diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 33927da9a1..87139409e5 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -106,6 +106,7 @@ class DraftWorkflowApi(Resource): "hash": fields.String(description="Workflow hash for validation"), "environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"), "conversation_variables": fields.List(fields.Raw, description="Conversation variables"), + "memory_blocks": fields.List(fields.Raw, description="Memory blocks"), }, ) ) @@ -131,6 +132,7 @@ class DraftWorkflowApi(Resource): parser.add_argument("environment_variables", type=list, required=True, location="json") parser.add_argument("conversation_variables", type=list, required=False, location="json") parser.add_argument("force_upload", type=bool, required=False, default=False, location="json") + parser.add_argument("memory_blocks", type=list, required=False, location="json") args = parser.parse_args() elif "text/plain" in content_type: try: @@ -147,6 +149,7 @@ class DraftWorkflowApi(Resource): "hash": data.get("hash"), "environment_variables": data.get("environment_variables"), "conversation_variables": data.get("conversation_variables"), + "memory_blocks": data.get("memory_blocks"), "force_upload": data.get("force_upload", False), } except json.JSONDecodeError: @@ -168,6 +171,11 @@ class DraftWorkflowApi(Resource): conversation_variables = [ variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list ] + memory_blocks_list = args.get("memory_blocks") or [] + from core.memory.entities import MemoryBlockSpec + memory_blocks = [ + MemoryBlockSpec.model_validate(obj) for obj in memory_blocks_list + ] workflow = workflow_service.sync_draft_workflow( app_model=app_model, graph=args["graph"], @@ -177,6 +185,7 @@ class DraftWorkflowApi(Resource): environment_variables=environment_variables, conversation_variables=conversation_variables, force_upload=args.get("force_upload", False), + memory_blocks=memory_blocks, ) except WorkflowHashNotEqualError: raise DraftWorkflowNotSync() diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 6f91a9d5ed..3a3d4ae899 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -13,8 +13,9 @@ from core.variables.variables import RAGPipelineVariableInput, VariableUnion, Ve from core.workflow.constants import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, + MEMORY_BLOCK_VARIABLE_NODE_ID, RAG_PIPELINE_VARIABLE_NODE_ID, - SYSTEM_VARIABLE_NODE_ID, MEMORY_BLOCK_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, ) from core.workflow.system_variable import SystemVariable from factories import variable_factory diff --git a/api/models/workflow.py b/api/models/workflow.py index 0499d32a4b..0146e848e2 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -156,6 +156,9 @@ class Workflow(Base): _rag_pipeline_variables: Mapped[str] = mapped_column( "rag_pipeline_variables", db.Text, nullable=False, server_default="{}" ) + _memory_blocks: Mapped[str] = mapped_column( + "memory_blocks", sa.Text, nullable=False, server_default="[]" + ) VERSION_DRAFT = "draft" @@ -173,6 +176,7 @@ class Workflow(Base): environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], rag_pipeline_variables: list[dict], + memory_blocks: Sequence[MemoryBlockSpec] | None = None, marked_name: str = "", marked_comment: str = "", ) -> "Workflow": @@ -188,6 +192,7 @@ class Workflow(Base): workflow.environment_variables = environment_variables or [] workflow.conversation_variables = conversation_variables or [] workflow.rag_pipeline_variables = rag_pipeline_variables or [] + workflow.memory_blocks = memory_blocks or [] workflow.marked_name = marked_name workflow.marked_comment = marked_comment workflow.created_at = naive_utc_now() @@ -447,6 +452,7 @@ class Workflow(Base): "features": self.features_dict, "environment_variables": [var.model_dump(mode="json") for var in environment_variables], "conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables], + "memory_blocks": [block.model_dump(mode="json") for block in self.memory_blocks], } return result @@ -486,15 +492,25 @@ class Workflow(Base): @property def memory_blocks(self) -> Sequence[MemoryBlockSpec]: - """Memory blocks configuration from graph""" + """Memory blocks configuration stored in database""" + if self._memory_blocks is None or self._memory_blocks == "": + self._memory_blocks = "[]" - if not self.graph_dict: - return [] - - memory_blocks_config = self.graph_dict.get('memory_blocks', []) - results = [MemoryBlockSpec.model_validate(config) for config in memory_blocks_config] + memory_blocks_list: list[dict[str, Any]] = json.loads(self._memory_blocks) + results = [MemoryBlockSpec.model_validate(config) for config in memory_blocks_list] return results + @memory_blocks.setter + def memory_blocks(self, value: Sequence[MemoryBlockSpec]): + if not value: + self._memory_blocks = "[]" + return + + self._memory_blocks = json.dumps( + [block.model_dump() for block in value], + ensure_ascii=False, + ) + @staticmethod def version_from_datetime(d: datetime) -> str: return str(d) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 7b3adc97a8..f8eafd7b87 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -11,7 +11,7 @@ from core.app.app_config.entities import VariableEntityType from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.file import File -from core.memory.entities import MemoryCreatedBy, MemoryScope +from core.memory.entities import MemoryBlockSpec, MemoryCreatedBy, MemoryScope from core.repositories import DifyCoreRepositoryFactory from core.variables import Variable from core.variables.variables import VariableUnion @@ -197,6 +197,7 @@ class WorkflowService: account: Account, environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], + memory_blocks: Sequence[MemoryBlockSpec] | None = None, force_upload: bool = False, ) -> Workflow: """ @@ -226,6 +227,7 @@ class WorkflowService: environment_variables=environment_variables, conversation_variables=conversation_variables, ) + workflow.memory_blocks = memory_blocks or [] db.session.add(workflow) # update draft workflow if found else: @@ -235,6 +237,7 @@ class WorkflowService: workflow.updated_at = naive_utc_now() workflow.environment_variables = environment_variables workflow.conversation_variables = conversation_variables + workflow.memory_blocks = memory_blocks or [] # commit db session changes db.session.commit() @@ -354,6 +357,7 @@ class WorkflowService: marked_name=marked_name, marked_comment=marked_comment, rag_pipeline_variables=draft_workflow.rag_pipeline_variables, + memory_blocks=draft_workflow.memory_blocks, features=draft_workflow.features, )