From efce1b04e0113805a64e0e59f9fd65b9eca02273 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 17 Sep 2025 22:34:11 +0800 Subject: [PATCH 01/13] fix style check --- api/commands.py | 38 ++--- .../service_api/dataset/document.py | 15 +- .../rag_pipeline/rag_pipeline_workflow.py | 5 +- .../pipeline/generate_response_converter.py | 2 +- api/core/datasource/datasource_manager.py | 2 +- api/core/datasource/entities/api_entities.py | 2 +- .../datasource/utils/message_transformer.py | 53 +++---- api/core/workflow/entities/variable_pool.py | 2 +- .../nodes/datasource/datasource_node.py | 61 ++------ .../entity/pipeline_service_api_entities.py | 2 +- api/services/rag_pipeline/rag_pipeline.py | 135 +++++++++++++++--- api/services/tools/tools_transform_service.py | 6 - 12 files changed, 196 insertions(+), 127 deletions(-) diff --git a/api/commands.py b/api/commands.py index 3ff0d1fbe1..44199f0ff8 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1440,12 +1440,12 @@ def transform_datasource_credentials(): notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all() if notion_credentials: notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {} - for credential in notion_credentials: - tenant_id = credential.tenant_id + for notion_credential in notion_credentials: + tenant_id = notion_credential.tenant_id if tenant_id not in notion_credentials_tenant_mapping: notion_credentials_tenant_mapping[tenant_id] = [] - notion_credentials_tenant_mapping[tenant_id].append(credential) - for tenant_id, credentials in notion_credentials_tenant_mapping.items(): + notion_credentials_tenant_mapping[tenant_id].append(notion_credential) + for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items(): # check notion plugin is installed installed_plugins = installer_manager.list_plugins(tenant_id) installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] @@ -1454,12 +1454,12 @@ def transform_datasource_credentials(): # install notion plugin PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier]) auth_count = 0 - for credential in credentials: + for notion_tenant_credential in notion_tenant_credentials: auth_count += 1 # get credential oauth params - access_token = credential.access_token + access_token = notion_tenant_credential.access_token # notion info - notion_info = credential.source_info + notion_info = notion_tenant_credential.source_info workspace_id = notion_info.get("workspace_id") workspace_name = notion_info.get("workspace_name") workspace_icon = notion_info.get("workspace_icon") @@ -1487,12 +1487,12 @@ def transform_datasource_credentials(): firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all() if firecrawl_credentials: firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {} - for credential in firecrawl_credentials: - tenant_id = credential.tenant_id + for firecrawl_credential in firecrawl_credentials: + tenant_id = firecrawl_credential.tenant_id if tenant_id not in firecrawl_credentials_tenant_mapping: firecrawl_credentials_tenant_mapping[tenant_id] = [] - firecrawl_credentials_tenant_mapping[tenant_id].append(credential) - for tenant_id, credentials in firecrawl_credentials_tenant_mapping.items(): + firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential) + for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items(): # check firecrawl plugin is installed installed_plugins = installer_manager.list_plugins(tenant_id) installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] @@ -1502,10 +1502,10 @@ def transform_datasource_credentials(): PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier]) auth_count = 0 - for credential in credentials: + for firecrawl_tenant_credential in firecrawl_tenant_credentials: auth_count += 1 # get credential api key - credentials_json = json.loads(credential.credentials) + credentials_json = json.loads(firecrawl_tenant_credential.credentials) api_key = credentials_json.get("config", {}).get("api_key") base_url = credentials_json.get("config", {}).get("base_url") new_credentials = { @@ -1530,12 +1530,12 @@ def transform_datasource_credentials(): jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all() if jina_credentials: jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {} - for credential in jina_credentials: - tenant_id = credential.tenant_id + for jina_credential in jina_credentials: + tenant_id = jina_credential.tenant_id if tenant_id not in jina_credentials_tenant_mapping: jina_credentials_tenant_mapping[tenant_id] = [] - jina_credentials_tenant_mapping[tenant_id].append(credential) - for tenant_id, credentials in jina_credentials_tenant_mapping.items(): + jina_credentials_tenant_mapping[tenant_id].append(jina_credential) + for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items(): # check jina plugin is installed installed_plugins = installer_manager.list_plugins(tenant_id) installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] @@ -1546,10 +1546,10 @@ def transform_datasource_credentials(): PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier]) auth_count = 0 - for credential in credentials: + for jina_tenant_credential in jina_tenant_credentials: auth_count += 1 # get credential api key - credentials_json = json.loads(credential.credentials) + credentials_json = json.loads(jina_tenant_credential.credentials) api_key = credentials_json.get("config", {}).get("api_key") new_credentials = { "integration_secret": api_key, diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index a9f7608733..d26c64fe36 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -124,6 +124,9 @@ class DocumentAddByTextApi(DatasetApiResource): args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), ) + if not current_user: + raise ValueError("current_user is required") + upload_file = FileService(db.engine).upload_text( text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id ) @@ -204,6 +207,8 @@ class DocumentUpdateByTextApi(DatasetApiResource): name = args.get("name") if text is None or name is None: raise ValueError("Both text and name must be strings.") + if not current_user: + raise ValueError("current_user is required") upload_file = FileService(db.engine).upload_text( text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id ) @@ -308,6 +313,8 @@ class DocumentAddByFileApi(DatasetApiResource): if not isinstance(current_user, EndUser): raise ValueError("Invalid user account") + if not current_user: + raise ValueError("current_user is required") upload_file = FileService(db.engine).upload_file( filename=file.filename, content=file.read(), @@ -396,8 +403,12 @@ class DocumentUpdateByFileApi(DatasetApiResource): if not file.filename: raise FilenameNotExistsError + if not current_user: + raise ValueError("current_user is required") + if not isinstance(current_user, EndUser): raise ValueError("Invalid user account") + try: upload_file = FileService(db.engine).upload_file( filename=file.filename, @@ -577,7 +588,7 @@ class DocumentApi(DatasetApiResource): response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details} elif metadata == "without": dataset_process_rules = DatasetService.get_process_rules(dataset_id) - document_process_rules = document.dataset_process_rule.to_dict() + document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} data_source_info = document.data_source_detail_dict response = { "id": document.id, @@ -610,7 +621,7 @@ class DocumentApi(DatasetApiResource): } else: dataset_process_rules = DatasetService.get_process_rules(dataset_id) - document_process_rules = document.dataset_process_rule.to_dict() + document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {} data_source_info = document.data_source_detail_dict response = { "id": document.id, diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index 55bfdde009..ad578d947e 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -214,7 +214,10 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource): raise UnsupportedFileTypeError() if not file.filename: - raise FilenameNotExistsError + raise FilenameNotExistsError+ + + if not current_user: + raise ValueError("Invalid user account") try: upload_file = FileService(db.engine).upload_file( diff --git a/api/core/app/apps/pipeline/generate_response_converter.py b/api/core/app/apps/pipeline/generate_response_converter.py index e125958180..f47db16c18 100644 --- a/api/core/app/apps/pipeline/generate_response_converter.py +++ b/api/core/app/apps/pipeline/generate_response_converter.py @@ -23,7 +23,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): :param blocking_response: blocking response :return: """ - return dict(blocking_response.to_dict()) + return dict(blocking_response.model_dump()) @classmethod def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override] diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 47f314e126..3144712fe9 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -46,7 +46,7 @@ class DatasourceManager: provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id) if not provider_entity: raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found") - + controller = None match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: controller = OnlineDocumentDatasourcePluginProviderController( diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 81771719ea..416ab90ff1 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -62,7 +62,7 @@ class DatasourceProviderApiEntity(BaseModel): "description": self.description.to_dict(), "icon": self.icon, "label": self.label.to_dict(), - "type": self.type.value, + "type": self.type, "team_credentials": self.masked_credentials, "is_team_authorization": self.is_team_authorization, "allow_delete": self.allow_delete, diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index bb6ac6c1fc..39a294e625 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -6,6 +6,7 @@ from typing import Optional from core.datasource.entities.datasource_entities import DatasourceMessage from core.file import File, FileTransferMethod, FileType from core.tools.tool_file_manager import ToolFileManager +from models.tools import ToolFile logger = logging.getLogger(__name__) @@ -32,20 +33,20 @@ class DatasourceFileMessageTransformer: try: assert isinstance(message.message, DatasourceMessage.TextMessage) tool_file_manager = ToolFileManager() - file = tool_file_manager.create_file_by_url( + tool_file: ToolFile | None = tool_file_manager.create_file_by_url( user_id=user_id, tenant_id=tenant_id, file_url=message.message.text, conversation_id=conversation_id, ) + if tool_file: + url = f"/files/datasources/{tool_file.id}{guess_extension(tool_file.mimetype) or '.png'}" - url = f"/files/datasources/{file.id}{guess_extension(file.mimetype) or '.png'}" - - yield DatasourceMessage( - type=DatasourceMessage.MessageType.IMAGE_LINK, - message=DatasourceMessage.TextMessage(text=url), - meta=message.meta.copy() if message.meta is not None else {}, - ) + yield DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE_LINK, + message=DatasourceMessage.TextMessage(text=url), + meta=message.meta.copy() if message.meta is not None else {}, + ) except Exception as e: yield DatasourceMessage( type=DatasourceMessage.MessageType.TEXT, @@ -72,7 +73,7 @@ class DatasourceFileMessageTransformer: # FIXME: should do a type check here. assert isinstance(message.message.blob, bytes) tool_file_manager = ToolFileManager() - file = tool_file_manager.create_file_by_raw( + blob_tool_file: ToolFile | None = tool_file_manager.create_file_by_raw( user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, @@ -80,25 +81,27 @@ class DatasourceFileMessageTransformer: mimetype=mimetype, filename=filename, ) - - url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mimetype)) - - # check if file is image - if "image" in mimetype: - yield DatasourceMessage( - type=DatasourceMessage.MessageType.IMAGE_LINK, - message=DatasourceMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, - ) - else: - yield DatasourceMessage( - type=DatasourceMessage.MessageType.BINARY_LINK, - message=DatasourceMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + if blob_tool_file: + url = cls.get_datasource_file_url( + datasource_file_id=blob_tool_file.id, extension=guess_extension(blob_tool_file.mimetype) ) + + # check if file is image + if "image" in mimetype: + yield DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE_LINK, + message=DatasourceMessage.TextMessage(text=url), + meta=meta.copy() if meta is not None else {}, + ) + else: + yield DatasourceMessage( + type=DatasourceMessage.MessageType.BINARY_LINK, + message=DatasourceMessage.TextMessage(text=url), + meta=meta.copy() if meta is not None else {}, + ) elif message.type == DatasourceMessage.MessageType.FILE: meta = message.meta or {} - file = meta.get("file", None) + file: Optional[File] = meta.get("file") if isinstance(file, File): if file.transfer_method == FileTransferMethod.TOOL_FILE: assert file.related_id is not None diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 845ecbc125..8ceabde7e6 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -67,7 +67,7 @@ class VariablePool(BaseModel): self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var) # Add rag pipeline variables to the variable pool if self.rag_pipeline_variables: - rag_pipeline_variables_map = defaultdict(dict) + rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict) for rag_var in self.rag_pipeline_variables: node_id = rag_var.variable.belong_to_node_id key = rag_var.variable.variable diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index eb58ba14c1..b37fd4e6be 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -19,7 +19,7 @@ from core.file.enums import FileTransferMethod, FileType from core.plugin.impl.exc import PluginDaemonClientSideError from core.variables.segments import ArrayAnySegment from core.variables.variables import ArrayAnyVariable -from core.workflow.entities.variable_pool import VariablePool, VariableValue +from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent @@ -87,29 +87,18 @@ class DatasourceNode(Node): raise DatasourceNodeError("Invalid datasource info format") datasource_info: dict[str, Any] = datasource_info_value # get datasource runtime - try: - from core.datasource.datasource_manager import DatasourceManager + from core.datasource.datasource_manager import DatasourceManager - if datasource_type is None: - raise DatasourceNodeError("Datasource type is not set") + if datasource_type is None: + raise DatasourceNodeError("Datasource type is not set") - datasource_runtime = DatasourceManager.get_datasource_runtime( - provider_id=f"{node_data.plugin_id}/{node_data.provider_name}", - datasource_name=node_data.datasource_name or "", - tenant_id=self.tenant_id, - datasource_type=DatasourceProviderType.value_of(datasource_type), - ) - datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id) - except DatasourceNodeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs={}, - metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info}, - error=f"Failed to get datasource runtime: {str(e)}", - error_type=type(e).__name__, - ) - ) + datasource_runtime = DatasourceManager.get_datasource_runtime( + provider_id=f"{node_data.plugin_id}/{node_data.provider_name}", + datasource_name=node_data.datasource_name or "", + tenant_id=self.tenant_id, + datasource_type=DatasourceProviderType.value_of(datasource_type), + ) + datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id) parameters_for_log = datasource_info @@ -282,27 +271,6 @@ class DatasourceNode(Node): assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) return list(variable.value) if variable else [] - def _append_variables_recursively( - self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue - ): - """ - Append variables recursively - :param node_id: node id - :param variable_key_list: variable key list - :param variable_value: variable value - :return: - """ - variable_pool.add([node_id] + [".".join(variable_key_list)], variable_value) - - # if variable_value is a dict, then recursively append variables - if isinstance(variable_value, dict): - for key, value in variable_value.items(): - # construct new key list - new_key_list = variable_key_list + [key] - self._append_variables_recursively( - variable_pool=variable_pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value - ) - @classmethod def _extract_variable_selector_to_variable_mapping( cls, @@ -423,13 +391,6 @@ class DatasourceNode(Node): ) elif message.type == DatasourceMessage.MessageType.JSON: assert isinstance(message.message, DatasourceMessage.JsonMessage) - if self.node_type == NodeType.AGENT: - msg_metadata = message.message.json_object.pop("execution_metadata", {}) - agent_execution_metadata = { - key: value - for key, value in msg_metadata.items() - if key in WorkflowNodeExecutionMetadataKey.__members__.values() - } json.append(message.message.json_object) elif message.type == DatasourceMessage.MessageType.LINK: assert isinstance(message.message, DatasourceMessage.TextMessage) diff --git a/api/services/rag_pipeline/entity/pipeline_service_api_entities.py b/api/services/rag_pipeline/entity/pipeline_service_api_entities.py index 35005fad71..be321574df 100644 --- a/api/services/rag_pipeline/entity/pipeline_service_api_entities.py +++ b/api/services/rag_pipeline/entity/pipeline_service_api_entities.py @@ -7,7 +7,7 @@ from pydantic import BaseModel class DatasourceNodeRunApiEntity(BaseModel): pipeline_id: str node_id: str - inputs: Mapping[str, Any] + inputs: dict[str, Any] datasource_type: str credential_id: Optional[str] = None is_published: bool diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index e27d78b980..e7c255a86a 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -580,10 +580,10 @@ class RagPipelineService: ) yield start_event.model_dump() try: - for message in online_document_result: + for online_document_message in online_document_result: end_time = time.time() online_document_event = DatasourceCompletedEvent( - data=message.result, time_consuming=round(end_time - start_time, 2) + data=online_document_message.result, time_consuming=round(end_time - start_time, 2) ) yield online_document_event.model_dump() except Exception as e: @@ -609,10 +609,10 @@ class RagPipelineService: completed=0, ) yield start_event.model_dump() - for message in online_drive_result: + for online_drive_message in online_drive_result: end_time = time.time() online_drive_event = DatasourceCompletedEvent( - data=message.result, + data=online_drive_message.result, time_consuming=round(end_time - start_time, 2), total=None, completed=None, @@ -629,19 +629,19 @@ class RagPipelineService: ) start_time = time.time() try: - for message in website_crawl_result: + for website_crawl_message in website_crawl_result: end_time = time.time() - if message.result.status == "completed": + if website_crawl_message.result.status == "completed": crawl_event = DatasourceCompletedEvent( - data=message.result.web_info_list or [], - total=message.result.total, - completed=message.result.completed, + data=website_crawl_message.result.web_info_list or [], + total=website_crawl_message.result.total, + completed=website_crawl_message.result.completed, time_consuming=round(end_time - start_time, 2), ) else: crawl_event = DatasourceProcessingEvent( - total=message.result.total, - completed=message.result.completed, + total=website_crawl_message.result.total, + completed=website_crawl_message.result.completed, ) yield crawl_event.model_dump() except Exception as e: @@ -723,12 +723,12 @@ class RagPipelineService: ) try: variables: dict[str, Any] = {} - for message in online_document_result: - if message.type == DatasourceMessage.MessageType.VARIABLE: - assert isinstance(message.message, DatasourceMessage.VariableMessage) - variable_name = message.message.variable_name - variable_value = message.message.variable_value - if message.message.stream: + for online_document_message in online_document_result: + if online_document_message.type == DatasourceMessage.MessageType.VARIABLE: + assert isinstance(online_document_message.message, DatasourceMessage.VariableMessage) + variable_name = online_document_message.message.variable_name + variable_value = online_document_message.message.variable_value + if online_document_message.message.stream: if not isinstance(variable_value, str): raise ValueError("When 'stream' is True, 'variable_value' must be a string.") if variable_name not in variables: @@ -793,8 +793,9 @@ class RagPipelineService: for event in generator: if isinstance(event, (NodeRunSucceededEvent, NodeRunFailedEvent)): node_run_result = event.node_run_result - # sign output files - node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) or {} + if node_run_result: + # sign output files + node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) or {} break if not node_run_result: @@ -1358,3 +1359,99 @@ class RagPipelineService: workflow_thread_pool_id=None, is_retry=True, ) + + def get_datasource_plugins(self, tenant_id: str, dataset_id: str, is_published: bool) -> list[dict]: + """ + Get datasource plugins + """ + dataset: Dataset | None = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("Dataset not found") + pipeline: Pipeline | None = db.session.query(Pipeline).filter(Pipeline.id == dataset.pipeline_id).first() + if not pipeline: + raise ValueError("Pipeline not found") + + workflow: Workflow | None = None + if is_published: + workflow = self.get_published_workflow(pipeline=pipeline) + else: + workflow = self.get_draft_workflow(pipeline=pipeline) + if not pipeline or not workflow: + raise ValueError("Pipeline or workflow not found") + + datasource_nodes = workflow.graph_dict.get("nodes", []) + datasource_plugins = [] + for datasource_node in datasource_nodes: + if datasource_node.get("type") == "datasource": + datasource_node_data = datasource_node.get("data", {}) + if not datasource_node_data: + continue + + variables = workflow.rag_pipeline_variables + if variables: + variables_map = {item["variable"]: item for item in variables} + else: + variables_map = {} + + datasource_parameters = datasource_node_data.get("datasource_parameters", {}) + user_input_variables_keys = [] + user_input_variables = [] + + for _, value in datasource_parameters.items(): + if value.get("value") and isinstance(value.get("value"), str): + pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}" + match = re.match(pattern, value["value"]) + if match: + full_path = match.group(1) + last_part = full_path.split(".")[-1] + user_input_variables_keys.append(last_part) + elif value.get("value") and isinstance(value.get("value"), list): + last_part = value.get("value")[-1] + user_input_variables_keys.append(last_part) + for key, value in variables_map.items(): + if key in user_input_variables_keys: + user_input_variables.append(value) + + # get credentials + datasource_provider_service: DatasourceProviderService = DatasourceProviderService() + credentials: list[dict[Any, Any]] = datasource_provider_service.list_datasource_credentials( + tenant_id=tenant_id, + provider=datasource_node_data.get("provider_name"), + plugin_id=datasource_node_data.get("plugin_id"), + ) + credential_info_list: list[Any] = [] + for credential in credentials: + credential_info_list.append( + { + "id": credential.get("id"), + "name": credential.get("name"), + "type": credential.get("type"), + "is_default": credential.get("is_default"), + } + ) + + datasource_plugins.append( + { + "node_id": datasource_node.get("id"), + "plugin_id": datasource_node_data.get("plugin_id"), + "provider_name": datasource_node_data.get("provider_name"), + "datasource_type": datasource_node_data.get("provider_type"), + "title": datasource_node_data.get("title"), + "user_input_variables": user_input_variables, + "credentials": credential_info_list, + } + ) + + return datasource_plugins + + def get_pipeline(self, tenant_id: str, dataset_id: str) -> Pipeline: + """ + Get pipeline + """ + dataset: Dataset | None = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + if not dataset: + raise ValueError("Dataset not found") + pipeline: Pipeline | None = db.session.query(Pipeline).filter(Pipeline.id == dataset.pipeline_id).first() + if not pipeline: + raise ValueError("Pipeline not found") + return pipeline diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index bb04728f3a..845e14ca70 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -99,12 +99,6 @@ class ToolTransformService: provider.declaration.identity.icon = ToolTransformService.get_plugin_icon_url( tenant_id=tenant_id, filename=provider.declaration.identity.icon ) - else: - provider.declaration.identity.icon = ToolTransformService.get_tool_provider_icon_url( - provider_type=provider.type.value, - provider_name=provider.name, - icon=provider.declaration.identity.icon, - ) @classmethod def builtin_provider_to_user_provider( From eefcd3ecc47a0ad1c4f6e1433883811c85389415 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Wed, 17 Sep 2025 22:31:19 +0800 Subject: [PATCH 02/13] chore(api): apply autofix manully --- .../datasets/rag_pipeline/rag_pipeline.py | 2 +- api/controllers/console/datasets/wraps.py | 3 +- api/core/app/app_config/entities.py | 14 +++--- .../app/apps/pipeline/pipeline_generator.py | 20 ++++----- api/core/app/apps/pipeline/pipeline_runner.py | 12 ++--- api/core/app/entities/app_invoke_entities.py | 8 ++-- api/core/app/entities/task_entities.py | 38 ++++++++-------- .../datasource/__base/datasource_runtime.py | 4 +- .../datasource/datasource_file_manager.py | 28 ++++-------- api/core/datasource/entities/api_entities.py | 12 ++--- .../datasource/entities/common_entities.py | 7 ++- .../entities/datasource_entities.py | 44 +++++++++---------- .../datasource/utils/message_transformer.py | 5 +-- api/core/datasource/utils/parser.py | 5 +-- api/core/entities/knowledge_entities.py | 7 ++- api/core/plugin/entities/plugin.py | 20 ++++----- api/core/rag/entities/event.py | 12 ++--- .../rag/extractor/entity/extract_setting.py | 3 +- api/core/rag/extractor/notion_extractor.py | 10 ++--- ...hemy_workflow_node_execution_repository.py | 4 +- api/core/schemas/registry.py | 2 +- api/core/schemas/resolver.py | 12 ++--- api/core/schemas/schema_manager.py | 6 +-- .../entities/workflow_node_execution.py | 8 ++-- .../nodes/datasource/datasource_node.py | 8 ++-- .../workflow/nodes/datasource/entities.py | 6 +-- .../nodes/knowledge_index/entities.py | 12 ++--- .../knowledge_index/knowledge_index_node.py | 6 +-- api/models/dataset.py | 16 +++---- api/models/workflow.py | 2 +- api/services/dataset_service.py | 6 +-- .../rag_pipeline_entities.py | 32 +++++++------- .../entity/pipeline_service_api_entities.py | 4 +- .../rag_pipeline/pipeline_generate_service.py | 2 +- .../built_in/built_in_retrieval.py | 5 +-- .../customized/customized_retrieval.py | 5 +-- .../database/database_retrieval.py | 7 ++- .../pipeline_template_base.py | 3 +- .../remote/remote_retrieval.py | 3 +- api/services/rag_pipeline/rag_pipeline.py | 40 ++++++++--------- .../rag_pipeline/rag_pipeline_dsl_service.py | 22 +++++----- .../rag_pipeline_transform_service.py | 11 +++-- api/tasks/batch_clean_document_task.py | 3 +- api/tasks/deal_dataset_index_update_task.py | 12 ++--- .../priority_rag_pipeline_run_task.py | 8 ++-- .../rag_pipeline/rag_pipeline_run_task.py | 8 ++-- api/tasks/retry_document_indexing_task.py | 2 +- 47 files changed, 241 insertions(+), 268 deletions(-) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index e6360706ee..f04b0e04c3 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -104,7 +104,7 @@ class CustomizedPipelineTemplateApi(Resource): def post(self, template_id: str): with Session(db.engine) as session: template = ( - session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first() + session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first() ) if not template: raise ValueError("Customized pipeline template not found.") diff --git a/api/controllers/console/datasets/wraps.py b/api/controllers/console/datasets/wraps.py index 26783d8cf8..33751ab231 100644 --- a/api/controllers/console/datasets/wraps.py +++ b/api/controllers/console/datasets/wraps.py @@ -1,6 +1,5 @@ from collections.abc import Callable from functools import wraps -from typing import Optional from controllers.console.datasets.error import PipelineNotFoundError from extensions.ext_database import db @@ -10,7 +9,7 @@ from models.dataset import Pipeline def get_rag_pipeline( - view: Optional[Callable] = None, + view: Callable | None = None, ): def decorator(view_func): @wraps(view_func) diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 3dbc8706d3..e836a46f8f 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from enum import StrEnum, auto -from typing import Any, Literal, Optional +from typing import Any, Literal from pydantic import BaseModel, Field, field_validator @@ -114,9 +114,9 @@ class VariableEntity(BaseModel): hide: bool = False max_length: int | None = None options: Sequence[str] = Field(default_factory=list) - allowed_file_types: Optional[Sequence[FileType]] = Field(default_factory=list) - allowed_file_extensions: Optional[Sequence[str]] = Field(default_factory=list) - allowed_file_upload_methods: Optional[Sequence[FileTransferMethod]] = Field(default_factory=list) + allowed_file_types: Sequence[FileType] | None = Field(default_factory=list) + allowed_file_extensions: Sequence[str] | None = Field(default_factory=list) + allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list) @field_validator("description", mode="before") @classmethod @@ -134,8 +134,8 @@ class RagPipelineVariableEntity(VariableEntity): Rag Pipeline Variable Entity. """ - tooltips: Optional[str] = None - placeholder: Optional[str] = None + tooltips: str | None = None + placeholder: str | None = None belong_to_node_id: str @@ -298,7 +298,7 @@ class AppConfig(BaseModel): tenant_id: str app_id: str app_mode: AppMode - additional_features: Optional[AppAdditionalFeatures] = None + additional_features: AppAdditionalFeatures | None = None variables: list[VariableEntity] = [] sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index b1e98ed3ea..d441f273d8 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -7,7 +7,7 @@ import threading import time import uuid from collections.abc import Generator, Mapping -from typing import Any, Literal, Optional, Union, cast, overload +from typing import Any, Literal, Union, cast, overload from flask import Flask, current_app from pydantic import ValidationError @@ -69,7 +69,7 @@ class PipelineGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: Literal[True], call_depth: int, - workflow_thread_pool_id: Optional[str], + workflow_thread_pool_id: str | None, is_retry: bool = False, ) -> Generator[Mapping | str, None, None]: ... @@ -84,7 +84,7 @@ class PipelineGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: Literal[False], call_depth: int, - workflow_thread_pool_id: Optional[str], + workflow_thread_pool_id: str | None, is_retry: bool = False, ) -> Mapping[str, Any]: ... @@ -99,7 +99,7 @@ class PipelineGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: bool, call_depth: int, - workflow_thread_pool_id: Optional[str], + workflow_thread_pool_id: str | None, is_retry: bool = False, ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... @@ -113,7 +113,7 @@ class PipelineGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: bool = True, call_depth: int = 0, - workflow_thread_pool_id: Optional[str] = None, + workflow_thread_pool_id: str | None = None, is_retry: bool = False, ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]: # Add null check for dataset @@ -314,7 +314,7 @@ class PipelineGenerator(BaseAppGenerator): workflow_node_execution_repository: WorkflowNodeExecutionRepository, streaming: bool = True, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, - workflow_thread_pool_id: Optional[str] = None, + workflow_thread_pool_id: str | None = None, ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: """ Generate App response. @@ -331,7 +331,7 @@ class PipelineGenerator(BaseAppGenerator): """ with preserve_flask_contexts(flask_app, context_vars=context): # init queue manager - workflow = db.session.query(Workflow).filter(Workflow.id == workflow_id).first() + workflow = db.session.query(Workflow).where(Workflow.id == workflow_id).first() if not workflow: raise ValueError(f"Workflow not found: {workflow_id}") queue_manager = PipelineQueueManager( @@ -568,7 +568,7 @@ class PipelineGenerator(BaseAppGenerator): queue_manager: AppQueueManager, context: contextvars.Context, variable_loader: VariableLoader, - workflow_thread_pool_id: Optional[str] = None, + workflow_thread_pool_id: str | None = None, ) -> None: """ Generate worker in a new thread. @@ -801,11 +801,11 @@ class PipelineGenerator(BaseAppGenerator): self, datasource_runtime: OnlineDriveDatasourcePlugin, prefix: str, - bucket: Optional[str], + bucket: str | None, user_id: str, all_files: list, datasource_info: Mapping[str, Any], - next_page_parameters: Optional[dict] = None, + next_page_parameters: dict | None = None, ): """ Get files in a folder. diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index f2f01d1ee7..3b9bd224d9 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -1,6 +1,6 @@ import logging import time -from typing import Optional, cast +from typing import cast from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig @@ -40,7 +40,7 @@ class PipelineRunner(WorkflowBasedAppRunner): variable_loader: VariableLoader, workflow: Workflow, system_user_id: str, - workflow_thread_pool_id: Optional[str] = None, + workflow_thread_pool_id: str | None = None, ) -> None: """ :param application_generate_entity: application generate entity @@ -69,13 +69,13 @@ class PipelineRunner(WorkflowBasedAppRunner): user_id = None if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: - end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() + end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first() if end_user: user_id = end_user.session_id else: user_id = self.application_generate_entity.user_id - pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first() + pipeline = db.session.query(Pipeline).where(Pipeline.id == app_config.app_id).first() if not pipeline: raise ValueError("Pipeline not found") @@ -188,7 +188,7 @@ class PipelineRunner(WorkflowBasedAppRunner): ) self._handle_event(workflow_entry, event) - def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]: + def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None: """ Get workflow """ @@ -205,7 +205,7 @@ class PipelineRunner(WorkflowBasedAppRunner): return workflow def _init_rag_pipeline_graph( - self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: Optional[str] = None + self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: str | None = None ) -> Graph: """ Init pipeline graph diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 14dd78489a..a5ed0f8fa3 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -242,7 +242,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): node_id: str inputs: dict - single_loop_run: Optional[SingleLoopRunEntity] = None + single_loop_run: SingleLoopRunEntity | None = None class RagPipelineGenerateEntity(WorkflowAppGenerateEntity): @@ -256,9 +256,9 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity): datasource_info: Mapping[str, Any] dataset_id: str batch: str - document_id: Optional[str] = None - original_document_id: Optional[str] = None - start_node_id: Optional[str] = None + document_id: str | None = None + original_document_id: str | None = None + start_node_id: str | None = None # Import TraceQueueManager at runtime to resolve forward references diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index fe8c916d3a..31dc1eea89 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -1,6 +1,6 @@ from collections.abc import Mapping, Sequence from enum import StrEnum -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, ConfigDict, Field @@ -252,8 +252,8 @@ class NodeStartStreamResponse(StreamResponse): node_type: str title: str index: int - predecessor_node_id: Optional[str] = None - inputs: Optional[Mapping[str, Any]] = None + predecessor_node_id: str | None = None + inputs: Mapping[str, Any] | None = None inputs_truncated: bool = False created_at: int extras: dict[str, object] = Field(default_factory=dict) @@ -310,12 +310,12 @@ class NodeFinishStreamResponse(StreamResponse): node_type: str title: str index: int - predecessor_node_id: Optional[str] = None - inputs: Optional[Mapping[str, Any]] = None + predecessor_node_id: str | None = None + inputs: Mapping[str, Any] | None = None inputs_truncated: bool = False - process_data: Optional[Mapping[str, Any]] = None + process_data: Mapping[str, Any] | None = None process_data_truncated: bool = False - outputs: Optional[Mapping[str, Any]] = None + outputs: Mapping[str, Any] | None = None outputs_truncated: bool = True status: str error: str | None = None @@ -382,12 +382,12 @@ class NodeRetryStreamResponse(StreamResponse): node_type: str title: str index: int - predecessor_node_id: Optional[str] = None - inputs: Optional[Mapping[str, Any]] = None + predecessor_node_id: str | None = None + inputs: Mapping[str, Any] | None = None inputs_truncated: bool = False - process_data: Optional[Mapping[str, Any]] = None + process_data: Mapping[str, Any] | None = None process_data_truncated: bool = False - outputs: Optional[Mapping[str, Any]] = None + outputs: Mapping[str, Any] | None = None outputs_truncated: bool = False status: str error: str | None = None @@ -503,11 +503,11 @@ class IterationNodeCompletedStreamResponse(StreamResponse): node_id: str node_type: str title: str - outputs: Optional[Mapping] = None + outputs: Mapping | None = None outputs_truncated: bool = False created_at: int - extras: Optional[dict] = None - inputs: Optional[Mapping] = None + extras: dict | None = None + inputs: Mapping | None = None inputs_truncated: bool = False status: WorkflowNodeExecutionStatus error: str | None = None @@ -541,8 +541,8 @@ class LoopNodeStartStreamResponse(StreamResponse): metadata: Mapping = {} inputs: Mapping = {} inputs_truncated: bool = False - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None + parallel_id: str | None = None + parallel_start_node_id: str | None = None event: StreamEvent = StreamEvent.LOOP_STARTED workflow_run_id: str @@ -590,11 +590,11 @@ class LoopNodeCompletedStreamResponse(StreamResponse): node_id: str node_type: str title: str - outputs: Optional[Mapping] = None + outputs: Mapping | None = None outputs_truncated: bool = False created_at: int - extras: Optional[dict] = None - inputs: Optional[Mapping] = None + extras: dict | None = None + inputs: Mapping | None = None inputs_truncated: bool = False status: WorkflowNodeExecutionStatus error: str | None = None diff --git a/api/core/datasource/__base/datasource_runtime.py b/api/core/datasource/__base/datasource_runtime.py index 7d24bd7c6d..b7f280208a 100644 --- a/api/core/datasource/__base/datasource_runtime.py +++ b/api/core/datasource/__base/datasource_runtime.py @@ -17,9 +17,9 @@ class DatasourceRuntime(BaseModel): """ tenant_id: str - datasource_id: Optional[str] = None + datasource_id: str | None = None invoke_from: Optional["InvokeFrom"] = None - datasource_invoke_from: Optional[DatasourceInvokeFrom] = None + datasource_invoke_from: DatasourceInvokeFrom | None = None credentials: dict[str, Any] = Field(default_factory=dict) runtime_parameters: dict[str, Any] = Field(default_factory=dict) diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py index beb5ce7b04..f4e3c656bc 100644 --- a/api/core/datasource/datasource_file_manager.py +++ b/api/core/datasource/datasource_file_manager.py @@ -6,7 +6,7 @@ import os import time from datetime import datetime from mimetypes import guess_extension, guess_type -from typing import Optional, Union +from typing import Union from uuid import uuid4 import httpx @@ -62,10 +62,10 @@ class DatasourceFileManager: *, user_id: str, tenant_id: str, - conversation_id: Optional[str], + conversation_id: str | None, file_binary: bytes, mimetype: str, - filename: Optional[str] = None, + filename: str | None = None, ) -> UploadFile: extension = guess_extension(mimetype) or ".bin" unique_name = uuid4().hex @@ -106,7 +106,7 @@ class DatasourceFileManager: user_id: str, tenant_id: str, file_url: str, - conversation_id: Optional[str] = None, + conversation_id: str | None = None, ) -> ToolFile: # try to download image try: @@ -153,10 +153,7 @@ class DatasourceFileManager: :return: the binary of the file, mime type """ upload_file: UploadFile | None = ( - db.session.query(UploadFile) - .filter( - UploadFile.id == id, - ) + db.session.query(UploadFile).where(UploadFile.id == id) .first() ) @@ -177,10 +174,7 @@ class DatasourceFileManager: :return: the binary of the file, mime type """ message_file: MessageFile | None = ( - db.session.query(MessageFile) - .filter( - MessageFile.id == id, - ) + db.session.query(MessageFile).where(MessageFile.id == id) .first() ) @@ -197,10 +191,7 @@ class DatasourceFileManager: tool_file_id = None tool_file: ToolFile | None = ( - db.session.query(ToolFile) - .filter( - ToolFile.id == tool_file_id, - ) + db.session.query(ToolFile).where(ToolFile.id == tool_file_id) .first() ) @@ -221,10 +212,7 @@ class DatasourceFileManager: :return: the binary of the file, mime type """ upload_file: UploadFile | None = ( - db.session.query(UploadFile) - .filter( - UploadFile.id == upload_file_id, - ) + db.session.query(UploadFile).where(UploadFile.id == upload_file_id) .first() ) diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 81771719ea..af8ce4ed9b 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -12,9 +12,9 @@ class DatasourceApiEntity(BaseModel): name: str # identifier label: I18nObject # label description: I18nObject - parameters: Optional[list[DatasourceParameter]] = None + parameters: list[DatasourceParameter] | None = None labels: list[str] = Field(default_factory=list) - output_schema: Optional[dict] = None + output_schema: dict | None = None ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]] @@ -28,12 +28,12 @@ class DatasourceProviderApiEntity(BaseModel): icon: str | dict label: I18nObject # label type: str - masked_credentials: Optional[dict] = None - original_credentials: Optional[dict] = None + masked_credentials: dict | None = None + original_credentials: dict | None = None is_team_authorization: bool = False allow_delete: bool = True - plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource") - plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource") + plugin_id: str | None = Field(default="", description="The plugin id of the datasource") + plugin_unique_identifier: str | None = Field(default="", description="The unique identifier of the datasource") datasources: list[DatasourceApiEntity] = Field(default_factory=list) labels: list[str] = Field(default_factory=list) diff --git a/api/core/datasource/entities/common_entities.py b/api/core/datasource/entities/common_entities.py index 924e6fc0cf..98680a5779 100644 --- a/api/core/datasource/entities/common_entities.py +++ b/api/core/datasource/entities/common_entities.py @@ -1,4 +1,3 @@ -from typing import Optional from pydantic import BaseModel, Field @@ -9,9 +8,9 @@ class I18nObject(BaseModel): """ en_US: str - zh_Hans: Optional[str] = Field(default=None) - pt_BR: Optional[str] = Field(default=None) - ja_JP: Optional[str] = Field(default=None) + zh_Hans: str | None = Field(default=None) + pt_BR: str | None = Field(default=None) + ja_JP: str | None = Field(default=None) def __init__(self, **data): super().__init__(**data) diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 0c2011f841..ac4f51ac75 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -1,6 +1,6 @@ import enum from enum import Enum -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field, ValidationInfo, field_validator from yarl import URL @@ -80,7 +80,7 @@ class DatasourceParameter(PluginParameter): name: str, typ: DatasourceParameterType, required: bool, - options: Optional[list[str]] = None, + options: list[str] | None = None, ) -> "DatasourceParameter": """ get a simple datasource parameter @@ -120,14 +120,14 @@ class DatasourceIdentity(BaseModel): name: str = Field(..., description="The name of the datasource") label: I18nObject = Field(..., description="The label of the datasource") provider: str = Field(..., description="The provider of the datasource") - icon: Optional[str] = None + icon: str | None = None class DatasourceEntity(BaseModel): identity: DatasourceIdentity parameters: list[DatasourceParameter] = Field(default_factory=list) description: I18nObject = Field(..., description="The label of the datasource") - output_schema: Optional[dict] = None + output_schema: dict | None = None @field_validator("parameters", mode="before") @classmethod @@ -141,7 +141,7 @@ class DatasourceProviderIdentity(BaseModel): description: I18nObject = Field(..., description="The description of the tool") icon: str = Field(..., description="The icon of the tool") label: I18nObject = Field(..., description="The label of the tool") - tags: Optional[list[ToolLabelEnum]] = Field( + tags: list[ToolLabelEnum] | None = Field( default=[], description="The tags of the tool", ) @@ -169,7 +169,7 @@ class DatasourceProviderEntity(BaseModel): identity: DatasourceProviderIdentity credentials_schema: list[ProviderConfig] = Field(default_factory=list) - oauth_schema: Optional[OAuthSchema] = None + oauth_schema: OAuthSchema | None = None provider_type: DatasourceProviderType @@ -183,8 +183,8 @@ class DatasourceInvokeMeta(BaseModel): """ time_cost: float = Field(..., description="The time cost of the tool invoke") - error: Optional[str] = None - tool_config: Optional[dict] = None + error: str | None = None + tool_config: dict | None = None @classmethod def empty(cls) -> "DatasourceInvokeMeta": @@ -233,10 +233,10 @@ class OnlineDocumentPage(BaseModel): page_id: str = Field(..., description="The page id") page_name: str = Field(..., description="The page title") - page_icon: Optional[dict] = Field(None, description="The page icon") + page_icon: dict | None = Field(None, description="The page icon") type: str = Field(..., description="The type of the page") last_edited_time: str = Field(..., description="The last edited time") - parent_id: Optional[str] = Field(None, description="The parent page id") + parent_id: str | None = Field(None, description="The parent page id") class OnlineDocumentInfo(BaseModel): @@ -244,9 +244,9 @@ class OnlineDocumentInfo(BaseModel): Online document info """ - workspace_id: Optional[str] = Field(None, description="The workspace id") - workspace_name: Optional[str] = Field(None, description="The workspace name") - workspace_icon: Optional[str] = Field(None, description="The workspace icon") + workspace_id: str | None = Field(None, description="The workspace id") + workspace_name: str | None = Field(None, description="The workspace name") + workspace_icon: str | None = Field(None, description="The workspace icon") total: int = Field(..., description="The total number of documents") pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") @@ -307,10 +307,10 @@ class WebSiteInfo(BaseModel): Website info """ - status: Optional[str] = Field(..., description="crawl job status") - web_info_list: Optional[list[WebSiteInfoDetail]] = [] - total: Optional[int] = Field(default=0, description="The total number of websites") - completed: Optional[int] = Field(default=0, description="The number of completed websites") + status: str | None = Field(..., description="crawl job status") + web_info_list: list[WebSiteInfoDetail] | None = [] + total: int | None = Field(default=0, description="The total number of websites") + completed: int | None = Field(default=0, description="The number of completed websites") class WebsiteCrawlMessage(BaseModel): @@ -346,10 +346,10 @@ class OnlineDriveFileBucket(BaseModel): Online drive file bucket """ - bucket: Optional[str] = Field(None, description="The file bucket") + bucket: str | None = Field(None, description="The file bucket") files: list[OnlineDriveFile] = Field(..., description="The file list") is_truncated: bool = Field(False, description="Whether the result is truncated") - next_page_parameters: Optional[dict] = Field(None, description="Parameters for fetching the next page") + next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page") class OnlineDriveBrowseFilesRequest(BaseModel): @@ -357,10 +357,10 @@ class OnlineDriveBrowseFilesRequest(BaseModel): Get online drive file list request """ - bucket: Optional[str] = Field(None, description="The file bucket") + bucket: str | None = Field(None, description="The file bucket") prefix: str = Field(..., description="The parent folder ID") max_keys: int = Field(20, description="Page size for pagination") - next_page_parameters: Optional[dict] = Field(None, description="Parameters for fetching the next page") + next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page") class OnlineDriveBrowseFilesResponse(BaseModel): @@ -377,4 +377,4 @@ class OnlineDriveDownloadFileRequest(BaseModel): """ id: str = Field(..., description="The id of the file") - bucket: Optional[str] = Field(None, description="The name of the bucket") + bucket: str | None = Field(None, description="The name of the bucket") diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index bb6ac6c1fc..5aa25b573f 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -1,7 +1,6 @@ import logging from collections.abc import Generator from mimetypes import guess_extension, guess_type -from typing import Optional from core.datasource.entities.datasource_entities import DatasourceMessage from core.file import File, FileTransferMethod, FileType @@ -17,7 +16,7 @@ class DatasourceFileMessageTransformer: messages: Generator[DatasourceMessage, None, None], user_id: str, tenant_id: str, - conversation_id: Optional[str] = None, + conversation_id: str | None = None, ) -> Generator[DatasourceMessage, None, None]: """ Transform datasource message and handle file download @@ -121,5 +120,5 @@ class DatasourceFileMessageTransformer: yield message @classmethod - def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str: + def get_datasource_file_url(cls, datasource_file_id: str, extension: str | None) -> str: return f"/files/datasources/{datasource_file_id}{extension or '.bin'}" diff --git a/api/core/datasource/utils/parser.py b/api/core/datasource/utils/parser.py index 57ee15d7f2..db1766a059 100644 --- a/api/core/datasource/utils/parser.py +++ b/api/core/datasource/utils/parser.py @@ -3,7 +3,6 @@ import uuid from json import dumps as json_dumps from json import loads as json_loads from json.decoder import JSONDecodeError -from typing import Optional from flask import request from requests import get @@ -169,9 +168,9 @@ class ApiBasedToolSchemaParser: return bundles @staticmethod - def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]: + def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None: parameter = parameter or {} - typ: Optional[str] = None + typ: str | None = None if parameter.get("format") == "binary": return ToolParameter.ToolParameterType.FILE diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index 33e1f64579..f6da4c7094 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -1,4 +1,3 @@ -from typing import Optional from pydantic import BaseModel @@ -16,7 +15,7 @@ class QAPreviewDetail(BaseModel): class IndexingEstimate(BaseModel): total_segments: int preview: list[PreviewDetail] - qa_preview: Optional[list[QAPreviewDetail]] = None + qa_preview: list[QAPreviewDetail] | None = None class PipelineDataset(BaseModel): @@ -30,10 +29,10 @@ class PipelineDocument(BaseModel): id: str position: int data_source_type: str - data_source_info: Optional[dict] = None + data_source_info: dict | None = None name: str indexing_status: str - error: Optional[str] = None + error: str | None = None enabled: bool diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 3063cd39ae..57012bf495 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -1,7 +1,7 @@ import datetime from collections.abc import Mapping from enum import StrEnum, auto -from typing import Any, Optional +from typing import Any from packaging.version import InvalidVersion, Version from pydantic import BaseModel, Field, field_validator, model_validator @@ -67,10 +67,10 @@ class PluginCategory(StrEnum): class PluginDeclaration(BaseModel): class Plugins(BaseModel): - tools: Optional[list[str]] = Field(default_factory=list[str]) - models: Optional[list[str]] = Field(default_factory=list[str]) - endpoints: Optional[list[str]] = Field(default_factory=list[str]) - datasources: Optional[list[str]] = Field(default_factory=list[str]) + tools: list[str] | None = Field(default_factory=list[str]) + models: list[str] | None = Field(default_factory=list[str]) + endpoints: list[str] | None = Field(default_factory=list[str]) + datasources: list[str] | None = Field(default_factory=list[str]) class Meta(BaseModel): minimum_dify_version: str | None = Field(default=None) @@ -101,11 +101,11 @@ class PluginDeclaration(BaseModel): tags: list[str] = Field(default_factory=list) repo: str | None = Field(default=None) verified: bool = Field(default=False) - tool: Optional[ToolProviderEntity] = None - model: Optional[ProviderEntity] = None - endpoint: Optional[EndpointProviderDeclaration] = None - agent_strategy: Optional[AgentStrategyProviderEntity] = None - datasource: Optional[DatasourceProviderEntity] = None + tool: ToolProviderEntity | None = None + model: ProviderEntity | None = None + endpoint: EndpointProviderDeclaration | None = None + agent_strategy: AgentStrategyProviderEntity | None = None + datasource: DatasourceProviderEntity | None = None meta: Meta @field_validator("version") diff --git a/api/core/rag/entities/event.py b/api/core/rag/entities/event.py index a36e32fc9c..24db5d77be 100644 --- a/api/core/rag/entities/event.py +++ b/api/core/rag/entities/event.py @@ -1,6 +1,6 @@ from collections.abc import Mapping from enum import Enum -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -27,12 +27,12 @@ class DatasourceErrorEvent(BaseDatasourceEvent): class DatasourceCompletedEvent(BaseDatasourceEvent): event: str = DatasourceStreamEvent.COMPLETED.value data: Mapping[str, Any] | list = Field(..., description="result") - total: Optional[int] = Field(default=0, description="total") - completed: Optional[int] = Field(default=0, description="completed") - time_consuming: Optional[float] = Field(default=0.0, description="time consuming") + total: int | None = Field(default=0, description="total") + completed: int | None = Field(default=0, description="completed") + time_consuming: float | None = Field(default=0.0, description="time consuming") class DatasourceProcessingEvent(BaseDatasourceEvent): event: str = DatasourceStreamEvent.PROCESSING.value - total: Optional[int] = Field(..., description="total") - completed: Optional[int] = Field(..., description="completed") + total: int | None = Field(..., description="total") + completed: int | None = Field(..., description="completed") diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index c0e79b02c4..b5eea0bf30 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -1,4 +1,3 @@ -from typing import Optional from pydantic import BaseModel, ConfigDict @@ -11,7 +10,7 @@ class NotionInfo(BaseModel): Notion import info. """ - credential_id: Optional[str] = None + credential_id: str | None = None notion_workspace_id: str notion_obj_id: str notion_page_type: str diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index c1563840f0..bddf41af43 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -1,7 +1,7 @@ import json import logging import operator -from typing import Any, Optional, cast +from typing import Any, cast import requests @@ -35,9 +35,9 @@ class NotionExtractor(BaseExtractor): notion_obj_id: str, notion_page_type: str, tenant_id: str, - document_model: Optional[DocumentModel] = None, - notion_access_token: Optional[str] = None, - credential_id: Optional[str] = None, + document_model: DocumentModel | None = None, + notion_access_token: str | None = None, + credential_id: str | None = None, ): self._notion_access_token = None self._document_model = document_model @@ -369,7 +369,7 @@ class NotionExtractor(BaseExtractor): return cast(str, data["last_edited_time"]) @classmethod - def _get_access_token(cls, tenant_id: str, credential_id: Optional[str]) -> str: + def _get_access_token(cls, tenant_id: str, credential_id: str | None) -> str: # get credential from tenant_id and credential_id if not credential_id: raise Exception(f"No credential id found for tenant {tenant_id}") diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 219aec5a03..5226a1071f 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -7,7 +7,7 @@ import json import logging from collections.abc import Callable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor -from typing import Any, Optional, TypeVar, Union +from typing import Any, TypeVar, Union import psycopg2.errors from sqlalchemy import UnaryExpression, asc, desc, select @@ -530,7 +530,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) def get_by_workflow_run( self, workflow_run_id: str, - order_config: Optional[OrderConfig] = None, + order_config: OrderConfig | None = None, triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) -> Sequence[WorkflowNodeExecution]: """ diff --git a/api/core/schemas/registry.py b/api/core/schemas/registry.py index 339784267c..867e4803bc 100644 --- a/api/core/schemas/registry.py +++ b/api/core/schemas/registry.py @@ -85,7 +85,7 @@ class SchemaRegistry: except (OSError, json.JSONDecodeError) as e: print(f"Warning: failed to load schema {version}/{schema_name}: {e}") - def get_schema(self, uri: str) -> Optional[Any]: + def get_schema(self, uri: str) -> Any | None: """Retrieves a schema by URI with version support""" version, schema_name = self._parse_uri(uri) if not version or not schema_name: diff --git a/api/core/schemas/resolver.py b/api/core/schemas/resolver.py index 1c5dabd79b..1b57f5bb94 100644 --- a/api/core/schemas/resolver.py +++ b/api/core/schemas/resolver.py @@ -3,7 +3,7 @@ import re import threading from collections import deque from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, Union from core.schemas.registry import SchemaRegistry @@ -53,8 +53,8 @@ class QueueItem: """Represents an item in the BFS queue""" current: Any - parent: Optional[Any] - key: Optional[Union[str, int]] + parent: Any | None + key: Union[str, int] | None depth: int ref_path: set[str] @@ -65,7 +65,7 @@ class SchemaResolver: _cache: dict[str, SchemaDict] = {} _cache_lock = threading.Lock() - def __init__(self, registry: Optional[SchemaRegistry] = None, max_depth: int = 10): + def __init__(self, registry: SchemaRegistry | None = None, max_depth: int = 10): """ Initialize the schema resolver @@ -202,7 +202,7 @@ class SchemaResolver: ) ) - def _get_resolved_schema(self, ref_uri: str) -> Optional[SchemaDict]: + def _get_resolved_schema(self, ref_uri: str) -> SchemaDict | None: """Get resolved schema from cache or registry""" # Check cache first with self._cache_lock: @@ -223,7 +223,7 @@ class SchemaResolver: def resolve_dify_schema_refs( - schema: SchemaType, registry: Optional[SchemaRegistry] = None, max_depth: int = 30 + schema: SchemaType, registry: SchemaRegistry | None = None, max_depth: int = 30 ) -> SchemaType: """ Resolve $ref references in Dify schema to actual schema content diff --git a/api/core/schemas/schema_manager.py b/api/core/schemas/schema_manager.py index 3c9314db66..833ab609c7 100644 --- a/api/core/schemas/schema_manager.py +++ b/api/core/schemas/schema_manager.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.schemas.registry import SchemaRegistry @@ -7,7 +7,7 @@ from core.schemas.registry import SchemaRegistry class SchemaManager: """Schema manager provides high-level schema operations""" - def __init__(self, registry: Optional[SchemaRegistry] = None): + def __init__(self, registry: SchemaRegistry | None = None): self.registry = registry or SchemaRegistry.default_registry() def get_all_schema_definitions(self, version: str = "v1") -> list[Mapping[str, Any]]: @@ -22,7 +22,7 @@ class SchemaManager: """ return self.registry.get_all_schemas_for_version(version) - def get_schema_by_name(self, schema_name: str, version: str = "v1") -> Optional[Mapping[str, Any]]: + def get_schema_by_name(self, schema_name: str, version: str = "v1") -> Mapping[str, Any] | None: """ Get a specific schema by name diff --git a/api/core/workflow/entities/workflow_node_execution.py b/api/core/workflow/entities/workflow_node_execution.py index ef3022352a..4abc9c068d 100644 --- a/api/core/workflow/entities/workflow_node_execution.py +++ b/api/core/workflow/entities/workflow_node_execution.py @@ -8,7 +8,7 @@ and don't contain implementation details like tenant_id, app_id, etc. from collections.abc import Mapping from datetime import datetime -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field, PrivateAttr @@ -53,9 +53,9 @@ class WorkflowNodeExecution(BaseModel): # Execution data # The `inputs` and `outputs` fields hold the full content - inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node - process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data - outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node + inputs: Mapping[str, Any] | None = None # Input variables used by this node + process_data: Mapping[str, Any] | None = None # Intermediate processing data + outputs: Mapping[str, Any] | None = None # Output variables produced by this node # Execution state status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index eb58ba14c1..6cf0c91c30 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -50,7 +50,7 @@ class DatasourceNode(Node): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = DatasourceNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -59,7 +59,7 @@ class DatasourceNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -179,7 +179,7 @@ class DatasourceNode(Node): related_id = datasource_info.get("related_id") if not related_id: raise DatasourceNodeError("File is not exist") - upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first() + upload_file = db.session.query(UploadFile).where(UploadFile.id == related_id).first() if not upload_file: raise ValueError("Invalid upload file Info") diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index b182928baa..4802d3ed98 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -1,4 +1,4 @@ -from typing import Any, Literal, Optional, Union +from typing import Any, Literal, Union from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo @@ -10,7 +10,7 @@ class DatasourceEntity(BaseModel): plugin_id: str provider_name: str # redundancy provider_type: str - datasource_name: Optional[str] = "local_file" + datasource_name: str | None = "local_file" datasource_configurations: dict[str, Any] | None = None plugin_unique_identifier: str | None = None # redundancy @@ -19,7 +19,7 @@ class DatasourceNodeData(BaseNodeData, DatasourceEntity): class DatasourceInput(BaseModel): # TODO: check this type value: Union[Any, list[str]] - type: Optional[Literal["mixed", "variable", "constant"]] = None + type: Literal["mixed", "variable", "constant"] | None = None @field_validator("type", mode="before") @classmethod diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 85c0f695c6..2a2e983a0c 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional, Union +from typing import Literal, Union from pydantic import BaseModel @@ -65,12 +65,12 @@ class RetrievalSetting(BaseModel): search_method: Literal["semantic_search", "keyword_search", "fulltext_search", "hybrid_search"] top_k: int - score_threshold: Optional[float] = 0.5 + score_threshold: float | None = 0.5 score_threshold_enabled: bool = False reranking_mode: str = "reranking_model" reranking_enable: bool = True - reranking_model: Optional[RerankingModelConfig] = None - weights: Optional[WeightedScoreConfig] = None + reranking_model: RerankingModelConfig | None = None + weights: WeightedScoreConfig | None = None class IndexMethod(BaseModel): @@ -107,10 +107,10 @@ class OnlineDocumentInfo(BaseModel): """ provider: str - workspace_id: Optional[str] = None + workspace_id: str | None = None page_id: str page_type: str - icon: Optional[OnlineDocumentIcon] = None + icon: OnlineDocumentIcon | None = None class WebsiteInfo(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index d7641bc123..d5ced1a246 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -2,7 +2,7 @@ import datetime import logging import time from collections.abc import Mapping -from typing import Any, Optional, cast +from typing import Any, cast from sqlalchemy import func, select @@ -43,7 +43,7 @@ class KnowledgeIndexNode(Node): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = KnowledgeIndexNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -52,7 +52,7 @@ class KnowledgeIndexNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/models/dataset.py b/api/models/dataset.py index 2c03a0c30c..d620d56006 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -10,7 +10,7 @@ import re import time from datetime import datetime from json import JSONDecodeError -from typing import Any, Optional, cast +from typing import Any, cast import sqlalchemy as sa from sqlalchemy import DateTime, String, func, select @@ -76,7 +76,7 @@ class Dataset(Base): @property def total_documents(self): - return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar() + return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar() @property def total_available_documents(self): @@ -173,10 +173,10 @@ class Dataset(Base): ) @property - def doc_form(self) -> Optional[str]: + def doc_form(self) -> str | None: if self.chunk_structure: return self.chunk_structure - document = db.session.query(Document).filter(Document.dataset_id == self.id).first() + document = db.session.query(Document).where(Document.dataset_id == self.id).first() if document: return document.doc_form return None @@ -234,7 +234,7 @@ class Dataset(Base): @property def is_published(self): if self.pipeline_id: - pipeline = db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first() + pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first() if pipeline: return pipeline.is_published return False @@ -1244,7 +1244,7 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] @property def created_user_name(self): - account = db.session.query(Account).filter(Account.id == self.created_by).first() + account = db.session.query(Account).where(Account.id == self.created_by).first() if account: return account.name return "" @@ -1274,7 +1274,7 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] @property def created_user_name(self): - account = db.session.query(Account).filter(Account.id == self.created_by).first() + account = db.session.query(Account).where(Account.id == self.created_by).first() if account: return account.name return "" @@ -1297,7 +1297,7 @@ class Pipeline(Base): # type: ignore[name-defined] updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def retrieve_dataset(self, session: Session): - return session.query(Dataset).filter(Dataset.pipeline_id == self.id).first() + return session.query(Dataset).where(Dataset.pipeline_id == self.id).first() class DocumentPipelineExecutionLog(Base): diff --git a/api/models/workflow.py b/api/models/workflow.py index bb7ea2c074..5f604a51a8 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1546,7 +1546,7 @@ class WorkflowDraftVariableFile(Base): comment="Size of the original variable content in bytes", ) - length: Mapped[Optional[int]] = mapped_column( + length: Mapped[int | None] = mapped_column( sa.Integer, nullable=True, comment=( diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 84e36cb80b..798233fd95 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -7,7 +7,7 @@ import time import uuid from collections import Counter from collections.abc import Sequence -from typing import Any, Literal, Optional +from typing import Any, Literal import sqlalchemy as sa from sqlalchemy import exists, func, select @@ -315,8 +315,8 @@ class DatasetService: return dataset @staticmethod - def get_dataset(dataset_id) -> Optional[Dataset]: - dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first() + def get_dataset(dataset_id) -> Dataset | None: + dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first() return dataset @staticmethod diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index e215a89c15..ac96b5c8ad 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -1,13 +1,13 @@ -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel, field_validator class IconInfo(BaseModel): icon: str - icon_background: Optional[str] = None - icon_type: Optional[str] = None - icon_url: Optional[str] = None + icon_background: str | None = None + icon_type: str | None = None + icon_url: str | None = None class PipelineTemplateInfoEntity(BaseModel): @@ -21,8 +21,8 @@ class RagPipelineDatasetCreateEntity(BaseModel): description: str icon_info: IconInfo permission: str - partial_member_list: Optional[list[str]] = None - yaml_content: Optional[str] = None + partial_member_list: list[str] | None = None + yaml_content: str | None = None class RerankingModelConfig(BaseModel): @@ -30,8 +30,8 @@ class RerankingModelConfig(BaseModel): Reranking Model Config. """ - reranking_provider_name: Optional[str] = "" - reranking_model_name: Optional[str] = "" + reranking_provider_name: str | None = "" + reranking_model_name: str | None = "" class VectorSetting(BaseModel): @@ -57,8 +57,8 @@ class WeightedScoreConfig(BaseModel): Weighted score Config. """ - vector_setting: Optional[VectorSetting] - keyword_setting: Optional[KeywordSetting] + vector_setting: VectorSetting | None + keyword_setting: KeywordSetting | None class EmbeddingSetting(BaseModel): @@ -85,12 +85,12 @@ class RetrievalSetting(BaseModel): search_method: Literal["semantic_search", "fulltext_search", "keyword_search", "hybrid_search"] top_k: int - score_threshold: Optional[float] = 0.5 + score_threshold: float | None = 0.5 score_threshold_enabled: bool = False - reranking_mode: Optional[str] = "reranking_model" - reranking_enable: Optional[bool] = True - reranking_model: Optional[RerankingModelConfig] = None - weights: Optional[WeightedScoreConfig] = None + reranking_mode: str | None = "reranking_model" + reranking_enable: bool | None = True + reranking_model: RerankingModelConfig | None = None + weights: WeightedScoreConfig | None = None class IndexMethod(BaseModel): @@ -112,7 +112,7 @@ class KnowledgeConfiguration(BaseModel): indexing_technique: Literal["high_quality", "economy"] embedding_model_provider: str = "" embedding_model: str = "" - keyword_number: Optional[int] = 10 + keyword_number: int | None = 10 retrieval_model: RetrievalSetting @field_validator("embedding_model_provider", mode="before") diff --git a/api/services/rag_pipeline/entity/pipeline_service_api_entities.py b/api/services/rag_pipeline/entity/pipeline_service_api_entities.py index 35005fad71..41f46a55a7 100644 --- a/api/services/rag_pipeline/entity/pipeline_service_api_entities.py +++ b/api/services/rag_pipeline/entity/pipeline_service_api_entities.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from pydantic import BaseModel @@ -9,7 +9,7 @@ class DatasourceNodeRunApiEntity(BaseModel): node_id: str inputs: Mapping[str, Any] datasource_type: str - credential_id: Optional[str] = None + credential_id: str | None = None is_published: bool diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py index 563174c528..e6cee64df6 100644 --- a/api/services/rag_pipeline/pipeline_generate_service.py +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -108,7 +108,7 @@ class PipelineGenerateService: Update document status to waiting :param document_id: document id """ - document = db.session.query(Document).filter(Document.id == document_id).first() + document = db.session.query(Document).where(Document.id == document_id).first() if document: document.indexing_status = "waiting" db.session.add(document) diff --git a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py index b0fa54115c..24baeb73b5 100644 --- a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py @@ -1,7 +1,6 @@ import json from os import path from pathlib import Path -from typing import Optional from flask import current_app @@ -14,7 +13,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): Retrieval pipeline template from built-in, the location is constants/pipeline_templates.json """ - builtin_data: Optional[dict] = None + builtin_data: dict | None = None def get_type(self) -> str: return PipelineTemplateType.BUILTIN @@ -54,7 +53,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return builtin_data.get("pipeline_templates", {}).get(language, {}) @classmethod - def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> Optional[dict]: + def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> dict | None: """ Fetch pipeline template detail from builtin. :param template_id: Template ID diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index 3380d23ec4..82a0a08ec6 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -1,4 +1,3 @@ -from typing import Optional import yaml from flask_login import current_user @@ -56,14 +55,14 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return {"pipeline_templates": recommended_pipelines_results} @classmethod - def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]: + def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None: """ Fetch pipeline template detail from db. :param template_id: Template ID :return: """ pipeline_template = ( - db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first() + db.session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first() ) if not pipeline_template: return None diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 709702fe11..a544767465 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -1,4 +1,3 @@ -from typing import Optional import yaml @@ -33,7 +32,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ pipeline_built_in_templates: list[PipelineBuiltInTemplate] = ( - db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all() + db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.language == language).all() ) recommended_pipelines_results = [] @@ -53,7 +52,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return {"pipeline_templates": recommended_pipelines_results} @classmethod - def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]: + def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None: """ Fetch pipeline template detail from db. :param pipeline_id: Pipeline ID @@ -61,7 +60,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ # is in public recommended list pipeline_template = ( - db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == template_id).first() + db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.id == template_id).first() ) if not pipeline_template: diff --git a/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py b/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py index fa6a38a357..21c30a4986 100644 --- a/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py +++ b/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Optional class PipelineTemplateRetrievalBase(ABC): @@ -10,7 +9,7 @@ class PipelineTemplateRetrievalBase(ABC): raise NotImplementedError @abstractmethod - def get_pipeline_template_detail(self, template_id: str) -> Optional[dict]: + def get_pipeline_template_detail(self, template_id: str) -> dict | None: raise NotImplementedError @abstractmethod diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py index e541a7bc0b..8f96842337 100644 --- a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py @@ -1,5 +1,4 @@ import logging -from typing import Optional import requests @@ -36,7 +35,7 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return PipelineTemplateType.REMOTE @classmethod - def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> Optional[dict]: + def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict | None: """ Fetch pipeline template detail from dify official. :param template_id: Pipeline ID diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index e27d78b980..88e1dab23e 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -5,7 +5,7 @@ import threading import time from collections.abc import Callable, Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast from uuid import uuid4 from flask_login import current_user @@ -112,7 +112,7 @@ class RagPipelineService: return result @classmethod - def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> Optional[dict]: + def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None: """ Get pipeline template detail. :param template_id: template id @@ -121,12 +121,12 @@ class RagPipelineService: if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() - built_in_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) + built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id) return built_in_result else: mode = "customized" retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() - customized_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) + customized_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id) return customized_result @classmethod @@ -185,7 +185,7 @@ class RagPipelineService: db.session.delete(customized_template) db.session.commit() - def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]: + def get_draft_workflow(self, pipeline: Pipeline) -> Workflow | None: """ Get draft workflow """ @@ -203,7 +203,7 @@ class RagPipelineService: # return draft workflow return workflow - def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]: + def get_published_workflow(self, pipeline: Pipeline) -> Workflow | None: """ Get published workflow """ @@ -267,7 +267,7 @@ class RagPipelineService: *, pipeline: Pipeline, graph: dict, - unique_hash: Optional[str], + unique_hash: str | None, account: Account, environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], @@ -387,9 +387,7 @@ class RagPipelineService: return default_block_configs - def get_default_block_config( - self, node_type: str, filters: Optional[dict] = None - ) -> Optional[Mapping[str, object]]: + def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None: """ Get default config of node. :param node_type: node type @@ -495,7 +493,7 @@ class RagPipelineService: account: Account, datasource_type: str, is_published: bool, - credential_id: Optional[str] = None, + credential_id: str | None = None, ) -> Generator[Mapping[str, Any], None, None]: """ Run published workflow datasource @@ -661,7 +659,7 @@ class RagPipelineService: account: Account, datasource_type: str, is_published: bool, - credential_id: Optional[str] = None, + credential_id: str | None = None, ) -> Mapping[str, Any]: """ Run published workflow datasource @@ -876,7 +874,7 @@ class RagPipelineService: if invoke_from.value == InvokeFrom.PUBLISHED.value: document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) if document_id: - document = db.session.query(Document).filter(Document.id == document_id.value).first() + document = db.session.query(Document).where(Document.id == document_id.value).first() if document: document.indexing_status = "error" document.error = error @@ -887,7 +885,7 @@ class RagPipelineService: def update_workflow( self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict - ) -> Optional[Workflow]: + ) -> Workflow | None: """ Update workflow attributes @@ -1057,7 +1055,7 @@ class RagPipelineService: return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) - def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]: + def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> WorkflowRun | None: """ Get workflow run detail @@ -1113,12 +1111,12 @@ class RagPipelineService: """ Publish customized pipeline template """ - pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first() + pipeline = db.session.query(Pipeline).where(Pipeline.id == pipeline_id).first() if not pipeline: raise ValueError("Pipeline not found") if not pipeline.workflow_id: raise ValueError("Pipeline workflow not found") - workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() + workflow = db.session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first() if not workflow: raise ValueError("Workflow not found") with Session(db.engine) as session: @@ -1142,7 +1140,7 @@ class RagPipelineService: max_position = ( db.session.query(func.max(PipelineCustomizedTemplate.position)) - .filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id) + .where(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id) .scalar() ) @@ -1278,7 +1276,7 @@ class RagPipelineService: # Query active recommended plugins pipeline_recommended_plugins = ( db.session.query(PipelineRecommendedPlugin) - .filter(PipelineRecommendedPlugin.active == True) + .where(PipelineRecommendedPlugin.active == True) .order_by(PipelineRecommendedPlugin.position.asc()) .all() ) @@ -1329,12 +1327,12 @@ class RagPipelineService: """ document_pipeline_excution_log = ( db.session.query(DocumentPipelineExecutionLog) - .filter(DocumentPipelineExecutionLog.document_id == document.id) + .where(DocumentPipelineExecutionLog.document_id == document.id) .first() ) if not document_pipeline_excution_log: raise ValueError("Document pipeline execution log not found") - pipeline = db.session.query(Pipeline).filter(Pipeline.id == document_pipeline_excution_log.pipeline_id).first() + pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_excution_log.pipeline_id).first() if not pipeline: raise ValueError("Pipeline not found") # convert to app config diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index fe92f6b084..e21d2d56bc 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -6,7 +6,7 @@ import uuid from collections.abc import Mapping from datetime import UTC, datetime from enum import StrEnum -from typing import Optional, cast +from typing import cast from urllib.parse import urlparse from uuid import uuid4 @@ -66,11 +66,11 @@ class ImportStatus(StrEnum): class RagPipelineImportInfo(BaseModel): id: str status: ImportStatus - pipeline_id: Optional[str] = None + pipeline_id: str | None = None current_dsl_version: str = CURRENT_DSL_VERSION imported_dsl_version: str = "" error: str = "" - dataset_id: Optional[str] = None + dataset_id: str | None = None class CheckDependenciesResult(BaseModel): @@ -121,12 +121,12 @@ class RagPipelineDslService: *, account: Account, import_mode: str, - yaml_content: Optional[str] = None, - yaml_url: Optional[str] = None, - pipeline_id: Optional[str] = None, - dataset: Optional[Dataset] = None, - dataset_name: Optional[str] = None, - icon_info: Optional[IconInfo] = None, + yaml_content: str | None = None, + yaml_url: str | None = None, + pipeline_id: str | None = None, + dataset: Dataset | None = None, + dataset_name: str | None = None, + icon_info: IconInfo | None = None, ) -> RagPipelineImportInfo: """Import an app from YAML content or URL.""" import_id = str(uuid.uuid4()) @@ -530,10 +530,10 @@ class RagPipelineDslService: def _create_or_update_pipeline( self, *, - pipeline: Optional[Pipeline], + pipeline: Pipeline | None, data: dict, account: Account, - dependencies: Optional[list[PluginDependency]] = None, + dependencies: list[PluginDependency] | None = None, ) -> Pipeline: """Create a new app or update an existing one.""" if not account.current_tenant_id: diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 78440b4889..c2dbb484cf 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -1,7 +1,6 @@ import json from datetime import UTC, datetime from pathlib import Path -from typing import Optional from uuid import uuid4 import yaml @@ -21,7 +20,7 @@ from services.plugin.plugin_service import PluginService class RagPipelineTransformService: def transform_dataset(self, dataset_id: str): - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: raise ValueError("Dataset not found") if dataset.pipeline_id and dataset.runtime_mode == "rag_pipeline": @@ -90,7 +89,7 @@ class RagPipelineTransformService: "status": "success", } - def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: Optional[str]): + def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None): pipeline_yaml = {} if doc_form == "text_model": match datasource_type: @@ -152,7 +151,7 @@ class RagPipelineTransformService: return node def _deal_knowledge_index( - self, dataset: Dataset, doc_form: str, indexing_technique: Optional[str], retrieval_model: dict, node: dict + self, dataset: Dataset, doc_form: str, indexing_technique: str | None, retrieval_model: dict, node: dict ): knowledge_configuration_dict = node.get("data", {}) knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration_dict) @@ -289,7 +288,7 @@ class RagPipelineTransformService: jina_node_id = "1752491761974" firecrawl_node_id = "1752565402678" - documents = db.session.query(Document).filter(Document.dataset_id == dataset.id).all() + documents = db.session.query(Document).where(Document.dataset_id == dataset.id).all() for document in documents: data_source_info_dict = document.data_source_info_dict @@ -299,7 +298,7 @@ class RagPipelineTransformService: document.data_source_type = "local_file" file_id = data_source_info_dict.get("upload_file_id") if file_id: - file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() if file: data_source_info = json.dumps( { diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index 9b3857b4a5..447443703a 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -1,6 +1,5 @@ import logging import time -from typing import Optional import click from celery import shared_task @@ -17,7 +16,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") -def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: Optional[str], file_ids: list[str]): +def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]): """ Clean document when document deleted. :param document_ids: document ids diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py index dc266aef65..df4a76d94f 100644 --- a/api/tasks/deal_dataset_index_update_task.py +++ b/api/tasks/deal_dataset_index_update_task.py @@ -44,7 +44,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): if dataset_documents: dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( {"indexing_status": "indexing"}, synchronize_session=False ) db.session.commit() @@ -76,12 +76,12 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): # clean keywords index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False) index_processor.load(dataset, documents, with_keywords=False) - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "completed"}, synchronize_session=False ) db.session.commit() except Exception as e: - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "error", "error": str(e)}, synchronize_session=False ) db.session.commit() @@ -100,7 +100,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): if dataset_documents: # update document status dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( {"indexing_status": "indexing"}, synchronize_session=False ) db.session.commit() @@ -148,12 +148,12 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): documents.append(document) # save vector index index_processor.load(dataset, documents, with_keywords=False) - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "completed"}, synchronize_session=False ) db.session.commit() except Exception as e: - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "error", "error": str(e)}, synchronize_session=False ) db.session.commit() diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index 4780e48558..028f635188 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -104,20 +104,20 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], with Session(db.engine, expire_on_commit=False) as session: # Load required entities - account = session.query(Account).filter(Account.id == user_id).first() + account = session.query(Account).where(Account.id == user_id).first() if not account: raise ValueError(f"Account {user_id} not found") - tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first() + tenant = session.query(Tenant).where(Tenant.id == tenant_id).first() if not tenant: raise ValueError(f"Tenant {tenant_id} not found") account.current_tenant = tenant - pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first() + pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first() if not pipeline: raise ValueError(f"Pipeline {pipeline_id} not found") - workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() + workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first() if not workflow: raise ValueError(f"Workflow {pipeline.workflow_id} not found") diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index 72916972df..ee904c4649 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -125,20 +125,20 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], with Session(db.engine) as session: # Load required entities - account = session.query(Account).filter(Account.id == user_id).first() + account = session.query(Account).where(Account.id == user_id).first() if not account: raise ValueError(f"Account {user_id} not found") - tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first() + tenant = session.query(Tenant).where(Tenant.id == tenant_id).first() if not tenant: raise ValueError(f"Tenant {tenant_id} not found") account.current_tenant = tenant - pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first() + pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first() if not pipeline: raise ValueError(f"Pipeline {pipeline_id} not found") - workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() + workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first() if not workflow: raise ValueError(f"Workflow {pipeline.workflow_id} not found") diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index f4e9b52778..9c12696824 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -38,7 +38,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_ if not user: logger.info(click.style(f"User not found: {user_id}", fg="red")) return - tenant = db.session.query(Tenant).filter(Tenant.id == dataset.tenant_id).first() + tenant = db.session.query(Tenant).where(Tenant.id == dataset.tenant_id).first() if not tenant: raise ValueError("Tenant not found") user.current_tenant = tenant From 6166c26ea6677aba3adb8e23490e05011fe097a5 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 17 Sep 2025 22:36:18 +0800 Subject: [PATCH 03/13] fix style check --- .../service_api/dataset/rag_pipeline/rag_pipeline_workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index ad578d947e..cbc1907bf5 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -214,7 +214,7 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource): raise UnsupportedFileTypeError() if not file.filename: - raise FilenameNotExistsError+ + raise FilenameNotExistsError if not current_user: raise ValueError("Invalid user account") From 24fc7d0d6b6a010b929523afedace27b51c04101 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Wed, 17 Sep 2025 22:40:24 +0800 Subject: [PATCH 04/13] fix(api): fix Optional not defined --- api/core/datasource/utils/message_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index 996f9dc95d..d0a9eb5e74 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -100,7 +100,7 @@ class DatasourceFileMessageTransformer: ) elif message.type == DatasourceMessage.MessageType.FILE: meta = message.meta or {} - file: Optional[File] = meta.get("file") + file: File | None = meta.get("file") if isinstance(file, File): if file.transfer_method == FileTransferMethod.TOOL_FILE: assert file.related_id is not None From 5077f8b29995693427c1e1138646aeee5a2bf1dd Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Wed, 17 Sep 2025 22:55:13 +0800 Subject: [PATCH 05/13] fix(api): fix format, replace .filter with .where --- api/controllers/console/datasets/wraps.py | 2 +- .../rag_pipeline/rag_pipeline_workflow.py | 2 +- api/core/app/apps/pipeline/pipeline_runner.py | 6 ++-- .../datasource/datasource_file_manager.py | 20 +++--------- .../datasource/entities/common_entities.py | 1 - api/core/entities/knowledge_entities.py | 1 - .../rag/extractor/entity/extract_setting.py | 1 - api/core/workflow/errors.py | 2 +- .../knowledge_index/knowledge_index_node.py | 4 +-- api/factories/file_factory.py | 2 +- api/models/dataset.py | 2 +- api/services/dataset_service.py | 2 +- api/services/datasource_provider_service.py | 4 +-- .../customized/customized_retrieval.py | 3 +- .../database/database_retrieval.py | 1 - api/services/rag_pipeline/rag_pipeline.py | 32 +++++++++---------- .../rag_pipeline/rag_pipeline_dsl_service.py | 8 ++--- api/tasks/deal_dataset_index_update_task.py | 8 ++--- 18 files changed, 41 insertions(+), 60 deletions(-) diff --git a/api/controllers/console/datasets/wraps.py b/api/controllers/console/datasets/wraps.py index 33751ab231..98abb3ef8d 100644 --- a/api/controllers/console/datasets/wraps.py +++ b/api/controllers/console/datasets/wraps.py @@ -27,7 +27,7 @@ def get_rag_pipeline( pipeline = ( db.session.query(Pipeline) - .filter(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id) + .where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id) .first() ) diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index cbc1907bf5..f05325d711 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -215,7 +215,7 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource): if not file.filename: raise FilenameNotExistsError - + if not current_user: raise ValueError("Invalid user account") diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 3b9bd224d9..ebb8b15163 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -195,9 +195,7 @@ class PipelineRunner(WorkflowBasedAppRunner): # fetch workflow by workflow_id workflow = ( db.session.query(Workflow) - .filter( - Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id - ) + .where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id) .first() ) @@ -272,7 +270,7 @@ class PipelineRunner(WorkflowBasedAppRunner): if document_id and dataset_id: document = ( db.session.query(Document) - .filter(Document.id == document_id, Document.dataset_id == dataset_id) + .where(Document.id == document_id, Document.dataset_id == dataset_id) .first() ) if document: diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py index f4e3c656bc..0c50c2f980 100644 --- a/api/core/datasource/datasource_file_manager.py +++ b/api/core/datasource/datasource_file_manager.py @@ -152,10 +152,7 @@ class DatasourceFileManager: :return: the binary of the file, mime type """ - upload_file: UploadFile | None = ( - db.session.query(UploadFile).where(UploadFile.id == id) - .first() - ) + upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == id).first() if not upload_file: return None @@ -173,10 +170,7 @@ class DatasourceFileManager: :return: the binary of the file, mime type """ - message_file: MessageFile | None = ( - db.session.query(MessageFile).where(MessageFile.id == id) - .first() - ) + message_file: MessageFile | None = db.session.query(MessageFile).where(MessageFile.id == id).first() # Check if message_file is not None if message_file is not None: @@ -190,10 +184,7 @@ class DatasourceFileManager: else: tool_file_id = None - tool_file: ToolFile | None = ( - db.session.query(ToolFile).where(ToolFile.id == tool_file_id) - .first() - ) + tool_file: ToolFile | None = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first() if not tool_file: return None @@ -211,10 +202,7 @@ class DatasourceFileManager: :return: the binary of the file, mime type """ - upload_file: UploadFile | None = ( - db.session.query(UploadFile).where(UploadFile.id == upload_file_id) - .first() - ) + upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() if not upload_file: return None, None diff --git a/api/core/datasource/entities/common_entities.py b/api/core/datasource/entities/common_entities.py index 98680a5779..ac36d83ae3 100644 --- a/api/core/datasource/entities/common_entities.py +++ b/api/core/datasource/entities/common_entities.py @@ -1,4 +1,3 @@ - from pydantic import BaseModel, Field diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index f6da4c7094..b9ca7414dc 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -1,4 +1,3 @@ - from pydantic import BaseModel diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index b5eea0bf30..b9bf9d0d8c 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -1,4 +1,3 @@ - from pydantic import BaseModel, ConfigDict from models.dataset import Document diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py index 14e0315846..5bf1faee5d 100644 --- a/api/core/workflow/errors.py +++ b/api/core/workflow/errors.py @@ -13,4 +13,4 @@ class WorkflowNodeRunFailedError(Exception): @property def error(self) -> str: - return self._error \ No newline at end of file + return self._error diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index d5ced1a246..4b6bad1aa3 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -160,7 +160,7 @@ class KnowledgeIndexNode(Node): document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.word_count = ( db.session.query(func.sum(DocumentSegment.word_count)) - .filter( + .where( DocumentSegment.document_id == document.id, DocumentSegment.dataset_id == dataset.id, ) @@ -168,7 +168,7 @@ class KnowledgeIndexNode(Node): ) db.session.add(document) # update document segment status - db.session.query(DocumentSegment).filter( + db.session.query(DocumentSegment).where( DocumentSegment.document_id == document.id, DocumentSegment.dataset_id == dataset.id, ).update( diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 41505ab025..588168bd39 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -326,7 +326,7 @@ def _build_from_datasource_file( ) -> File: datasource_file = ( db.session.query(UploadFile) - .filter( + .where( UploadFile.id == mapping.get("datasource_file_id"), UploadFile.tenant_id == tenant_id, ) diff --git a/api/models/dataset.py b/api/models/dataset.py index d620d56006..2c4059f800 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -82,7 +82,7 @@ class Dataset(Base): def total_available_documents(self): return ( db.session.query(func.count(Document.id)) - .filter( + .where( Document.dataset_id == self.id, Document.indexing_status == "completed", Document.enabled == True, diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 798233fd95..51507886ad 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -419,7 +419,7 @@ class DatasetService: def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str): dataset = ( db.session.query(Dataset) - .filter( + .where( Dataset.id != dataset_id, Dataset.name == name, Dataset.tenant_id == tenant_id, diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 8dceeee7ec..f05a892f93 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -690,7 +690,7 @@ class DatasourceProviderService: # Get all provider configurations of the current workspace datasource_providers: list[DatasourceProvider] = ( db.session.query(DatasourceProvider) - .filter( + .where( DatasourceProvider.tenant_id == tenant_id, DatasourceProvider.provider == provider, DatasourceProvider.plugin_id == plugin_id, @@ -862,7 +862,7 @@ class DatasourceProviderService: # Get all provider configurations of the current workspace datasource_providers: list[DatasourceProvider] = ( db.session.query(DatasourceProvider) - .filter( + .where( DatasourceProvider.tenant_id == tenant_id, DatasourceProvider.provider == provider, DatasourceProvider.plugin_id == plugin_id, diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index 82a0a08ec6..ca871bcaa1 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -1,4 +1,3 @@ - import yaml from flask_login import current_user @@ -36,7 +35,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ pipeline_customized_templates = ( db.session.query(PipelineCustomizedTemplate) - .filter(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language) + .where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language) .order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc()) .all() ) diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index a544767465..ec91f79606 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -1,4 +1,3 @@ - import yaml from extensions.ext_database import db diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 0232b9998f..cc7514aaba 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -138,7 +138,7 @@ class RagPipelineService: """ customized_template: PipelineCustomizedTemplate | None = ( db.session.query(PipelineCustomizedTemplate) - .filter( + .where( PipelineCustomizedTemplate.id == template_id, PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, ) @@ -151,7 +151,7 @@ class RagPipelineService: if template_name: template = ( db.session.query(PipelineCustomizedTemplate) - .filter( + .where( PipelineCustomizedTemplate.name == template_name, PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, PipelineCustomizedTemplate.id != template_id, @@ -174,7 +174,7 @@ class RagPipelineService: """ customized_template: PipelineCustomizedTemplate | None = ( db.session.query(PipelineCustomizedTemplate) - .filter( + .where( PipelineCustomizedTemplate.id == template_id, PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, ) @@ -192,7 +192,7 @@ class RagPipelineService: # fetch draft workflow by rag pipeline workflow = ( db.session.query(Workflow) - .filter( + .where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.version == "draft", @@ -214,7 +214,7 @@ class RagPipelineService: # fetch published workflow by workflow_id workflow = ( db.session.query(Workflow) - .filter( + .where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == pipeline.workflow_id, @@ -1015,7 +1015,7 @@ class RagPipelineService: """ limit = int(args.get("limit", 20)) - base_query = db.session.query(WorkflowRun).filter( + base_query = db.session.query(WorkflowRun).where( WorkflowRun.tenant_id == pipeline.tenant_id, WorkflowRun.app_id == pipeline.id, or_( @@ -1025,7 +1025,7 @@ class RagPipelineService: ) if args.get("last_id"): - last_workflow_run = base_query.filter( + last_workflow_run = base_query.where( WorkflowRun.id == args.get("last_id"), ).first() @@ -1033,7 +1033,7 @@ class RagPipelineService: raise ValueError("Last workflow run not exists") workflow_runs = ( - base_query.filter( + base_query.where( WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id ) .order_by(WorkflowRun.created_at.desc()) @@ -1046,7 +1046,7 @@ class RagPipelineService: has_more = False if len(workflow_runs) == limit: current_page_first_workflow_run = workflow_runs[-1] - rest_count = base_query.filter( + rest_count = base_query.where( WorkflowRun.created_at < current_page_first_workflow_run.created_at, WorkflowRun.id != current_page_first_workflow_run.id, ).count() @@ -1065,7 +1065,7 @@ class RagPipelineService: """ workflow_run = ( db.session.query(WorkflowRun) - .filter( + .where( WorkflowRun.tenant_id == pipeline.tenant_id, WorkflowRun.app_id == pipeline.id, WorkflowRun.id == run_id, @@ -1130,7 +1130,7 @@ class RagPipelineService: if template_name: template = ( db.session.query(PipelineCustomizedTemplate) - .filter( + .where( PipelineCustomizedTemplate.name == template_name, PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id, ) @@ -1168,7 +1168,7 @@ class RagPipelineService: def is_workflow_exist(self, pipeline: Pipeline) -> bool: return ( db.session.query(Workflow) - .filter( + .where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.version == Workflow.VERSION_DRAFT, @@ -1362,10 +1362,10 @@ class RagPipelineService: """ Get datasource plugins """ - dataset: Dataset | None = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: raise ValueError("Dataset not found") - pipeline: Pipeline | None = db.session.query(Pipeline).filter(Pipeline.id == dataset.pipeline_id).first() + pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first() if not pipeline: raise ValueError("Pipeline not found") @@ -1446,10 +1446,10 @@ class RagPipelineService: """ Get pipeline """ - dataset: Dataset | None = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: raise ValueError("Dataset not found") - pipeline: Pipeline | None = db.session.query(Pipeline).filter(Pipeline.id == dataset.pipeline_id).first() + pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first() if not pipeline: raise ValueError("Pipeline not found") return pipeline diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index e21d2d56bc..88f28e03ef 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -318,7 +318,7 @@ class RagPipelineDslService: if knowledge_configuration.indexing_technique == "high_quality": dataset_collection_binding = ( self._session.query(DatasetCollectionBinding) - .filter( + .where( DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model, @@ -452,7 +452,7 @@ class RagPipelineDslService: if knowledge_configuration.indexing_technique == "high_quality": dataset_collection_binding = ( self._session.query(DatasetCollectionBinding) - .filter( + .where( DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model, @@ -599,7 +599,7 @@ class RagPipelineDslService: ) workflow = ( self._session.query(Workflow) - .filter( + .where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.version == "draft", @@ -673,7 +673,7 @@ class RagPipelineDslService: workflow = ( self._session.query(Workflow) - .filter( + .where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.version == "draft", diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py index df4a76d94f..713f149c38 100644 --- a/api/tasks/deal_dataset_index_update_task.py +++ b/api/tasks/deal_dataset_index_update_task.py @@ -33,7 +33,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): if action == "upgrade": dataset_documents = ( db.session.query(DatasetDocument) - .filter( + .where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, @@ -54,7 +54,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): # add from vector index segments = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) .order_by(DocumentSegment.position.asc()) .all() ) @@ -88,7 +88,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): elif action == "update": dataset_documents = ( db.session.query(DatasetDocument) - .filter( + .where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, @@ -113,7 +113,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): try: segments = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) .order_by(DocumentSegment.position.asc()) .all() ) From 8cc6927fede350c959cbcdf20a14f31c2203b1aa Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 17 Sep 2025 23:04:03 +0800 Subject: [PATCH 06/13] fix mypy --- api/controllers/service_api/dataset/error.py | 6 ++++++ api/core/app/apps/base_app_generator.py | 12 ++++++------ .../app/apps/pipeline/generate_response_converter.py | 4 ++-- api/core/app/apps/pipeline/pipeline_generator.py | 2 +- api/core/datasource/datasource_manager.py | 8 ++++++-- api/core/file/models.py | 1 + api/core/workflow/nodes/loop/loop_node.py | 2 +- api/services/datasource_provider_service.py | 2 +- api/services/rag_pipeline/rag_pipeline.py | 5 +++-- 9 files changed, 27 insertions(+), 15 deletions(-) diff --git a/api/controllers/service_api/dataset/error.py b/api/controllers/service_api/dataset/error.py index e4214a16ad..ecfc37df85 100644 --- a/api/controllers/service_api/dataset/error.py +++ b/api/controllers/service_api/dataset/error.py @@ -47,3 +47,9 @@ class DatasetInUseError(BaseHTTPException): error_code = "dataset_in_use" description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." code = 409 + + +class PipelineRunError(BaseHTTPException): + error_code = "pipeline_run_error" + description = "An error occurred while running the pipeline." + code = 500 diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index b8ff40cb5f..01d025aca8 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -45,9 +45,9 @@ class BaseAppGenerator: mapping=v, tenant_id=tenant_id, config=FileUploadConfig( - allowed_file_types=entity_dictionary[k].allowed_file_types, # pyright: ignore[reportArgumentType] - allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, # pyright: ignore[reportArgumentType] - allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, # pyright: ignore[reportArgumentType] + allowed_file_types=entity_dictionary[k].allowed_file_types or [], + allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [], + allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [], ), strict_type_validation=strict_type_validation, ) @@ -60,9 +60,9 @@ class BaseAppGenerator: mappings=v, tenant_id=tenant_id, config=FileUploadConfig( - allowed_file_types=entity_dictionary[k].allowed_file_types, # pyright: ignore[reportArgumentType] - allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, # pyright: ignore[reportArgumentType] - allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, # pyright: ignore[reportArgumentType] + allowed_file_types=entity_dictionary[k].allowed_file_types or [], + allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [], + allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [], ), ) for k, v in user_inputs.items() diff --git a/api/core/app/apps/pipeline/generate_response_converter.py b/api/core/app/apps/pipeline/generate_response_converter.py index f47db16c18..cfacd8640d 100644 --- a/api/core/app/apps/pipeline/generate_response_converter.py +++ b/api/core/app/apps/pipeline/generate_response_converter.py @@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(cast(dict, data)) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump()) yield response_chunk @classmethod @@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict())) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump()) yield response_chunk diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index b1e98ed3ea..780e3a53db 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -744,7 +744,7 @@ class PipelineGenerator(BaseAppGenerator): Format datasource info list. """ if datasource_type == "online_drive": - all_files = [] + all_files: list[Mapping[str, Any]] = [] datasource_node_data = None datasource_nodes = workflow.graph_dict.get("nodes", []) for datasource_node in datasource_nodes: diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 3144712fe9..31f982e960 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -46,7 +46,7 @@ class DatasourceManager: provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id) if not provider_entity: raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found") - controller = None + controller: DatasourcePluginProviderController | None = None match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: controller = OnlineDocumentDatasourcePluginProviderController( @@ -79,8 +79,12 @@ class DatasourceManager: case _: raise ValueError(f"Unsupported datasource type: {datasource_type}") - datasource_plugin_providers[provider_id] = controller + if controller: + datasource_plugin_providers[provider_id] = controller + if controller is None: + raise DatasourceProviderNotFoundError(f"Datasource provider {provider_id} not found.") + return controller @classmethod diff --git a/api/core/file/models.py b/api/core/file/models.py index a242851a63..7089b7ce7a 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -119,6 +119,7 @@ class File(BaseModel): assert self.related_id is not None assert self.extension is not None return sign_tool_file(tool_file_id=self.related_id, extension=self.extension) + return None def to_plugin_parameter(self) -> dict[str, Any]: return { diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 25b0c4f4fe..d783290e51 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -92,7 +92,7 @@ class LoopNode(Node): if self._node_data.loop_variables: value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = { "constant": lambda var: self._get_segment_for_constant(var.var_type, var.value), - "variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value), + "variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value) if isinstance(var.value, list) else None, } for loop_variable in self._node_data.loop_variables: if loop_variable.value_type not in value_processor: diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 8dceeee7ec..ae4ab5947b 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -77,7 +77,7 @@ class DatasourceProviderService: provider_id=f"{plugin_id}/{provider}", credential_type=CredentialType.of(datasource_provider.auth_type), ) - encrypted_credentials = raw_credentials.copy() + encrypted_credentials = dict(raw_credentials) for key, value in encrypted_credentials.items(): if key in provider_credential_secret_variables: encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index e7c255a86a..07b0e53cbe 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -378,12 +378,12 @@ class RagPipelineService: Get default block configs """ # return default block config - default_block_configs = [] + default_block_configs: list[dict[str, Any]] = [] for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values(): node_class = node_class_mapping[LATEST_VERSION] default_config = node_class.get_default_config() if default_config: - default_block_configs.append(default_config) + default_block_configs.append(dict(default_config)) return default_block_configs @@ -631,6 +631,7 @@ class RagPipelineService: try: for website_crawl_message in website_crawl_result: end_time = time.time() + crawl_event: DatasourceCompletedEvent | DatasourceProcessingEvent if website_crawl_message.result.status == "completed": crawl_event = DatasourceCompletedEvent( data=website_crawl_message.result.web_info_list or [], From ea38b4bcbeeddcb12e6ffe61184635173af311ab Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 17 Sep 2025 23:15:03 +0800 Subject: [PATCH 07/13] fix mypy --- .../service_api/dataset/rag_pipeline/rag_pipeline_workflow.py | 2 +- api/core/datasource/datasource_manager.py | 2 +- api/core/workflow/nodes/loop/loop_node.py | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index cbc1907bf5..f05325d711 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -215,7 +215,7 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource): if not file.filename: raise FilenameNotExistsError - + if not current_user: raise ValueError("Invalid user account") diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 31f982e960..47d297e194 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -84,7 +84,7 @@ class DatasourceManager: if controller is None: raise DatasourceProviderNotFoundError(f"Datasource provider {provider_id} not found.") - + return controller @classmethod diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index d783290e51..2b988ad944 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -92,7 +92,9 @@ class LoopNode(Node): if self._node_data.loop_variables: value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = { "constant": lambda var: self._get_segment_for_constant(var.var_type, var.value), - "variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value) if isinstance(var.value, list) else None, + "variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value) + if isinstance(var.value, list) + else None, } for loop_variable in self._node_data.loop_variables: if loop_variable.value_type not in value_processor: From 42d76dd12688a07351d1e43310d2f26b8fa45f76 Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 17 Sep 2025 23:19:57 +0800 Subject: [PATCH 08/13] fix mypy --- api/configs/__init__.py | 2 +- api/configs/feature/__init__.py | 4 ++-- api/core/workflow/errors.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/configs/__init__.py b/api/configs/__init__.py index 04642b5e9a..1932046322 100644 --- a/api/configs/__init__.py +++ b/api/configs/__init__.py @@ -1,3 +1,3 @@ from .app_config import DifyConfig -dify_config = DifyConfig() # pyright: ignore[reportCallIssue] +dify_config = DifyConfig() # type: ignore diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 6d3934a557..db6f1e592c 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -512,11 +512,11 @@ class WorkflowVariableTruncationConfig(BaseSettings): description="Maximum size for variable to trigger final truncation.", ) WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH: PositiveInt = Field( - 50000, + 100000, description="maximum length for string to trigger tuncation, measure in number of characters", ) WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH: PositiveInt = Field( - 100, + 1000, description="maximum length for array to trigger truncation.", ) diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py index 14e0315846..5bf1faee5d 100644 --- a/api/core/workflow/errors.py +++ b/api/core/workflow/errors.py @@ -13,4 +13,4 @@ class WorkflowNodeRunFailedError(Exception): @property def error(self) -> str: - return self._error \ No newline at end of file + return self._error From eed82f7ca714961c3aa8377901e407be3b1f3c73 Mon Sep 17 00:00:00 2001 From: Harry Date: Wed, 17 Sep 2025 23:23:58 +0800 Subject: [PATCH 09/13] fix(api): update user retrieval logic in get_user function --- api/controllers/inner_api/plugin/wraps.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index 3776d0be0e..04102c49f3 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -32,11 +32,20 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: user_model = ( session.query(EndUser) .where( - EndUser.session_id == user_id, + EndUser.id == user_id, EndUser.tenant_id == tenant_id, ) .first() ) + if not user_model: + user_model = ( + session.query(EndUser) + .where( + EndUser.session_id == user_id, + EndUser.tenant_id == tenant_id, + ) + .first() + ) if not user_model: user_model = EndUser( tenant_id=tenant_id, From 73e8623f07baacf6a4e2189e82e70aa058d8366f Mon Sep 17 00:00:00 2001 From: Harry Date: Wed, 17 Sep 2025 23:42:32 +0800 Subject: [PATCH 10/13] fix(api): simplify parameters in get_signed_file_url_for_plugin function --- api/controllers/inner_api/plugin/plugin.py | 8 +------- api/core/file/helpers.py | 8 ++------ 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 9c26d64510..c5bb2f2545 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -420,13 +420,7 @@ class PluginUploadFileRequestApi(Resource): ) def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile): # generate signed url - url = get_signed_file_url_for_plugin( - payload.filename, - payload.mimetype, - tenant_model.id, - user_model.id, - user_model.session_id if isinstance(user_model, EndUser) else None, - ) + url = get_signed_file_url_for_plugin(payload.filename, payload.mimetype, tenant_model.id, user_model.id) return BaseBackwardsInvocationResponse(data={"url": url}).model_dump() diff --git a/api/core/file/helpers.py b/api/core/file/helpers.py index 37ed8275c2..6d553d7dc6 100644 --- a/api/core/file/helpers.py +++ b/api/core/file/helpers.py @@ -25,9 +25,7 @@ def get_signed_file_url(upload_file_id: str, as_attachment=False) -> str: return f"{url}?{query_string}" -def get_signed_file_url_for_plugin( - filename: str, mimetype: str, tenant_id: str, user_id: str, session_id: str | None -) -> str: +def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str: # Plugin access should use internal URL for Docker network communication base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL url = f"{base_url}/files/upload/for-plugin" @@ -37,9 +35,7 @@ def get_signed_file_url_for_plugin( msg = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() encoded_sign = base64.urlsafe_b64encode(sign).decode() - - url_user_id = session_id or user_id - return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={url_user_id}&tenant_id={tenant_id}" + return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={user_id}&tenant_id={tenant_id}" def verify_plugin_file_signature( From 495562e3138bad56af5123a3dc4ed6b00c2ad1be Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Wed, 17 Sep 2025 23:48:45 +0800 Subject: [PATCH 11/13] chore(api): fix incorrect assertion message --- .../sqlalchemy_workflow_node_execution_repository.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 5226a1071f..fc160cbbe4 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -291,7 +291,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) return None value_json = _deterministic_json_dump(json_encodable_value) - assert value_json is not None, "value_json should be None here." + assert value_json is not None, "value_json should be not None here." suffix = type_.value upload_file = self._file_service.upload_file( From 6371cc502843482f60fa27150469ce9a00c26c01 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Wed, 17 Sep 2025 23:49:27 +0800 Subject: [PATCH 12/13] Apply suggestion from @gemini-code-assist[bot] Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- api/controllers/inner_api/plugin/plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index c5bb2f2545..2ee1f6c988 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -420,7 +420,7 @@ class PluginUploadFileRequestApi(Resource): ) def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile): # generate signed url - url = get_signed_file_url_for_plugin(payload.filename, payload.mimetype, tenant_model.id, user_model.id) + url = get_signed_file_url_for_plugin(filename=payload.filename, mimetype=payload.mimetype, tenant_id=tenant_model.id, user_id=user_model.id) return BaseBackwardsInvocationResponse(data={"url": url}).model_dump() From 370127b87a397964f8a38fc72cd831fc2a66a7a7 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Wed, 17 Sep 2025 23:58:30 +0800 Subject: [PATCH 13/13] fix(api): fix line too long --- api/controllers/inner_api/plugin/plugin.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 2ee1f6c988..deab50076d 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -420,7 +420,12 @@ class PluginUploadFileRequestApi(Resource): ) def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile): # generate signed url - url = get_signed_file_url_for_plugin(filename=payload.filename, mimetype=payload.mimetype, tenant_id=tenant_model.id, user_id=user_model.id) + url = get_signed_file_url_for_plugin( + filename=payload.filename, + mimetype=payload.mimetype, + tenant_id=tenant_model.id, + user_id=user_model.id, + ) return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()