fix: store memory_blocks in correct field

This commit is contained in:
Stream 2025-10-15 16:56:12 +08:00
parent 7ca06931ec
commit f4fa57dac9
No known key found for this signature in database
GPG Key ID: 033728094B100D70
3 changed files with 36 additions and 7 deletions

View File

@ -103,6 +103,7 @@ class DraftWorkflowApi(Resource):
"hash": fields.String(description="Workflow hash for validation"), "hash": fields.String(description="Workflow hash for validation"),
"environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"), "environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"),
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"), "conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
"memory_blocks": fields.List(fields.Raw, description="Memory blocks"),
}, },
) )
) )
@ -127,6 +128,7 @@ class DraftWorkflowApi(Resource):
parser.add_argument("hash", type=str, required=False, location="json") parser.add_argument("hash", type=str, required=False, location="json")
parser.add_argument("environment_variables", type=list, required=True, location="json") parser.add_argument("environment_variables", type=list, required=True, location="json")
parser.add_argument("conversation_variables", type=list, required=False, location="json") parser.add_argument("conversation_variables", type=list, required=False, location="json")
parser.add_argument("memory_blocks", type=list, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
elif "text/plain" in content_type: elif "text/plain" in content_type:
try: try:
@ -143,6 +145,7 @@ class DraftWorkflowApi(Resource):
"hash": data.get("hash"), "hash": data.get("hash"),
"environment_variables": data.get("environment_variables"), "environment_variables": data.get("environment_variables"),
"conversation_variables": data.get("conversation_variables"), "conversation_variables": data.get("conversation_variables"),
"memory_blocks": data.get("memory_blocks"),
} }
except json.JSONDecodeError: except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400 return {"message": "Invalid JSON data"}, 400
@ -163,6 +166,11 @@ class DraftWorkflowApi(Resource):
conversation_variables = [ conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
] ]
memory_blocks_list = args.get("memory_blocks") or []
from core.memory.entities import MemoryBlockSpec
memory_blocks = [
MemoryBlockSpec.model_validate(obj) for obj in memory_blocks_list
]
workflow = workflow_service.sync_draft_workflow( workflow = workflow_service.sync_draft_workflow(
app_model=app_model, app_model=app_model,
graph=args["graph"], graph=args["graph"],
@ -171,6 +179,7 @@ class DraftWorkflowApi(Resource):
account=current_user, account=current_user,
environment_variables=environment_variables, environment_variables=environment_variables,
conversation_variables=conversation_variables, conversation_variables=conversation_variables,
memory_blocks=memory_blocks,
) )
except WorkflowHashNotEqualError: except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync() raise DraftWorkflowNotSync()

View File

@ -156,6 +156,9 @@ class Workflow(Base):
_rag_pipeline_variables: Mapped[str] = mapped_column( _rag_pipeline_variables: Mapped[str] = mapped_column(
"rag_pipeline_variables", db.Text, nullable=False, server_default="{}" "rag_pipeline_variables", db.Text, nullable=False, server_default="{}"
) )
_memory_blocks: Mapped[str] = mapped_column(
"memory_blocks", sa.Text, nullable=False, server_default="[]"
)
VERSION_DRAFT = "draft" VERSION_DRAFT = "draft"
@ -173,6 +176,7 @@ class Workflow(Base):
environment_variables: Sequence[Variable], environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable], conversation_variables: Sequence[Variable],
rag_pipeline_variables: list[dict], rag_pipeline_variables: list[dict],
memory_blocks: Sequence[MemoryBlockSpec] | None = None,
marked_name: str = "", marked_name: str = "",
marked_comment: str = "", marked_comment: str = "",
) -> "Workflow": ) -> "Workflow":
@ -188,6 +192,7 @@ class Workflow(Base):
workflow.environment_variables = environment_variables or [] workflow.environment_variables = environment_variables or []
workflow.conversation_variables = conversation_variables or [] workflow.conversation_variables = conversation_variables or []
workflow.rag_pipeline_variables = rag_pipeline_variables or [] workflow.rag_pipeline_variables = rag_pipeline_variables or []
workflow.memory_blocks = memory_blocks or []
workflow.marked_name = marked_name workflow.marked_name = marked_name
workflow.marked_comment = marked_comment workflow.marked_comment = marked_comment
workflow.created_at = naive_utc_now() workflow.created_at = naive_utc_now()
@ -447,6 +452,7 @@ class Workflow(Base):
"features": self.features_dict, "features": self.features_dict,
"environment_variables": [var.model_dump(mode="json") for var in environment_variables], "environment_variables": [var.model_dump(mode="json") for var in environment_variables],
"conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables], "conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables],
"memory_blocks": [block.model_dump(mode="json") for block in self.memory_blocks],
} }
return result return result
@ -486,15 +492,25 @@ class Workflow(Base):
@property @property
def memory_blocks(self) -> Sequence[MemoryBlockSpec]: def memory_blocks(self) -> Sequence[MemoryBlockSpec]:
"""Memory blocks configuration from graph""" """Memory blocks configuration stored in database"""
if self._memory_blocks is None or self._memory_blocks == "":
self._memory_blocks = "[]"
if not self.graph_dict: memory_blocks_list: list[dict[str, Any]] = json.loads(self._memory_blocks)
return [] results = [MemoryBlockSpec.model_validate(config) for config in memory_blocks_list]
memory_blocks_config = self.graph_dict.get('memory_blocks', [])
results = [MemoryBlockSpec.model_validate(config) for config in memory_blocks_config]
return results return results
@memory_blocks.setter
def memory_blocks(self, value: Sequence[MemoryBlockSpec]):
if not value:
self._memory_blocks = "[]"
return
self._memory_blocks = json.dumps(
[block.model_dump() for block in value],
ensure_ascii=False,
)
@staticmethod @staticmethod
def version_from_datetime(d: datetime) -> str: def version_from_datetime(d: datetime) -> str:
return str(d) return str(d)

View File

@ -11,7 +11,7 @@ from core.app.app_config.entities import VariableEntityType
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File from core.file import File
from core.memory.entities import MemoryCreatedBy, MemoryScope from core.memory.entities import MemoryCreatedBy, MemoryScope, MemoryBlockSpec
from core.repositories import DifyCoreRepositoryFactory from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable from core.variables import Variable
from core.variables.variables import VariableUnion from core.variables.variables import VariableUnion
@ -197,6 +197,7 @@ class WorkflowService:
account: Account, account: Account,
environment_variables: Sequence[Variable], environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable], conversation_variables: Sequence[Variable],
memory_blocks: Sequence[MemoryBlockSpec] | None = None,
) -> Workflow: ) -> Workflow:
""" """
Sync draft workflow Sync draft workflow
@ -224,6 +225,7 @@ class WorkflowService:
environment_variables=environment_variables, environment_variables=environment_variables,
conversation_variables=conversation_variables, conversation_variables=conversation_variables,
) )
workflow.memory_blocks = memory_blocks or []
db.session.add(workflow) db.session.add(workflow)
# update draft workflow if found # update draft workflow if found
else: else:
@ -233,6 +235,7 @@ class WorkflowService:
workflow.updated_at = naive_utc_now() workflow.updated_at = naive_utc_now()
workflow.environment_variables = environment_variables workflow.environment_variables = environment_variables
workflow.conversation_variables = conversation_variables workflow.conversation_variables = conversation_variables
workflow.memory_blocks = memory_blocks or []
# commit db session changes # commit db session changes
db.session.commit() db.session.commit()
@ -280,6 +283,7 @@ class WorkflowService:
marked_name=marked_name, marked_name=marked_name,
marked_comment=marked_comment, marked_comment=marked_comment,
rag_pipeline_variables=draft_workflow.rag_pipeline_variables, rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
memory_blocks=draft_workflow.memory_blocks,
features=draft_workflow.features, features=draft_workflow.features,
) )