diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 0bc93ad34d..9b8bf566c1 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -397,39 +397,40 @@ class CotAgentRunner(BaseAgentRunner, ABC): current_scratchpad: AgentScratchpadUnit | None = None for message in self.history_prompt_messages: - if isinstance(message, AssistantPromptMessage): - if not current_scratchpad: - assert isinstance(message.content, str) - current_scratchpad = AgentScratchpadUnit( - agent_response=message.content, - thought=message.content or "I am thinking about how to help you", - action_str="", - action=None, - observation=None, - ) - scratchpads.append(current_scratchpad) - if message.tool_calls: - try: - current_scratchpad.action = AgentScratchpadUnit.Action( - action_name=message.tool_calls[0].function.name, - action_input=json.loads(message.tool_calls[0].function.arguments), + match message: + case AssistantPromptMessage(): + if not current_scratchpad: + assert isinstance(message.content, str) + current_scratchpad = AgentScratchpadUnit( + agent_response=message.content, + thought=message.content or "I am thinking about how to help you", + action_str="", + action=None, + observation=None, ) - current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict()) - except Exception: - logger.exception("Failed to parse tool call from assistant message") - elif isinstance(message, ToolPromptMessage): - if current_scratchpad: - assert isinstance(message.content, str) - current_scratchpad.observation = message.content - else: - raise NotImplementedError("expected str type") - elif isinstance(message, UserPromptMessage): - if scratchpads: - result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) - scratchpads = [] - current_scratchpad = None + scratchpads.append(current_scratchpad) + if message.tool_calls: + try: + current_scratchpad.action = AgentScratchpadUnit.Action( + action_name=message.tool_calls[0].function.name, + action_input=json.loads(message.tool_calls[0].function.arguments), + ) + current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict()) + except Exception: + logger.exception("Failed to parse tool call from assistant message") + case ToolPromptMessage(): + if current_scratchpad: + assert isinstance(message.content, str) + current_scratchpad.observation = message.content + else: + raise NotImplementedError("expected str type") + case UserPromptMessage(): + if scratchpads: + result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) + scratchpads = [] + current_scratchpad = None - result.append(message) + result.append(message) if scratchpads: result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads))) diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index 4f3c74deea..6aa1b85028 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -134,17 +134,18 @@ class AdvancedChatAppGenerateResponseConverter( "created_at": chunk.created_at, } - if isinstance(sub_stream_response, MessageEndStreamResponse): - sub_stream_response_dict = sub_stream_response.model_dump(mode="json") - metadata = sub_stream_response_dict.get("metadata", {}) - sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) - response_chunk.update(sub_stream_response_dict) - elif isinstance(sub_stream_response, ErrorStreamResponse): - data = cls._error_to_stream_response(sub_stream_response.err) - response_chunk.update(data) - elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): - response_chunk.update(sub_stream_response.to_ignore_detail_dict()) - else: - response_chunk.update(sub_stream_response.model_dump(mode="json")) + match sub_stream_response: + case MessageEndStreamResponse(): + sub_stream_response_dict = sub_stream_response.model_dump(mode="json") + metadata = sub_stream_response_dict.get("metadata", {}) + sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) + response_chunk.update(sub_stream_response_dict) + case ErrorStreamResponse(): + data = cls._error_to_stream_response(sub_stream_response.err) + response_chunk.update(data) + case NodeStartStreamResponse() | NodeFinishStreamResponse(): + response_chunk.update(sub_stream_response.to_ignore_detail_dict()) + case _: + response_chunk.update(sub_stream_response.model_dump(mode="json")) yield response_chunk diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 0ca682e87a..ea4a187a9c 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -305,26 +305,27 @@ class AppRunner: text += message.content elif isinstance(message.content, list): for content in message.content: - if isinstance(content, str): - text += content - elif isinstance(content, TextPromptMessageContent): - text += content.data - elif isinstance(content, ImagePromptMessageContent): - if message_id and user_id and tenant_id: - try: - self._handle_multimodal_image_content( - content=content, - message_id=message_id, - user_id=user_id, - tenant_id=tenant_id, - queue_manager=queue_manager, - ) - except Exception: - _logger.exception("Failed to handle multimodal image output") - else: - _logger.warning("Received multimodal output but missing required parameters") - else: - text += content.data if hasattr(content, "data") else str(content) + match content: + case str(): + text += content + case TextPromptMessageContent(): + text += content.data + case ImagePromptMessageContent(): + if message_id and user_id and tenant_id: + try: + self._handle_multimodal_image_content( + content=content, + message_id=message_id, + user_id=user_id, + tenant_id=tenant_id, + queue_manager=queue_manager, + ) + except Exception: + _logger.exception("Failed to handle multimodal image output") + else: + _logger.warning("Received multimodal output but missing required parameters") + case _: + text += content.data if hasattr(content, "data") else str(content) if not model: model = result.model diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index 38864a1830..3a6c314159 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -77,17 +77,16 @@ class TemplateTransformer(ABC): """ def convert_scientific_notation(value: Any) -> Any: - if isinstance(value, str): - # Check if the string looks like scientific notation - if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE): + match value: + case str() if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE): try: return float(value) except ValueError: pass - elif isinstance(value, dict): - return {k: convert_scientific_notation(v) for k, v in value.items()} - elif isinstance(value, list): - return [convert_scientific_notation(v) for v in value] + case dict(): + return {k: convert_scientific_notation(v) for k, v in value.items()} + case list(): + return [convert_scientific_notation(v) for v in value] return value return convert_scientific_notation(result) diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 884610ca82..fadb6fa2d6 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -72,20 +72,21 @@ def handle_mcp_request( try: # Dispatch request to appropriate handler based on instance type - if isinstance(request_root, mcp_types.InitializeRequest): - return create_success_response(handle_initialize(mcp_server.description)) - elif isinstance(request_root, mcp_types.ListToolsRequest): - return create_success_response( - handle_list_tools( - app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict + match request_root: + case mcp_types.InitializeRequest(): + return create_success_response(handle_initialize(mcp_server.description)) + case mcp_types.ListToolsRequest(): + return create_success_response( + handle_list_tools( + app.name, app.mode, user_input_form, mcp_server.description, mcp_server.parameters_dict + ) ) - ) - elif isinstance(request_root, mcp_types.CallToolRequest): - return create_success_response(handle_call_tool(app, request, user_input_form, end_user)) - elif isinstance(request_root, mcp_types.PingRequest): - return create_success_response(handle_ping()) - else: - return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}") + case mcp_types.CallToolRequest(): + return create_success_response(handle_call_tool(app, request, user_input_form, end_user)) + case mcp_types.PingRequest(): + return create_success_response(handle_ping()) + case _: + return create_error_response(mcp_types.METHOD_NOT_FOUND, f"Method not found: {request_type.__name__}") except ValueError as e: logger.exception("Invalid params")