diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 1f5cbbeca5..3670f6c920 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -103,6 +103,7 @@ class DraftWorkflowApi(Resource): "hash": fields.String(description="Workflow hash for validation"), "environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"), "conversation_variables": fields.List(fields.Raw, description="Conversation variables"), + "memory_blocks": fields.List(fields.Raw, description="Memory blocks"), }, ) ) @@ -127,6 +128,7 @@ class DraftWorkflowApi(Resource): parser.add_argument("hash", type=str, required=False, location="json") parser.add_argument("environment_variables", type=list, required=True, location="json") parser.add_argument("conversation_variables", type=list, required=False, location="json") + parser.add_argument("memory_blocks", type=list, required=False, location="json") args = parser.parse_args() elif "text/plain" in content_type: try: @@ -143,6 +145,7 @@ class DraftWorkflowApi(Resource): "hash": data.get("hash"), "environment_variables": data.get("environment_variables"), "conversation_variables": data.get("conversation_variables"), + "memory_blocks": data.get("memory_blocks"), } except json.JSONDecodeError: return {"message": "Invalid JSON data"}, 400 @@ -163,6 +166,11 @@ class DraftWorkflowApi(Resource): conversation_variables = [ variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list ] + memory_blocks_list = args.get("memory_blocks") or [] + from core.memory.entities import MemoryBlockSpec + memory_blocks = [ + MemoryBlockSpec.model_validate(obj) for obj in memory_blocks_list + ] workflow = workflow_service.sync_draft_workflow( app_model=app_model, graph=args["graph"], @@ -171,6 +179,7 @@ class DraftWorkflowApi(Resource): account=current_user, environment_variables=environment_variables, conversation_variables=conversation_variables, + memory_blocks=memory_blocks, ) except WorkflowHashNotEqualError: raise DraftWorkflowNotSync() diff --git a/api/models/workflow.py b/api/models/workflow.py index 49f72b61dc..505bd606a5 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -156,6 +156,9 @@ class Workflow(Base): _rag_pipeline_variables: Mapped[str] = mapped_column( "rag_pipeline_variables", db.Text, nullable=False, server_default="{}" ) + _memory_blocks: Mapped[str] = mapped_column( + "memory_blocks", sa.Text, nullable=False, server_default="[]" + ) VERSION_DRAFT = "draft" @@ -173,6 +176,7 @@ class Workflow(Base): environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], rag_pipeline_variables: list[dict], + memory_blocks: Sequence[MemoryBlockSpec] | None = None, marked_name: str = "", marked_comment: str = "", ) -> "Workflow": @@ -188,6 +192,7 @@ class Workflow(Base): workflow.environment_variables = environment_variables or [] workflow.conversation_variables = conversation_variables or [] workflow.rag_pipeline_variables = rag_pipeline_variables or [] + workflow.memory_blocks = memory_blocks or [] workflow.marked_name = marked_name workflow.marked_comment = marked_comment workflow.created_at = naive_utc_now() @@ -447,6 +452,7 @@ class Workflow(Base): "features": self.features_dict, "environment_variables": [var.model_dump(mode="json") for var in environment_variables], "conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables], + "memory_blocks": [block.model_dump(mode="json") for block in self.memory_blocks], } return result @@ -486,15 +492,25 @@ class Workflow(Base): @property def memory_blocks(self) -> Sequence[MemoryBlockSpec]: - """Memory blocks configuration from graph""" + """Memory blocks configuration stored in database""" + if self._memory_blocks is None or self._memory_blocks == "": + self._memory_blocks = "[]" - if not self.graph_dict: - return [] - - memory_blocks_config = self.graph_dict.get('memory_blocks', []) - results = [MemoryBlockSpec.model_validate(config) for config in memory_blocks_config] + memory_blocks_list: list[dict[str, Any]] = json.loads(self._memory_blocks) + results = [MemoryBlockSpec.model_validate(config) for config in memory_blocks_list] return results + @memory_blocks.setter + def memory_blocks(self, value: Sequence[MemoryBlockSpec]): + if not value: + self._memory_blocks = "[]" + return + + self._memory_blocks = json.dumps( + [block.model_dump() for block in value], + ensure_ascii=False, + ) + @staticmethod def version_from_datetime(d: datetime) -> str: return str(d) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 5a2b6e8f56..19e85217e9 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -11,7 +11,7 @@ from core.app.app_config.entities import VariableEntityType from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.file import File -from core.memory.entities import MemoryCreatedBy, MemoryScope +from core.memory.entities import MemoryCreatedBy, MemoryScope, MemoryBlockSpec from core.repositories import DifyCoreRepositoryFactory from core.variables import Variable from core.variables.variables import VariableUnion @@ -197,6 +197,7 @@ class WorkflowService: account: Account, environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], + memory_blocks: Sequence[MemoryBlockSpec] | None = None, ) -> Workflow: """ Sync draft workflow @@ -224,6 +225,7 @@ class WorkflowService: environment_variables=environment_variables, conversation_variables=conversation_variables, ) + workflow.memory_blocks = memory_blocks or [] db.session.add(workflow) # update draft workflow if found else: @@ -233,6 +235,7 @@ class WorkflowService: workflow.updated_at = naive_utc_now() workflow.environment_variables = environment_variables workflow.conversation_variables = conversation_variables + workflow.memory_blocks = memory_blocks or [] # commit db session changes db.session.commit() @@ -280,6 +283,7 @@ class WorkflowService: marked_name=marked_name, marked_comment=marked_comment, rag_pipeline_variables=draft_workflow.rag_pipeline_variables, + memory_blocks=draft_workflow.memory_blocks, features=draft_workflow.features, )