mirror of https://github.com/langgenius/dify.git
refactor: add node_id to MemoryBlockSpec
This commit is contained in:
parent
89d53ecf50
commit
f0ff2e1f2c
|
|
@ -4,7 +4,7 @@ from enum import StrEnum
|
|||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
|
||||
|
|
@ -49,6 +49,21 @@ class MemoryBlockSpec(BaseModel):
|
|||
model: ModelConfig = Field(description="Model configuration for memory updates")
|
||||
end_user_visible: bool = Field(default=False, description="Whether memory is visible to end users")
|
||||
end_user_editable: bool = Field(default=False, description="Whether memory is editable by end users")
|
||||
node_id: str | None = Field(
|
||||
default=None,
|
||||
description="Node ID when scope is NODE. Must be None when scope is APP."
|
||||
)
|
||||
|
||||
@field_validator('node_id')
|
||||
@classmethod
|
||||
def validate_node_id_with_scope(cls, v: str | None, info) -> str | None:
|
||||
"""Validate node_id consistency with scope"""
|
||||
scope = info.data.get('scope')
|
||||
if scope == MemoryScope.NODE and v is None:
|
||||
raise ValueError("node_id is required when scope is NODE")
|
||||
if scope == MemoryScope.APP and v is not None:
|
||||
raise ValueError("node_id must be None when scope is APP")
|
||||
return v
|
||||
|
||||
|
||||
class MemoryCreatedBy(BaseModel):
|
||||
|
|
|
|||
|
|
@ -44,13 +44,7 @@ class MemoryConfig(BaseModel):
|
|||
|
||||
enabled: bool
|
||||
size: int | None = None
|
||||
|
||||
mode: Literal["linear", "block"] | None = "linear"
|
||||
block_id: list[str] | None = None
|
||||
role_prefix: RolePrefix | None = None
|
||||
window: WindowConfig
|
||||
query_prompt_template: str | None = None
|
||||
|
||||
@property
|
||||
def is_block_mode(self) -> bool:
|
||||
return self.mode == "block" and bool(self.block_id)
|
||||
|
|
|
|||
|
|
@ -1234,13 +1234,6 @@ class LLMNode(Node):
|
|||
tenant_id=self.tenant_id
|
||||
)
|
||||
|
||||
memory_config = self._node_data.memory
|
||||
if not memory_config:
|
||||
return
|
||||
block_ids = memory_config.block_id
|
||||
if not block_ids:
|
||||
return
|
||||
|
||||
# FIXME: This is dirty workaround and may cause incorrect resolution for workflow version
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(Workflow).where(
|
||||
|
|
@ -1250,24 +1243,31 @@ class LLMNode(Node):
|
|||
workflow = session.scalars(stmt).first()
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found.")
|
||||
memory_blocks = workflow.memory_blocks
|
||||
|
||||
for block_id in block_ids:
|
||||
memory_block_spec = next((block for block in memory_blocks if block.id == block_id), None)
|
||||
# Filter memory blocks that belong to this node
|
||||
node_memory_blocks = [
|
||||
block for block in workflow.memory_blocks
|
||||
if block.scope == MemoryScope.NODE and block.node_id == self.id
|
||||
]
|
||||
|
||||
if memory_block_spec and memory_block_spec.scope == MemoryScope.NODE:
|
||||
is_draft = (self.invoke_from == InvokeFrom.DEBUGGER)
|
||||
from services.chatflow_memory_service import ChatflowMemoryService
|
||||
ChatflowMemoryService.update_node_memory_if_needed(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
node_id=self.id,
|
||||
conversation_id=conversation_id,
|
||||
memory_block_spec=memory_block_spec,
|
||||
variable_pool=variable_pool,
|
||||
is_draft=is_draft,
|
||||
created_by=self._get_user_from_context()
|
||||
)
|
||||
if not node_memory_blocks:
|
||||
return
|
||||
|
||||
# Update each memory block that belongs to this node
|
||||
is_draft = (self.invoke_from == InvokeFrom.DEBUGGER)
|
||||
from services.chatflow_memory_service import ChatflowMemoryService
|
||||
|
||||
for memory_block_spec in node_memory_blocks:
|
||||
ChatflowMemoryService.update_node_memory_if_needed(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
node_id=self.id,
|
||||
conversation_id=conversation_id,
|
||||
memory_block_spec=memory_block_spec,
|
||||
variable_pool=variable_pool,
|
||||
is_draft=is_draft,
|
||||
created_by=self._get_user_from_context()
|
||||
)
|
||||
|
||||
def _get_user_from_context(self) -> MemoryCreatedBy:
|
||||
if self.user_from == UserFrom.ACCOUNT:
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ memory_block_fields = {
|
|||
"model": fields.Nested(model_config_fields),
|
||||
"end_user_visible": fields.Boolean,
|
||||
"end_user_editable": fields.Boolean,
|
||||
"node_id": fields.String,
|
||||
}
|
||||
|
||||
pipeline_variable_fields = {
|
||||
|
|
|
|||
Loading…
Reference in New Issue