refactor: add node_id to MemoryBlockSpec

This commit is contained in:
Stream 2025-10-28 13:04:01 +08:00
parent 89d53ecf50
commit f0ff2e1f2c
No known key found for this signature in database
GPG Key ID: 033728094B100D70
4 changed files with 40 additions and 30 deletions

View File

@ -4,7 +4,7 @@ from enum import StrEnum
from typing import Optional
from uuid import uuid4
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from core.app.app_config.entities import ModelConfig
@ -49,6 +49,21 @@ class MemoryBlockSpec(BaseModel):
model: ModelConfig = Field(description="Model configuration for memory updates")
end_user_visible: bool = Field(default=False, description="Whether memory is visible to end users")
end_user_editable: bool = Field(default=False, description="Whether memory is editable by end users")
node_id: str | None = Field(
default=None,
description="Node ID when scope is NODE. Must be None when scope is APP."
)
@field_validator('node_id')
@classmethod
def validate_node_id_with_scope(cls, v: str | None, info) -> str | None:
"""Validate node_id consistency with scope"""
scope = info.data.get('scope')
if scope == MemoryScope.NODE and v is None:
raise ValueError("node_id is required when scope is NODE")
if scope == MemoryScope.APP and v is not None:
raise ValueError("node_id must be None when scope is APP")
return v
class MemoryCreatedBy(BaseModel):

View File

@ -44,13 +44,7 @@ class MemoryConfig(BaseModel):
enabled: bool
size: int | None = None
mode: Literal["linear", "block"] | None = "linear"
block_id: list[str] | None = None
role_prefix: RolePrefix | None = None
window: WindowConfig
query_prompt_template: str | None = None
@property
def is_block_mode(self) -> bool:
return self.mode == "block" and bool(self.block_id)

View File

@ -1234,13 +1234,6 @@ class LLMNode(Node):
tenant_id=self.tenant_id
)
memory_config = self._node_data.memory
if not memory_config:
return
block_ids = memory_config.block_id
if not block_ids:
return
# FIXME: This is dirty workaround and may cause incorrect resolution for workflow version
with Session(db.engine) as session:
stmt = select(Workflow).where(
@ -1250,24 +1243,31 @@ class LLMNode(Node):
workflow = session.scalars(stmt).first()
if not workflow:
raise ValueError("Workflow not found.")
memory_blocks = workflow.memory_blocks
for block_id in block_ids:
memory_block_spec = next((block for block in memory_blocks if block.id == block_id), None)
# Filter memory blocks that belong to this node
node_memory_blocks = [
block for block in workflow.memory_blocks
if block.scope == MemoryScope.NODE and block.node_id == self.id
]
if memory_block_spec and memory_block_spec.scope == MemoryScope.NODE:
is_draft = (self.invoke_from == InvokeFrom.DEBUGGER)
from services.chatflow_memory_service import ChatflowMemoryService
ChatflowMemoryService.update_node_memory_if_needed(
tenant_id=self.tenant_id,
app_id=self.app_id,
node_id=self.id,
conversation_id=conversation_id,
memory_block_spec=memory_block_spec,
variable_pool=variable_pool,
is_draft=is_draft,
created_by=self._get_user_from_context()
)
if not node_memory_blocks:
return
# Update each memory block that belongs to this node
is_draft = (self.invoke_from == InvokeFrom.DEBUGGER)
from services.chatflow_memory_service import ChatflowMemoryService
for memory_block_spec in node_memory_blocks:
ChatflowMemoryService.update_node_memory_if_needed(
tenant_id=self.tenant_id,
app_id=self.app_id,
node_id=self.id,
conversation_id=conversation_id,
memory_block_spec=memory_block_spec,
variable_pool=variable_pool,
is_draft=is_draft,
created_by=self._get_user_from_context()
)
def _get_user_from_context(self) -> MemoryCreatedBy:
if self.user_from == UserFrom.ACCOUNT:

View File

@ -71,6 +71,7 @@ memory_block_fields = {
"model": fields.Nested(model_config_fields),
"end_user_visible": fields.Boolean,
"end_user_editable": fields.Boolean,
"node_id": fields.String,
}
pipeline_variable_fields = {