update workflow app bind datasets

This commit is contained in:
takatost 2024-03-21 17:06:35 +08:00
parent fa673f9b4c
commit 95c5848d05
4 changed files with 81 additions and 2 deletions

View File

@ -6,5 +6,8 @@ app_was_created = signal('app-was-created')
# sender: app
app_was_deleted = signal('app-was-deleted')
# sender: app, kwargs: old_app_model_config, new_app_model_config
# sender: app, kwargs: app_model_config
app_model_config_was_updated = signal('app-model-config-was-updated')
# sender: app, kwargs: published_workflow
app_published_workflow_was_updated = signal('app-published-workflow-was-updated')

View File

@ -8,3 +8,4 @@ from .delete_installed_app_when_app_deleted import handle
from .generate_conversation_name_when_first_message_created import handle
from .update_app_dataset_join_when_app_model_config_updated import handle
from .update_provider_last_used_at_when_messaeg_created import handle
from .update_app_dataset_join_when_app_published_workflow_updated import handle

View File

@ -0,0 +1,73 @@
from typing import cast
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
from events.app_event import app_published_workflow_was_updated
from extensions.ext_database import db
from models.dataset import AppDatasetJoin
from models.workflow import Workflow
@app_published_workflow_was_updated.connect
def handle(sender, **kwargs):
app = sender
published_workflow = kwargs.get('published_workflow')
published_workflow = cast(Workflow, published_workflow)
dataset_ids = get_dataset_ids_from_workflow(published_workflow)
app_dataset_joins = db.session.query(AppDatasetJoin).filter(
AppDatasetJoin.app_id == app.id
).all()
removed_dataset_ids = []
if not app_dataset_joins:
added_dataset_ids = dataset_ids
else:
old_dataset_ids = set()
for app_dataset_join in app_dataset_joins:
old_dataset_ids.add(app_dataset_join.dataset_id)
added_dataset_ids = dataset_ids - old_dataset_ids
removed_dataset_ids = old_dataset_ids - dataset_ids
if removed_dataset_ids:
for dataset_id in removed_dataset_ids:
db.session.query(AppDatasetJoin).filter(
AppDatasetJoin.app_id == app.id,
AppDatasetJoin.dataset_id == dataset_id
).delete()
if added_dataset_ids:
for dataset_id in added_dataset_ids:
app_dataset_join = AppDatasetJoin(
app_id=app.id,
dataset_id=dataset_id
)
db.session.add(app_dataset_join)
db.session.commit()
def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set:
dataset_ids = set()
graph = published_workflow.graph_dict
if not graph:
return dataset_ids
nodes = graph.get('nodes', [])
# fetch all knowledge retrieval nodes
knowledge_retrieval_nodes = [node for node in nodes
if node.get('data', {}).get('type') == NodeType.KNOWLEDGE_RETRIEVAL.value]
if not knowledge_retrieval_nodes:
return dataset_ids
for node in knowledge_retrieval_nodes:
try:
node_data = KnowledgeRetrievalNodeData(**node.get('data', {}))
dataset_ids.update(node_data.dataset_ids)
except Exception as e:
continue
return dataset_ids

View File

@ -9,6 +9,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import NodeType
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.workflow_engine_manager import WorkflowEngineManager
from events.app_event import app_published_workflow_was_updated
from extensions.ext_database import db
from models.account import Account
from models.model import App, AppMode
@ -138,7 +139,8 @@ class WorkflowService:
app_model.workflow_id = workflow.id
db.session.commit()
# TODO update app related datasets
# trigger app workflow events
app_published_workflow_was_updated.send(app_model, published_workflow=workflow)
# return new workflow
return workflow