From f0ff2e1f2ce7bff8b23671b04397f234502c5cbc Mon Sep 17 00:00:00 2001 From: Stream Date: Tue, 28 Oct 2025 13:04:01 +0800 Subject: [PATCH] refactor: add node_id to MemoryBlockSpec --- api/core/memory/entities.py | 17 ++++++- .../entities/advanced_prompt_entities.py | 6 --- api/core/workflow/nodes/llm/node.py | 46 +++++++++---------- api/fields/workflow_fields.py | 1 + 4 files changed, 40 insertions(+), 30 deletions(-) diff --git a/api/core/memory/entities.py b/api/core/memory/entities.py index e0f6ed840f..4d15a8c9d6 100644 --- a/api/core/memory/entities.py +++ b/api/core/memory/entities.py @@ -4,7 +4,7 @@ from enum import StrEnum from typing import Optional from uuid import uuid4 -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from core.app.app_config.entities import ModelConfig @@ -49,6 +49,21 @@ class MemoryBlockSpec(BaseModel): model: ModelConfig = Field(description="Model configuration for memory updates") end_user_visible: bool = Field(default=False, description="Whether memory is visible to end users") end_user_editable: bool = Field(default=False, description="Whether memory is editable by end users") + node_id: str | None = Field( + default=None, + description="Node ID when scope is NODE. Must be None when scope is APP." + ) + + @field_validator('node_id') + @classmethod + def validate_node_id_with_scope(cls, v: str | None, info) -> str | None: + """Validate node_id consistency with scope""" + scope = info.data.get('scope') + if scope == MemoryScope.NODE and v is None: + raise ValueError("node_id is required when scope is NODE") + if scope == MemoryScope.APP and v is not None: + raise ValueError("node_id must be None when scope is APP") + return v class MemoryCreatedBy(BaseModel): diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index 1e77d12715..6510b15114 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -44,13 +44,7 @@ class MemoryConfig(BaseModel): enabled: bool size: int | None = None - mode: Literal["linear", "block"] | None = "linear" - block_id: list[str] | None = None role_prefix: RolePrefix | None = None window: WindowConfig query_prompt_template: str | None = None - - @property - def is_block_mode(self) -> bool: - return self.mode == "block" and bool(self.block_id) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 818c711e49..3de84e83cd 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1234,13 +1234,6 @@ class LLMNode(Node): tenant_id=self.tenant_id ) - memory_config = self._node_data.memory - if not memory_config: - return - block_ids = memory_config.block_id - if not block_ids: - return - # FIXME: This is dirty workaround and may cause incorrect resolution for workflow version with Session(db.engine) as session: stmt = select(Workflow).where( @@ -1250,24 +1243,31 @@ class LLMNode(Node): workflow = session.scalars(stmt).first() if not workflow: raise ValueError("Workflow not found.") - memory_blocks = workflow.memory_blocks - for block_id in block_ids: - memory_block_spec = next((block for block in memory_blocks if block.id == block_id), None) + # Filter memory blocks that belong to this node + node_memory_blocks = [ + block for block in workflow.memory_blocks + if block.scope == MemoryScope.NODE and block.node_id == self.id + ] - if memory_block_spec and memory_block_spec.scope == MemoryScope.NODE: - is_draft = (self.invoke_from == InvokeFrom.DEBUGGER) - from services.chatflow_memory_service import ChatflowMemoryService - ChatflowMemoryService.update_node_memory_if_needed( - tenant_id=self.tenant_id, - app_id=self.app_id, - node_id=self.id, - conversation_id=conversation_id, - memory_block_spec=memory_block_spec, - variable_pool=variable_pool, - is_draft=is_draft, - created_by=self._get_user_from_context() - ) + if not node_memory_blocks: + return + + # Update each memory block that belongs to this node + is_draft = (self.invoke_from == InvokeFrom.DEBUGGER) + from services.chatflow_memory_service import ChatflowMemoryService + + for memory_block_spec in node_memory_blocks: + ChatflowMemoryService.update_node_memory_if_needed( + tenant_id=self.tenant_id, + app_id=self.app_id, + node_id=self.id, + conversation_id=conversation_id, + memory_block_spec=memory_block_spec, + variable_pool=variable_pool, + is_draft=is_draft, + created_by=self._get_user_from_context() + ) def _get_user_from_context(self) -> MemoryCreatedBy: if self.user_from == UserFrom.ACCOUNT: diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 46ff19eabe..f981adfdc2 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -71,6 +71,7 @@ memory_block_fields = { "model": fields.Nested(model_config_fields), "end_user_visible": fields.Boolean, "end_user_editable": fields.Boolean, + "node_id": fields.String, } pipeline_variable_fields = {