From 68c10a1672f51c3f3ddb45f1e20615fc15a0cb99 Mon Sep 17 00:00:00 2001 From: Yeuoly Date: Tue, 24 Sep 2024 18:03:48 +0800 Subject: [PATCH] feat: add backwards invoke node api --- api/controllers/inner_api/plugin/plugin.py | 70 ++++++++---- api/core/plugin/backwards_invocation/node.py | 114 +++++++++++++++++++ api/core/plugin/entities/request.py | 30 ++++- api/core/workflow/workflow_entry.py | 82 +++++++++++++ api/services/workflow_service.py | 81 ++++++++++--- 5 files changed, 335 insertions(+), 42 deletions(-) create mode 100644 api/core/plugin/backwards_invocation/node.py diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 4c28e6acb3..ae35332689 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -8,13 +8,15 @@ from controllers.inner_api.plugin.wraps import get_tenant, plugin_data from controllers.inner_api.wraps import plugin_inner_api_only from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation +from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation from core.plugin.encrypt import PluginEncrypter from core.plugin.entities.request import ( RequestInvokeApp, RequestInvokeEncrypt, RequestInvokeLLM, RequestInvokeModeration, - RequestInvokeNode, + RequestInvokeParameterExtractorNode, + RequestInvokeQuestionClassifierNode, RequestInvokeRerank, RequestInvokeSpeech2Text, RequestInvokeTextEmbedding, @@ -96,23 +98,46 @@ class PluginInvokeToolApi(Resource): yield ( ToolInvokeMessage( type=ToolInvokeMessage.MessageType.TEXT, - message=ToolInvokeMessage.TextMessage(text='helloworld'), + message=ToolInvokeMessage.TextMessage(text="helloworld"), ) .model_dump_json() .encode() - + b'\n\n' + + b"\n\n" ) return compact_generate_response(generator()) -class PluginInvokeNodeApi(Resource): +class PluginInvokeParameterExtractorNodeApi(Resource): @setup_required @plugin_inner_api_only @get_tenant - @plugin_data(payload_type=RequestInvokeNode) - def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeNode): - pass + @plugin_data(payload_type=RequestInvokeParameterExtractorNode) + def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode): + return PluginNodeBackwardsInvocation.invoke_parameter_extractor( + tenant_id=tenant_model.id, + user_id=user_id, + parameters=payload.parameters, + model_config=payload.model, + instruction=payload.instruction, + query=payload.query, + ) + + +class PluginInvokeQuestionClassifierNodeApi(Resource): + @setup_required + @plugin_inner_api_only + @get_tenant + @plugin_data(payload_type=RequestInvokeQuestionClassifierNode) + def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode): + return PluginNodeBackwardsInvocation.invoke_question_classifier( + tenant_id=tenant_model.id, + user_id=user_id, + query=payload.query, + model_config=payload.model, + classes=payload.classes, + instruction=payload.instruction, + ) class PluginInvokeAppApi(Resource): @@ -127,15 +152,13 @@ class PluginInvokeAppApi(Resource): tenant_id=tenant_model.id, conversation_id=payload.conversation_id, query=payload.query, - stream=payload.response_mode == 'streaming', + stream=payload.response_mode == "streaming", inputs=payload.inputs, - files=payload.files - ) - - return compact_generate_response( - PluginAppBackwardsInvocation.convert_to_event_stream(response) + files=payload.files, ) + return compact_generate_response(PluginAppBackwardsInvocation.convert_to_event_stream(response)) + class PluginInvokeEncryptApi(Resource): @setup_required @@ -149,13 +172,14 @@ class PluginInvokeEncryptApi(Resource): return PluginEncrypter.invoke_encrypt(tenant_model, payload) -api.add_resource(PluginInvokeLLMApi, '/invoke/llm') -api.add_resource(PluginInvokeTextEmbeddingApi, '/invoke/text-embedding') -api.add_resource(PluginInvokeRerankApi, '/invoke/rerank') -api.add_resource(PluginInvokeTTSApi, '/invoke/tts') -api.add_resource(PluginInvokeSpeech2TextApi, '/invoke/speech2text') -api.add_resource(PluginInvokeModerationApi, '/invoke/moderation') -api.add_resource(PluginInvokeToolApi, '/invoke/tool') -api.add_resource(PluginInvokeNodeApi, '/invoke/node') -api.add_resource(PluginInvokeAppApi, '/invoke/app') -api.add_resource(PluginInvokeEncryptApi, '/invoke/encrypt') +api.add_resource(PluginInvokeLLMApi, "/invoke/llm") +api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding") +api.add_resource(PluginInvokeRerankApi, "/invoke/rerank") +api.add_resource(PluginInvokeTTSApi, "/invoke/tts") +api.add_resource(PluginInvokeSpeech2TextApi, "/invoke/speech2text") +api.add_resource(PluginInvokeModerationApi, "/invoke/moderation") +api.add_resource(PluginInvokeToolApi, "/invoke/tool") +api.add_resource(PluginInvokeParameterExtractorNodeApi, "/invoke/parameter-extractor") +api.add_resource(PluginInvokeQuestionClassifierNodeApi, "/invoke/question-classifier") +api.add_resource(PluginInvokeAppApi, "/invoke/app") +api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt") diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py new file mode 100644 index 0000000000..9a7fd5fc3a --- /dev/null +++ b/api/core/plugin/backwards_invocation/node.py @@ -0,0 +1,114 @@ +from core.plugin.backwards_invocation.base import BaseBackwardsInvocation +from core.workflow.nodes.parameter_extractor.entities import ( + ModelConfig as ParameterExtractorModelConfig, +) +from core.workflow.nodes.parameter_extractor.entities import ( + ParameterConfig, + ParameterExtractorNodeData, +) +from core.workflow.nodes.question_classifier.entities import ( + ClassConfig, + QuestionClassifierNodeData, +) +from core.workflow.nodes.question_classifier.entities import ( + ModelConfig as QuestionClassifierModelConfig, +) +from services.workflow_service import WorkflowService + + +class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): + @classmethod + def invoke_parameter_extractor( + cls, + tenant_id: str, + user_id: str, + parameters: list[ParameterConfig], + model_config: ParameterExtractorModelConfig, + instruction: str, + query: str, + ) -> dict: + """ + Invoke parameter extractor node. + + :param tenant_id: str + :param user_id: str + :param parameters: list[ParameterConfig] + :param model_config: ModelConfig + :param instruction: str + :param query: str + :return: dict with __reason, __is_success, and other parameters + """ + workflow_service = WorkflowService() + node_id = "1919810" + node_data = ParameterExtractorNodeData( + title="parameter_extractor", + desc="parameter_extractor", + parameters=parameters, + reasoning_mode="function_call", + query=[node_id, "query"], + model=model_config, + instruction=instruction, # instruct with variables are not supported + ) + node_data_dict = node_data.model_dump() + execution = workflow_service.run_free_workflow_node( + node_data_dict, + tenant_id=tenant_id, + user_id=user_id, + node_id=node_id, + user_inputs={ + f"{node_id}.query": query, + }, + ) + + output = execution.outputs_dict + return output or { + "__reason": "No parameters extracted", + "__is_success": False, + } + + @classmethod + def invoke_question_classifier( + cls, + tenant_id: str, + user_id: str, + model_config: QuestionClassifierModelConfig, + classes: list[ClassConfig], + instruction: str, + query: str, + ) -> dict: + """ + Invoke question classifier node. + + :param tenant_id: str + :param user_id: str + :param model_config: ModelConfig + :param classes: list[ClassConfig] + :param instruction: str + :param query: str + :return: dict with class_name + """ + workflow_service = WorkflowService() + node_id = "1919810" + node_data = QuestionClassifierNodeData( + title="question_classifier", + desc="question_classifier", + query_variable_selector=[node_id, "query"], + model=model_config, + classes=classes, + instruction=instruction, # instruct with variables are not supported + ) + node_data_dict = node_data.model_dump() + execution = workflow_service.run_free_workflow_node( + node_data_dict, + tenant_id=tenant_id, + user_id=user_id, + node_id=node_id, + user_inputs={ + f"{node_id}.query": query, + }, + ) + + output = execution.outputs_dict + return output or { + "class_name": classes[0].name, + } diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 2e87b76636..bf4c4448c7 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -14,6 +14,16 @@ from core.model_runtime.entities.message_entities import ( UserPromptMessage, ) from core.model_runtime.entities.model_entities import ModelType +from core.workflow.nodes.question_classifier.entities import ( + ClassConfig, + ModelConfig as QuestionClassifierModelConfig, +) +from core.workflow.nodes.parameter_extractor.entities import ( + ModelConfig as ParameterExtractorModelConfig, +) +from core.workflow.nodes.parameter_extractor.entities import ( + ParameterConfig, +) class RequestInvokeTool(BaseModel): @@ -92,11 +102,27 @@ class RequestInvokeModeration(BaseModel): """ -class RequestInvokeNode(BaseModel): +class RequestInvokeParameterExtractorNode(BaseModel): """ - Request to invoke node + Request to invoke parameter extractor node """ + parameters: list[ParameterConfig] + model: ParameterExtractorModelConfig + instruction: str + query: str + + +class RequestInvokeQuestionClassifierNode(BaseModel): + """ + Request to invoke question classifier node + """ + + query: str + model: QuestionClassifierModelConfig + classes: list[ClassConfig] + instruction: str + class RequestInvokeApp(BaseModel): """ diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 74a598ada5..9477e98c92 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -205,6 +205,88 @@ class WorkflowEntry: except Exception as e: raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) + @classmethod + def run_free_node( + cls, node_data: dict, node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any] + ) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]: + """ + Run free node + + NOTE: only parameter_extractor/question_classifier are supported + + :param node_data: node data + :param user_id: user id + :param user_inputs: user inputs + :return: + """ + # generate a fake graph + node_config = {"id": node_id, "width": 114, "height": 514, "type": "custom", "data": node_data} + graph_dict = { + "nodes": [node_config], + } + + node_type = NodeType.value_of(node_data.get("type", "")) + if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}: + raise ValueError(f"Node type {node_type} not supported") + + node_cls = node_classes.get(node_type) + if not node_cls: + raise ValueError(f"Node class not found for node type {node_type}") + + graph = Graph.init(graph_config=graph_dict) + + # init variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + environment_variables=[], + ) + + node_cls = cast(type[BaseNode], node_cls) + # init workflow run state + node_instance: BaseNode = node_cls( + id=str(uuid.uuid4()), + config=node_config, + graph_init_params=GraphInitParams( + tenant_id=tenant_id, + app_id="", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="", + graph_config=graph_dict, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()), + ) + + try: + # variable selector to variable mapping + try: + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=graph_dict, config=node_config + ) + except NotImplementedError: + variable_mapping = {} + + cls.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id=tenant_id, + node_type=node_type, + node_data=node_instance.node_data, + ) + + # run node + generator = node_instance.run() + + return node_instance, generator + except Exception as e: + raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) + @classmethod def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]: """ diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 0ff81f1f7e..399451cb8e 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,8 +1,8 @@ import json import time -from collections.abc import Sequence +from collections.abc import Callable, Generator, Sequence from datetime import datetime, timezone -from typing import Optional +from typing import Any, Optional from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager @@ -10,7 +10,9 @@ from core.app.segments import Variable from core.model_runtime.utils.encoders import jsonable_encoder from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.nodes.event import RunCompletedEvent +from core.workflow.graph_engine.entities.event import InNodeEvent +from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.event import RunCompletedEvent, RunEvent from core.workflow.nodes.node_mapping import node_classes from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated @@ -216,13 +218,64 @@ class WorkflowService: # run draft workflow node start_at = time.perf_counter() - try: - node_instance, generator = WorkflowEntry.single_step_run( + workflow_node_execution = self._handle_node_run_result( + getter=lambda: WorkflowEntry.single_step_run( workflow=draft_workflow, node_id=node_id, user_inputs=user_inputs, user_id=account.id, - ) + ), + start_at=start_at, + tenant_id=app_model.tenant_id, + node_id=node_id, + ) + + db.session.add(workflow_node_execution) + db.session.commit() + + return workflow_node_execution + + def run_free_workflow_node( + self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] + ) -> WorkflowNodeExecution: + """ + Run draft workflow node + """ + # run draft workflow node + start_at = time.perf_counter() + + workflow_node_execution = self._handle_node_run_result( + getter=lambda: WorkflowEntry.run_free_node( + node_id=node_id, + node_data=node_data, + tenant_id=tenant_id, + user_id=user_id, + user_inputs=user_inputs, + ), + start_at=start_at, + tenant_id=tenant_id, + node_id=node_id + ) + + return workflow_node_execution + + def _handle_node_run_result( + self, + getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]], + start_at: float, + tenant_id: str, + node_id: str, + ): + """ + Handle node run result + + :param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]] + :param start_at: float + :param tenant_id: str + :param node_id: str + """ + try: + node_instance, generator = getter() node_run_result: NodeRunResult | None = None for event in generator: @@ -245,9 +298,7 @@ class WorkflowService: error = e.error workflow_node_execution = WorkflowNodeExecution() - workflow_node_execution.tenant_id = app_model.tenant_id - workflow_node_execution.app_id = app_model.id - workflow_node_execution.workflow_id = draft_workflow.id + workflow_node_execution.tenant_id = tenant_id workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value workflow_node_execution.index = 1 workflow_node_execution.node_id = node_id @@ -255,7 +306,6 @@ class WorkflowService: workflow_node_execution.title = node_instance.node_data.title workflow_node_execution.elapsed_time = time.perf_counter() - start_at workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value - workflow_node_execution.created_by = account.id workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -277,9 +327,6 @@ class WorkflowService: workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.error = error - db.session.add(workflow_node_execution) - db.session.commit() - return workflow_node_execution def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App: @@ -302,10 +349,10 @@ class WorkflowService: new_app = workflow_converter.convert_to_workflow( app_model=app_model, account=account, - name=args.get("name"), - icon_type=args.get("icon_type"), - icon=args.get("icon"), - icon_background=args.get("icon_background"), + name=args.get("name", ""), + icon_type=args.get("icon_type", ""), + icon=args.get("icon", ""), + icon_background=args.get("icon_background", ""), ) return new_app