diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 4be006de11..5090cccb5d 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -303,9 +303,10 @@ class AgentNode(Node[AgentNodeData]): if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR: value = cast(dict[str, Any], value) model_instance, model_schema = self._fetch_model(value) - # memory config + # memory config - only for LLM models history_prompt_messages = [] - if node_data.memory: + model_type_str = value.get("model_type", ModelType.LLM.value) + if node_data.memory and model_type_str == ModelType.LLM.value: memory = self._fetch_memory(model_instance) if memory: prompt_messages = memory.get_history_prompt_messages( @@ -415,12 +416,13 @@ class AgentNode(Node[AgentNodeData]): def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]: provider_manager = ProviderManager() + model_type = ModelType(value.get("model_type", ModelType.LLM.value)) provider_model_bundle = provider_manager.get_provider_model_bundle( - tenant_id=self.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM + tenant_id=self.tenant_id, provider=value.get("provider", ""), model_type=model_type ) model_name = value.get("model", "") model_credentials = provider_model_bundle.configuration.get_current_credentials( - model_type=ModelType.LLM, model=model_name + model_type=model_type, model=model_name ) provider_name = provider_model_bundle.configuration.provider.provider model_type_instance = provider_model_bundle.model_type_instance