From 491fa9923b8d1fd3820a5df40eda4ebf22affdf5 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Mon, 2 Feb 2026 21:03:16 +0900 Subject: [PATCH] refactor: port api/controllers/console/datasets/data_source.py /datasets/metadata.py /service_api/dataset/metadata.py /nodes/agent/agent_node.py api/core/workflow/nodes/datasource/datasource_node.py api/services/dataset_service.py to match case (#31836) --- .../console/datasets/data_source.py | 40 ++-- api/controllers/console/datasets/metadata.py | 9 +- .../service_api/dataset/metadata.py | 9 +- api/core/workflow/nodes/agent/agent_node.py | 64 +++--- .../nodes/datasource/datasource_node.py | 193 +++++++++--------- api/services/dataset_service.py | 114 ++++++----- 6 files changed, 223 insertions(+), 206 deletions(-) diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 01e9bf77c0..daef4e005a 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -1,6 +1,6 @@ import json from collections.abc import Generator -from typing import Any, cast +from typing import Any, Literal, cast from flask import request from flask_restx import Resource, fields, marshal_with @@ -157,9 +157,8 @@ class DataSourceApi(Resource): @setup_required @login_required @account_initialization_required - def patch(self, binding_id, action): + def patch(self, binding_id, action: Literal["enable", "disable"]): binding_id = str(binding_id) - action = str(action) with Session(db.engine) as session: data_source_binding = session.execute( select(DataSourceOauthBinding).filter_by(id=binding_id) @@ -167,23 +166,24 @@ class DataSourceApi(Resource): if data_source_binding is None: raise NotFound("Data source binding not found.") # enable binding - if action == "enable": - if data_source_binding.disabled: - data_source_binding.disabled = False - data_source_binding.updated_at = naive_utc_now() - db.session.add(data_source_binding) - db.session.commit() - else: - raise ValueError("Data source is not disabled.") - # disable binding - if action == "disable": - if not data_source_binding.disabled: - data_source_binding.disabled = True - data_source_binding.updated_at = naive_utc_now() - db.session.add(data_source_binding) - db.session.commit() - else: - raise ValueError("Data source is disabled.") + match action: + case "enable": + if data_source_binding.disabled: + data_source_binding.disabled = False + data_source_binding.updated_at = naive_utc_now() + db.session.add(data_source_binding) + db.session.commit() + else: + raise ValueError("Data source is not disabled.") + # disable binding + case "disable": + if not data_source_binding.disabled: + data_source_binding.disabled = True + data_source_binding.updated_at = naive_utc_now() + db.session.add(data_source_binding) + db.session.commit() + else: + raise ValueError("Data source is disabled.") return {"result": "success"}, 200 diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index 05fc4cd714..2e69ddc5ab 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -126,10 +126,11 @@ class DatasetMetadataBuiltInFieldActionApi(Resource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - if action == "enable": - MetadataService.enable_built_in_field(dataset) - elif action == "disable": - MetadataService.disable_built_in_field(dataset) + match action: + case "enable": + MetadataService.enable_built_in_field(dataset) + case "disable": + MetadataService.disable_built_in_field(dataset) return {"result": "success"}, 200 diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index b8d9508004..692342a38a 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -168,10 +168,11 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): raise NotFound("Dataset not found.") DatasetService.check_dataset_permission(dataset, current_user) - if action == "enable": - MetadataService.enable_built_in_field(dataset) - elif action == "disable": - MetadataService.disable_built_in_field(dataset) + match action: + case "enable": + MetadataService.enable_built_in_field(dataset) + case "disable": + MetadataService.disable_built_in_field(dataset) return {"result": "success"}, 200 diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 5a365f769d..e195aebe6d 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -192,32 +192,33 @@ class AgentNode(Node[AgentNodeData]): result[parameter_name] = None continue agent_input = node_data.agent_parameters[parameter_name] - if agent_input.type == "variable": - variable = variable_pool.get(agent_input.value) # type: ignore - if variable is None: - raise AgentVariableNotFoundError(str(agent_input.value)) - parameter_value = variable.value - elif agent_input.type in {"mixed", "constant"}: - # variable_pool.convert_template expects a string template, - # but if passing a dict, convert to JSON string first before rendering - try: - if not isinstance(agent_input.value, str): - parameter_value = json.dumps(agent_input.value, ensure_ascii=False) - else: + match agent_input.type: + case "variable": + variable = variable_pool.get(agent_input.value) # type: ignore + if variable is None: + raise AgentVariableNotFoundError(str(agent_input.value)) + parameter_value = variable.value + case "mixed" | "constant": + # variable_pool.convert_template expects a string template, + # but if passing a dict, convert to JSON string first before rendering + try: + if not isinstance(agent_input.value, str): + parameter_value = json.dumps(agent_input.value, ensure_ascii=False) + else: + parameter_value = str(agent_input.value) + except TypeError: parameter_value = str(agent_input.value) - except TypeError: - parameter_value = str(agent_input.value) - segment_group = variable_pool.convert_template(parameter_value) - parameter_value = segment_group.log if for_log else segment_group.text - # variable_pool.convert_template returns a string, - # so we need to convert it back to a dictionary - try: - if not isinstance(agent_input.value, str): - parameter_value = json.loads(parameter_value) - except json.JSONDecodeError: - parameter_value = parameter_value - else: - raise AgentInputTypeError(agent_input.type) + segment_group = variable_pool.convert_template(parameter_value) + parameter_value = segment_group.log if for_log else segment_group.text + # variable_pool.convert_template returns a string, + # so we need to convert it back to a dictionary + try: + if not isinstance(agent_input.value, str): + parameter_value = json.loads(parameter_value) + except json.JSONDecodeError: + parameter_value = parameter_value + case _: + raise AgentInputTypeError(agent_input.type) value = parameter_value if parameter.type == "array[tools]": value = cast(list[dict[str, Any]], value) @@ -374,12 +375,13 @@ class AgentNode(Node[AgentNodeData]): result: dict[str, Any] = {} for parameter_name in typed_node_data.agent_parameters: input = typed_node_data.agent_parameters[parameter_name] - if input.type in ["mixed", "constant"]: - selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - elif input.type == "variable": - result[parameter_name] = input.value + match input.type: + case "mixed" | "constant": + selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + case "variable": + result[parameter_name] = input.value result = {node_id + "." + key: value for key, value in result.items()} diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index fd71d610b4..a732a70417 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -270,15 +270,18 @@ class DatasourceNode(Node[DatasourceNodeData]): if typed_node_data.datasource_parameters: for parameter_name in typed_node_data.datasource_parameters: input = typed_node_data.datasource_parameters[parameter_name] - if input.type == "mixed": - assert isinstance(input.value, str) - selectors = VariableTemplateParser(input.value).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - elif input.type == "variable": - result[parameter_name] = input.value - elif input.type == "constant": - pass + match input.type: + case "mixed": + assert isinstance(input.value, str) + selectors = VariableTemplateParser(input.value).extract_variable_selectors() + for selector in selectors: + result[selector.variable] = selector.value_selector + case "variable": + result[parameter_name] = input.value + case "constant": + pass + case None: + pass result = {node_id + "." + key: value for key, value in result.items()} @@ -308,99 +311,107 @@ class DatasourceNode(Node[DatasourceNodeData]): variables: dict[str, Any] = {} for message in message_stream: - if message.type in { - DatasourceMessage.MessageType.IMAGE_LINK, - DatasourceMessage.MessageType.BINARY_LINK, - DatasourceMessage.MessageType.IMAGE, - }: - assert isinstance(message.message, DatasourceMessage.TextMessage) + match message.type: + case ( + DatasourceMessage.MessageType.IMAGE_LINK + | DatasourceMessage.MessageType.BINARY_LINK + | DatasourceMessage.MessageType.IMAGE + ): + assert isinstance(message.message, DatasourceMessage.TextMessage) - url = message.message.text - transfer_method = FileTransferMethod.TOOL_FILE + url = message.message.text + transfer_method = FileTransferMethod.TOOL_FILE - datasource_file_id = str(url).split("/")[-1].split(".")[0] + datasource_file_id = str(url).split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) - datasource_file = session.scalar(stmt) - if datasource_file is None: - raise ToolFileError(f"Tool file {datasource_file_id} does not exist") + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) + datasource_file = session.scalar(stmt) + if datasource_file is None: + raise ToolFileError(f"Tool file {datasource_file_id} does not exist") - mapping = { - "tool_file_id": datasource_file_id, - "type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype), - "transfer_method": transfer_method, - "url": url, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=self.tenant_id, - ) - files.append(file) - elif message.type == DatasourceMessage.MessageType.BLOB: - # get tool file id - assert isinstance(message.message, DatasourceMessage.TextMessage) - assert message.meta - - datasource_file_id = message.message.text.split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) - datasource_file = session.scalar(stmt) - if datasource_file is None: - raise ToolFileError(f"datasource file {datasource_file_id} not exists") - - mapping = { - "tool_file_id": datasource_file_id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - - files.append( - file_factory.build_from_mapping( + mapping = { + "tool_file_id": datasource_file_id, + "type": file_factory.get_file_type_by_mime_type(datasource_file.mimetype), + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( mapping=mapping, tenant_id=self.tenant_id, ) - ) - elif message.type == DatasourceMessage.MessageType.TEXT: - assert isinstance(message.message, DatasourceMessage.TextMessage) - text += message.message.text - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk=message.message.text, - is_final=False, - ) - elif message.type == DatasourceMessage.MessageType.JSON: - assert isinstance(message.message, DatasourceMessage.JsonMessage) - json.append(message.message.json_object) - elif message.type == DatasourceMessage.MessageType.LINK: - assert isinstance(message.message, DatasourceMessage.TextMessage) - stream_text = f"Link: {message.message.text}\n" - text += stream_text - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk=stream_text, - is_final=False, - ) - elif message.type == DatasourceMessage.MessageType.VARIABLE: - assert isinstance(message.message, DatasourceMessage.VariableMessage) - variable_name = message.message.variable_name - variable_value = message.message.variable_value - if message.message.stream: - if not isinstance(variable_value, str): - raise ValueError("When 'stream' is True, 'variable_value' must be a string.") - if variable_name not in variables: - variables[variable_name] = "" - variables[variable_name] += variable_value + files.append(file) + case DatasourceMessage.MessageType.BLOB: + # get tool file id + assert isinstance(message.message, DatasourceMessage.TextMessage) + assert message.meta + datasource_file_id = message.message.text.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == datasource_file_id) + datasource_file = session.scalar(stmt) + if datasource_file is None: + raise ToolFileError(f"datasource file {datasource_file_id} not exists") + + mapping = { + "tool_file_id": datasource_file_id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + + files.append( + file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self.tenant_id, + ) + ) + case DatasourceMessage.MessageType.TEXT: + assert isinstance(message.message, DatasourceMessage.TextMessage) + text += message.message.text yield StreamChunkEvent( - selector=[self._node_id, variable_name], - chunk=variable_value, + selector=[self._node_id, "text"], + chunk=message.message.text, is_final=False, ) - else: - variables[variable_name] = variable_value - elif message.type == DatasourceMessage.MessageType.FILE: - assert message.meta is not None - files.append(message.meta["file"]) + case DatasourceMessage.MessageType.JSON: + assert isinstance(message.message, DatasourceMessage.JsonMessage) + json.append(message.message.json_object) + case DatasourceMessage.MessageType.LINK: + assert isinstance(message.message, DatasourceMessage.TextMessage) + stream_text = f"Link: {message.message.text}\n" + text += stream_text + yield StreamChunkEvent( + selector=[self._node_id, "text"], + chunk=stream_text, + is_final=False, + ) + case DatasourceMessage.MessageType.VARIABLE: + assert isinstance(message.message, DatasourceMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + + yield StreamChunkEvent( + selector=[self._node_id, variable_name], + chunk=variable_value, + is_final=False, + ) + else: + variables[variable_name] = variable_value + case DatasourceMessage.MessageType.FILE: + assert message.meta is not None + files.append(message.meta["file"]) + case ( + DatasourceMessage.MessageType.BLOB_CHUNK + | DatasourceMessage.MessageType.LOG + | DatasourceMessage.MessageType.RETRIEVER_RESOURCES + ): + pass + # mark the end of the stream yield StreamChunkEvent( selector=[self._node_id, "text"], diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 16945fca6a..1ea6c4e1c3 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -2978,14 +2978,15 @@ class DocumentService: """ now = naive_utc_now() - if action == "enable": - return DocumentService._prepare_enable_update(document, now) - elif action == "disable": - return DocumentService._prepare_disable_update(document, user, now) - elif action == "archive": - return DocumentService._prepare_archive_update(document, user, now) - elif action == "un_archive": - return DocumentService._prepare_unarchive_update(document, now) + match action: + case "enable": + return DocumentService._prepare_enable_update(document, now) + case "disable": + return DocumentService._prepare_disable_update(document, user, now) + case "archive": + return DocumentService._prepare_archive_update(document, user, now) + case "un_archive": + return DocumentService._prepare_unarchive_update(document, now) return None @@ -3622,56 +3623,57 @@ class SegmentService: # Check if segment_ids is not empty to avoid WHERE false condition if not segment_ids or len(segment_ids) == 0: return - if action == "enable": - segments = db.session.scalars( - select(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.document_id == document.id, - DocumentSegment.enabled == False, - ) - ).all() - if not segments: - return - real_deal_segment_ids = [] - for segment in segments: - indexing_cache_key = f"segment_{segment.id}_indexing" - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - continue - segment.enabled = True - segment.disabled_at = None - segment.disabled_by = None - db.session.add(segment) - real_deal_segment_ids.append(segment.id) - db.session.commit() + match action: + case "enable": + segments = db.session.scalars( + select(DocumentSegment).where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == False, + ) + ).all() + if not segments: + return + real_deal_segment_ids = [] + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = True + segment.disabled_at = None + segment.disabled_by = None + db.session.add(segment) + real_deal_segment_ids.append(segment.id) + db.session.commit() - enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) - elif action == "disable": - segments = db.session.scalars( - select(DocumentSegment).where( - DocumentSegment.id.in_(segment_ids), - DocumentSegment.dataset_id == dataset.id, - DocumentSegment.document_id == document.id, - DocumentSegment.enabled == True, - ) - ).all() - if not segments: - return - real_deal_segment_ids = [] - for segment in segments: - indexing_cache_key = f"segment_{segment.id}_indexing" - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - continue - segment.enabled = False - segment.disabled_at = naive_utc_now() - segment.disabled_by = current_user.id - db.session.add(segment) - real_deal_segment_ids.append(segment.id) - db.session.commit() + enable_segments_to_index_task.delay(real_deal_segment_ids, dataset.id, document.id) + case "disable": + segments = db.session.scalars( + select(DocumentSegment).where( + DocumentSegment.id.in_(segment_ids), + DocumentSegment.dataset_id == dataset.id, + DocumentSegment.document_id == document.id, + DocumentSegment.enabled == True, + ) + ).all() + if not segments: + return + real_deal_segment_ids = [] + for segment in segments: + indexing_cache_key = f"segment_{segment.id}_indexing" + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + continue + segment.enabled = False + segment.disabled_at = naive_utc_now() + segment.disabled_by = current_user.id + db.session.add(segment) + real_deal_segment_ids.append(segment.id) + db.session.commit() - disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id) + disable_segments_from_index_task.delay(real_deal_segment_ids, dataset.id, document.id) @classmethod def create_child_chunk(