mirror of https://github.com/langgenius/dify.git
fix: store memory_blocks in correct field
This commit is contained in:
parent
7ca06931ec
commit
f4fa57dac9
|
|
@ -103,6 +103,7 @@ class DraftWorkflowApi(Resource):
|
|||
"hash": fields.String(description="Workflow hash for validation"),
|
||||
"environment_variables": fields.List(fields.Raw, required=True, description="Environment variables"),
|
||||
"conversation_variables": fields.List(fields.Raw, description="Conversation variables"),
|
||||
"memory_blocks": fields.List(fields.Raw, description="Memory blocks"),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
@ -127,6 +128,7 @@ class DraftWorkflowApi(Resource):
|
|||
parser.add_argument("hash", type=str, required=False, location="json")
|
||||
parser.add_argument("environment_variables", type=list, required=True, location="json")
|
||||
parser.add_argument("conversation_variables", type=list, required=False, location="json")
|
||||
parser.add_argument("memory_blocks", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
elif "text/plain" in content_type:
|
||||
try:
|
||||
|
|
@ -143,6 +145,7 @@ class DraftWorkflowApi(Resource):
|
|||
"hash": data.get("hash"),
|
||||
"environment_variables": data.get("environment_variables"),
|
||||
"conversation_variables": data.get("conversation_variables"),
|
||||
"memory_blocks": data.get("memory_blocks"),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
return {"message": "Invalid JSON data"}, 400
|
||||
|
|
@ -163,6 +166,11 @@ class DraftWorkflowApi(Resource):
|
|||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
memory_blocks_list = args.get("memory_blocks") or []
|
||||
from core.memory.entities import MemoryBlockSpec
|
||||
memory_blocks = [
|
||||
MemoryBlockSpec.model_validate(obj) for obj in memory_blocks_list
|
||||
]
|
||||
workflow = workflow_service.sync_draft_workflow(
|
||||
app_model=app_model,
|
||||
graph=args["graph"],
|
||||
|
|
@ -171,6 +179,7 @@ class DraftWorkflowApi(Resource):
|
|||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
memory_blocks=memory_blocks,
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
|
|
|
|||
|
|
@ -156,6 +156,9 @@ class Workflow(Base):
|
|||
_rag_pipeline_variables: Mapped[str] = mapped_column(
|
||||
"rag_pipeline_variables", db.Text, nullable=False, server_default="{}"
|
||||
)
|
||||
_memory_blocks: Mapped[str] = mapped_column(
|
||||
"memory_blocks", sa.Text, nullable=False, server_default="[]"
|
||||
)
|
||||
|
||||
VERSION_DRAFT = "draft"
|
||||
|
||||
|
|
@ -173,6 +176,7 @@ class Workflow(Base):
|
|||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
rag_pipeline_variables: list[dict],
|
||||
memory_blocks: Sequence[MemoryBlockSpec] | None = None,
|
||||
marked_name: str = "",
|
||||
marked_comment: str = "",
|
||||
) -> "Workflow":
|
||||
|
|
@ -188,6 +192,7 @@ class Workflow(Base):
|
|||
workflow.environment_variables = environment_variables or []
|
||||
workflow.conversation_variables = conversation_variables or []
|
||||
workflow.rag_pipeline_variables = rag_pipeline_variables or []
|
||||
workflow.memory_blocks = memory_blocks or []
|
||||
workflow.marked_name = marked_name
|
||||
workflow.marked_comment = marked_comment
|
||||
workflow.created_at = naive_utc_now()
|
||||
|
|
@ -447,6 +452,7 @@ class Workflow(Base):
|
|||
"features": self.features_dict,
|
||||
"environment_variables": [var.model_dump(mode="json") for var in environment_variables],
|
||||
"conversation_variables": [var.model_dump(mode="json") for var in self.conversation_variables],
|
||||
"memory_blocks": [block.model_dump(mode="json") for block in self.memory_blocks],
|
||||
}
|
||||
return result
|
||||
|
||||
|
|
@ -486,15 +492,25 @@ class Workflow(Base):
|
|||
|
||||
@property
|
||||
def memory_blocks(self) -> Sequence[MemoryBlockSpec]:
|
||||
"""Memory blocks configuration from graph"""
|
||||
"""Memory blocks configuration stored in database"""
|
||||
if self._memory_blocks is None or self._memory_blocks == "":
|
||||
self._memory_blocks = "[]"
|
||||
|
||||
if not self.graph_dict:
|
||||
return []
|
||||
|
||||
memory_blocks_config = self.graph_dict.get('memory_blocks', [])
|
||||
results = [MemoryBlockSpec.model_validate(config) for config in memory_blocks_config]
|
||||
memory_blocks_list: list[dict[str, Any]] = json.loads(self._memory_blocks)
|
||||
results = [MemoryBlockSpec.model_validate(config) for config in memory_blocks_list]
|
||||
return results
|
||||
|
||||
@memory_blocks.setter
|
||||
def memory_blocks(self, value: Sequence[MemoryBlockSpec]):
|
||||
if not value:
|
||||
self._memory_blocks = "[]"
|
||||
return
|
||||
|
||||
self._memory_blocks = json.dumps(
|
||||
[block.model_dump() for block in value],
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def version_from_datetime(d: datetime) -> str:
|
||||
return str(d)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from core.app.app_config.entities import VariableEntityType
|
|||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.file import File
|
||||
from core.memory.entities import MemoryCreatedBy, MemoryScope
|
||||
from core.memory.entities import MemoryCreatedBy, MemoryScope, MemoryBlockSpec
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.variables import Variable
|
||||
from core.variables.variables import VariableUnion
|
||||
|
|
@ -197,6 +197,7 @@ class WorkflowService:
|
|||
account: Account,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
memory_blocks: Sequence[MemoryBlockSpec] | None = None,
|
||||
) -> Workflow:
|
||||
"""
|
||||
Sync draft workflow
|
||||
|
|
@ -224,6 +225,7 @@ class WorkflowService:
|
|||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
workflow.memory_blocks = memory_blocks or []
|
||||
db.session.add(workflow)
|
||||
# update draft workflow if found
|
||||
else:
|
||||
|
|
@ -233,6 +235,7 @@ class WorkflowService:
|
|||
workflow.updated_at = naive_utc_now()
|
||||
workflow.environment_variables = environment_variables
|
||||
workflow.conversation_variables = conversation_variables
|
||||
workflow.memory_blocks = memory_blocks or []
|
||||
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
|
|
@ -280,6 +283,7 @@ class WorkflowService:
|
|||
marked_name=marked_name,
|
||||
marked_comment=marked_comment,
|
||||
rag_pipeline_variables=draft_workflow.rag_pipeline_variables,
|
||||
memory_blocks=draft_workflow.memory_blocks,
|
||||
features=draft_workflow.features,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue