From 95c5848d05b4e0a881eb61255bd0b902a954a77c Mon Sep 17 00:00:00 2001 From: takatost Date: Thu, 21 Mar 2024 17:06:35 +0800 Subject: [PATCH] update workflow app bind datasets --- api/events/app_event.py | 5 +- api/events/event_handlers/__init__.py | 1 + ...oin_when_app_published_workflow_updated.py | 73 +++++++++++++++++++ api/services/workflow_service.py | 4 +- 4 files changed, 81 insertions(+), 2 deletions(-) create mode 100644 api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py diff --git a/api/events/app_event.py b/api/events/app_event.py index 938478d3b7..8dbf34cbd1 100644 --- a/api/events/app_event.py +++ b/api/events/app_event.py @@ -6,5 +6,8 @@ app_was_created = signal('app-was-created') # sender: app app_was_deleted = signal('app-was-deleted') -# sender: app, kwargs: old_app_model_config, new_app_model_config +# sender: app, kwargs: app_model_config app_model_config_was_updated = signal('app-model-config-was-updated') + +# sender: app, kwargs: published_workflow +app_published_workflow_was_updated = signal('app-published-workflow-was-updated') diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py index fdfb401bd4..e0f3b84990 100644 --- a/api/events/event_handlers/__init__.py +++ b/api/events/event_handlers/__init__.py @@ -8,3 +8,4 @@ from .delete_installed_app_when_app_deleted import handle from .generate_conversation_name_when_first_message_created import handle from .update_app_dataset_join_when_app_model_config_updated import handle from .update_provider_last_used_at_when_messaeg_created import handle +from .update_app_dataset_join_when_app_published_workflow_updated import handle diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py new file mode 100644 index 0000000000..996b1e9691 --- /dev/null +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -0,0 +1,73 @@ +from typing import cast + +from core.workflow.entities.node_entities import NodeType +from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from events.app_event import app_published_workflow_was_updated +from extensions.ext_database import db +from models.dataset import AppDatasetJoin +from models.workflow import Workflow + + +@app_published_workflow_was_updated.connect +def handle(sender, **kwargs): + app = sender + published_workflow = kwargs.get('published_workflow') + published_workflow = cast(Workflow, published_workflow) + + dataset_ids = get_dataset_ids_from_workflow(published_workflow) + app_dataset_joins = db.session.query(AppDatasetJoin).filter( + AppDatasetJoin.app_id == app.id + ).all() + + removed_dataset_ids = [] + if not app_dataset_joins: + added_dataset_ids = dataset_ids + else: + old_dataset_ids = set() + for app_dataset_join in app_dataset_joins: + old_dataset_ids.add(app_dataset_join.dataset_id) + + added_dataset_ids = dataset_ids - old_dataset_ids + removed_dataset_ids = old_dataset_ids - dataset_ids + + if removed_dataset_ids: + for dataset_id in removed_dataset_ids: + db.session.query(AppDatasetJoin).filter( + AppDatasetJoin.app_id == app.id, + AppDatasetJoin.dataset_id == dataset_id + ).delete() + + if added_dataset_ids: + for dataset_id in added_dataset_ids: + app_dataset_join = AppDatasetJoin( + app_id=app.id, + dataset_id=dataset_id + ) + db.session.add(app_dataset_join) + + db.session.commit() + + +def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set: + dataset_ids = set() + graph = published_workflow.graph_dict + if not graph: + return dataset_ids + + nodes = graph.get('nodes', []) + + # fetch all knowledge retrieval nodes + knowledge_retrieval_nodes = [node for node in nodes + if node.get('data', {}).get('type') == NodeType.KNOWLEDGE_RETRIEVAL.value] + + if not knowledge_retrieval_nodes: + return dataset_ids + + for node in knowledge_retrieval_nodes: + try: + node_data = KnowledgeRetrievalNodeData(**node.get('data', {})) + dataset_ids.update(node_data.dataset_ids) + except Exception as e: + continue + + return dataset_ids diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index a2cc7448e5..ecbe9721a9 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -9,6 +9,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.entities.node_entities import NodeType from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.workflow_engine_manager import WorkflowEngineManager +from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db from models.account import Account from models.model import App, AppMode @@ -138,7 +139,8 @@ class WorkflowService: app_model.workflow_id = workflow.id db.session.commit() - # TODO update app related datasets + # trigger app workflow events + app_published_workflow_was_updated.send(app_model, published_workflow=workflow) # return new workflow return workflow