fix: store memory_blocks in correct field

This commit is contained in:
Stream 2025-10-15 16:56:12 +08:00
parent 7ca06931ec
commit f4fa57dac9
No known key found for this signature in database
GPG Key ID: 033728094B100D70
3 changed files with 36 additions and 7 deletions

View File

@ -103,6 +103,7 @@ class DraftWorkflowApi(Resource):
"hash": fields.String(description="Workflow hash for validation"),
"environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"),
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
"memory_blocks": fields.List(fields.Raw, description="Memory blocks"),
},
)
)
@ -127,6 +128,7 @@ class DraftWorkflowApi(Resource):
parser.add_argument("hash", type=str, required=False, location="json")
parser.add_argument("environment_variables", type=list, required=True, location="json")
parser.add_argument("conversation_variables", type=list, required=False, location="json")
parser.add_argument("memory_blocks", type=list, required=False, location="json")
args = parser.parse_args()
elif "text/plain" in content_type:
try:
@ -143,6 +145,7 @@ class DraftWorkflowApi(Resource):
"hash": data.get("hash"),
"environment_variables": data.get("environment_variables"),
"conversation_variables": data.get("conversation_variables"),
"memory_blocks": data.get("memory_blocks"),
}
except json.JSONDecodeError:
return {"message": "Invalid JSON data"}, 400
@ -163,6 +166,11 @@ class DraftWorkflowApi(Resource):
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
memory_blocks_list = args.get("memory_blocks") or []
from core.memory.entities import MemoryBlockSpec
memory_blocks = [
MemoryBlockSpec.model_validate(obj) for obj in memory_blocks_list
]
workflow = workflow_service.sync_draft_workflow(
app_model=app_model,
graph=args["graph"],
@ -171,6 +179,7 @@ class DraftWorkflowApi(Resource):
account=current_user,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
memory_blocks=memory_blocks,
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()

View File

@ -156,6 +156,9 @@ class Workflow(Base):
_rag_pipeline_variables: Mapped[str] = mapped_column(
"rag_pipeline_variables", db.Text, nullable=False, server_default="{}"
)
_memory_blocks: Mapped[str] = mapped_column(
"memory_blocks", sa.Text, nullable=False, server_default="[]"
)
VERSION_DRAFT = "draft"
@ -173,6 +176,7 @@ class Workflow(Base):
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
rag_pipeline_variables: list[dict],
memory_blocks: Sequence[MemoryBlockSpec] | None = None,
marked_name: str = "",
marked_comment: str = "",
) -> "Workflow":
@ -188,6 +192,7 @@ class Workflow(Base):
workflow.environment_variables = environment_variables or []
workflow.conversation_variables = conversation_variables or []
workflow.rag_pipeline_variables = rag_pipeline_variables or []
workflow.memory_blocks = memory_blocks or []
workflow.marked_name = marked_name
workflow.marked_comment = marked_comment
workflow.created_at = naive_utc_now()
@ -447,6 +452,7 @@ class Workflow(Base):
"features": self.features_dict,
"environment_variables": [var.model_dump(mode="json") for var in environment_variables],
"conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables],
"memory_blocks": [block.model_dump(mode="json") for block in self.memory_blocks],
}
return result
@ -486,15 +492,25 @@ class Workflow(Base):
@property
def memory_blocks(self) -> Sequence[MemoryBlockSpec]:
"""Memory blocks configuration from graph"""
"""Memory blocks configuration stored in database"""
if self._memory_blocks is None or self._memory_blocks == "":
self._memory_blocks = "[]"
if not self.graph_dict:
return []
memory_blocks_config = self.graph_dict.get('memory_blocks', [])
results = [MemoryBlockSpec.model_validate(config) for config in memory_blocks_config]
memory_blocks_list: list[dict[str, Any]] = json.loads(self._memory_blocks)
results = [MemoryBlockSpec.model_validate(config) for config in memory_blocks_list]
return results
@memory_blocks.setter
def memory_blocks(self, value: Sequence[MemoryBlockSpec]):
if not value:
self._memory_blocks = "[]"
return
self._memory_blocks = json.dumps(
[block.model_dump() for block in value],
ensure_ascii=False,
)
@staticmethod
def version_from_datetime(d: datetime) -> str:
return str(d)

View File

@ -11,7 +11,7 @@ from core.app.app_config.entities import VariableEntityType
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File
from core.memory.entities import MemoryCreatedBy, MemoryScope
from core.memory.entities import MemoryCreatedBy, MemoryScope, MemoryBlockSpec
from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
from core.variables.variables import VariableUnion
@ -197,6 +197,7 @@ class WorkflowService:
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
memory_blocks: Sequence[MemoryBlockSpec] | None = None,
) -> Workflow:
"""
Sync draft workflow
@ -224,6 +225,7 @@ class WorkflowService:
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
workflow.memory_blocks = memory_blocks or []
db.session.add(workflow)
# update draft workflow if found
else:
@ -233,6 +235,7 @@ class WorkflowService:
workflow.updated_at = naive_utc_now()
workflow.environment_variables = environment_variables
workflow.conversation_variables = conversation_variables
workflow.memory_blocks = memory_blocks or []
# commit db session changes
db.session.commit()
@ -280,6 +283,7 @@ class WorkflowService:
marked_name=marked_name,
marked_comment=marked_comment,
rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
memory_blocks=draft_workflow.memory_blocks,
features=draft_workflow.features,
)