From 8cc6927fede350c959cbcdf20a14f31c2203b1aa Mon Sep 17 00:00:00 2001 From: jyong <718720800@qq.com> Date: Wed, 17 Sep 2025 23:04:03 +0800 Subject: [PATCH] fix mypy --- api/controllers/service_api/dataset/error.py | 6 ++++++ api/core/app/apps/base_app_generator.py | 12 ++++++------ .../app/apps/pipeline/generate_response_converter.py | 4 ++-- api/core/app/apps/pipeline/pipeline_generator.py | 2 +- api/core/datasource/datasource_manager.py | 8 ++++++-- api/core/file/models.py | 1 + api/core/workflow/nodes/loop/loop_node.py | 2 +- api/services/datasource_provider_service.py | 2 +- api/services/rag_pipeline/rag_pipeline.py | 5 +++-- 9 files changed, 27 insertions(+), 15 deletions(-) diff --git a/api/controllers/service_api/dataset/error.py b/api/controllers/service_api/dataset/error.py index e4214a16ad..ecfc37df85 100644 --- a/api/controllers/service_api/dataset/error.py +++ b/api/controllers/service_api/dataset/error.py @@ -47,3 +47,9 @@ class DatasetInUseError(BaseHTTPException): error_code = "dataset_in_use" description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." code = 409 + + +class PipelineRunError(BaseHTTPException): + error_code = "pipeline_run_error" + description = "An error occurred while running the pipeline." + code = 500 diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index b8ff40cb5f..01d025aca8 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -45,9 +45,9 @@ class BaseAppGenerator: mapping=v, tenant_id=tenant_id, config=FileUploadConfig( - allowed_file_types=entity_dictionary[k].allowed_file_types, # pyright: ignore[reportArgumentType] - allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, # pyright: ignore[reportArgumentType] - allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, # pyright: ignore[reportArgumentType] + allowed_file_types=entity_dictionary[k].allowed_file_types or [], + allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [], + allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [], ), strict_type_validation=strict_type_validation, ) @@ -60,9 +60,9 @@ class BaseAppGenerator: mappings=v, tenant_id=tenant_id, config=FileUploadConfig( - allowed_file_types=entity_dictionary[k].allowed_file_types, # pyright: ignore[reportArgumentType] - allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, # pyright: ignore[reportArgumentType] - allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, # pyright: ignore[reportArgumentType] + allowed_file_types=entity_dictionary[k].allowed_file_types or [], + allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [], + allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [], ), ) for k, v in user_inputs.items() diff --git a/api/core/app/apps/pipeline/generate_response_converter.py b/api/core/app/apps/pipeline/generate_response_converter.py index f47db16c18..cfacd8640d 100644 --- a/api/core/app/apps/pipeline/generate_response_converter.py +++ b/api/core/app/apps/pipeline/generate_response_converter.py @@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(cast(dict, data)) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump()) yield response_chunk @classmethod @@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict())) else: - response_chunk.update(sub_stream_response.to_dict()) + response_chunk.update(sub_stream_response.model_dump()) yield response_chunk diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index b1e98ed3ea..780e3a53db 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -744,7 +744,7 @@ class PipelineGenerator(BaseAppGenerator): Format datasource info list. """ if datasource_type == "online_drive": - all_files = [] + all_files: list[Mapping[str, Any]] = [] datasource_node_data = None datasource_nodes = workflow.graph_dict.get("nodes", []) for datasource_node in datasource_nodes: diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 3144712fe9..31f982e960 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -46,7 +46,7 @@ class DatasourceManager: provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id) if not provider_entity: raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found") - controller = None + controller: DatasourcePluginProviderController | None = None match datasource_type: case DatasourceProviderType.ONLINE_DOCUMENT: controller = OnlineDocumentDatasourcePluginProviderController( @@ -79,8 +79,12 @@ class DatasourceManager: case _: raise ValueError(f"Unsupported datasource type: {datasource_type}") - datasource_plugin_providers[provider_id] = controller + if controller: + datasource_plugin_providers[provider_id] = controller + if controller is None: + raise DatasourceProviderNotFoundError(f"Datasource provider {provider_id} not found.") + return controller @classmethod diff --git a/api/core/file/models.py b/api/core/file/models.py index a242851a63..7089b7ce7a 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -119,6 +119,7 @@ class File(BaseModel): assert self.related_id is not None assert self.extension is not None return sign_tool_file(tool_file_id=self.related_id, extension=self.extension) + return None def to_plugin_parameter(self) -> dict[str, Any]: return { diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 25b0c4f4fe..d783290e51 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -92,7 +92,7 @@ class LoopNode(Node): if self._node_data.loop_variables: value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = { "constant": lambda var: self._get_segment_for_constant(var.var_type, var.value), - "variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value), + "variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value) if isinstance(var.value, list) else None, } for loop_variable in self._node_data.loop_variables: if loop_variable.value_type not in value_processor: diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 8dceeee7ec..ae4ab5947b 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -77,7 +77,7 @@ class DatasourceProviderService: provider_id=f"{plugin_id}/{provider}", credential_type=CredentialType.of(datasource_provider.auth_type), ) - encrypted_credentials = raw_credentials.copy() + encrypted_credentials = dict(raw_credentials) for key, value in encrypted_credentials.items(): if key in provider_credential_secret_variables: encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value) diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index e7c255a86a..07b0e53cbe 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -378,12 +378,12 @@ class RagPipelineService: Get default block configs """ # return default block config - default_block_configs = [] + default_block_configs: list[dict[str, Any]] = [] for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values(): node_class = node_class_mapping[LATEST_VERSION] default_config = node_class.get_default_config() if default_config: - default_block_configs.append(default_config) + default_block_configs.append(dict(default_config)) return default_block_configs @@ -631,6 +631,7 @@ class RagPipelineService: try: for website_crawl_message in website_crawl_result: end_time = time.time() + crawl_event: DatasourceCompletedEvent | DatasourceProcessingEvent if website_crawl_message.result.status == "completed": crawl_event = DatasourceCompletedEvent( data=website_crawl_message.result.web_info_list or [],