From 072967a1d38139b93f88418debe1f4d372cb6fa1 Mon Sep 17 00:00:00 2001 From: takatost Date: Mon, 1 Apr 2024 15:24:35 +0800 Subject: [PATCH] fix node single step run of answer & http request & llm --- api/core/workflow/nodes/answer/answer_node.py | 12 +++++++++- .../nodes/http_request/http_executor.py | 22 ++++++++++++++----- .../nodes/http_request/http_request_node.py | 18 ++++++++++++++- api/core/workflow/nodes/llm/llm_node.py | 14 ++++++++++++ 4 files changed, 58 insertions(+), 8 deletions(-) diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index c11846a935..e8f1678ecb 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -142,4 +142,14 @@ class AnswerNode(BaseNode): :param node_data: node data :return: """ - return {} + node_data = node_data + node_data = cast(cls._node_data_cls, node_data) + + variable_template_parser = VariableTemplateParser(template=node_data.answer) + variable_selectors = variable_template_parser.extract_variable_selectors() + + variable_mapping = {} + for variable_selector in variable_selectors: + variable_mapping[variable_selector.variable] = variable_selector.value_selector + + return variable_mapping diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index 17238da11a..d270277cd5 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -8,6 +8,7 @@ import httpx import requests import core.helper.ssrf_proxy as ssrf_proxy +from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_pool import ValueType, VariablePool from core.workflow.nodes.http_request.entities import HttpRequestNodeData from core.workflow.utils.variable_template_parser import VariableTemplateParser @@ -135,6 +136,7 @@ class HttpExecutor: body: Union[None, str] files: Union[None, dict[str, Any]] boundary: str + variable_selectors: list[VariableSelector] def __init__(self, node_data: HttpRequestNodeData, variable_pool: VariablePool): """ @@ -149,6 +151,7 @@ class HttpExecutor: self.files = None # init template + self.variable_selectors = [] self._init_template(node_data, variable_pool) def _is_json_body(self, body: HttpRequestNodeData.Body): @@ -168,11 +171,13 @@ class HttpExecutor: """ init template """ + variable_selectors = [] + # extract all template in url - self.server_url = self._format_template(node_data.url, variable_pool) + self.server_url, server_url_variable_selectors = self._format_template(node_data.url, variable_pool) # extract all template in params - params = self._format_template(node_data.params, variable_pool) + params, params_variable_selectors = self._format_template(node_data.params, variable_pool) # fill in params kv_paris = params.split('\n') @@ -191,7 +196,7 @@ class HttpExecutor: self.params[k.strip()] = v # extract all template in headers - headers = self._format_template(node_data.headers, variable_pool) + headers, headers_variable_selectors = self._format_template(node_data.headers, variable_pool) # fill in headers kv_paris = headers.split('\n') @@ -210,13 +215,14 @@ class HttpExecutor: self.headers[k.strip()] = v.strip() # extract all template in body + body_data_variable_selectors = [] if node_data.body: # check if it's a valid JSON is_valid_json = self._is_json_body(node_data.body) body_data = node_data.body.data or '' if body_data: - body_data = self._format_template(body_data, variable_pool, is_valid_json) + body_data, body_data_variable_selectors = self._format_template(body_data, variable_pool, is_valid_json) if node_data.body.type == 'json': self.headers['Content-Type'] = 'application/json' @@ -251,6 +257,9 @@ class HttpExecutor: self.body = body_data elif node_data.body.type == 'none': self.body = '' + + self.variable_selectors = (server_url_variable_selectors + params_variable_selectors + + headers_variable_selectors + body_data_variable_selectors) def _assembling_headers(self) -> dict[str, Any]: authorization = deepcopy(self.authorization) @@ -364,7 +373,8 @@ class HttpExecutor: return raw_request - def _format_template(self, template: str, variable_pool: VariablePool, escape_quotes: bool = False) -> str: + def _format_template(self, template: str, variable_pool: VariablePool, escape_quotes: bool = False) \ + -> tuple[str, list[VariableSelector]]: """ format template """ @@ -386,4 +396,4 @@ class HttpExecutor: variable_value_mapping[variable_selector.variable] = value - return variable_template_parser.format(variable_value_mapping) + return variable_template_parser.format(variable_value_mapping), variable_selectors diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index 1e538bc2fe..47853832f8 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -1,3 +1,4 @@ +import logging from mimetypes import guess_extension from os import path from typing import cast @@ -60,7 +61,22 @@ class HttpRequestNode(BaseNode): :param node_data: node data :return: """ - return {} + try: + http_executor = HttpExecutor(node_data=node_data, variable_pool=VariablePool( + system_variables={}, + user_inputs={} + )) + + variable_selectors = http_executor.variable_selectors + + variable_mapping = {} + for variable_selector in variable_selectors: + variable_mapping[variable_selector.variable] = variable_selector.value_selector + + return variable_mapping + except Exception as e: + logging.exception(f"Failed to extract variable selector to variable mapping: {e}") + return {} def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]: """ diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index c0049c5bb3..df7b33d4c3 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -495,7 +495,21 @@ class LLMNode(BaseNode): node_data = node_data node_data = cast(cls._node_data_cls, node_data) + prompt_template = node_data.prompt_template + + variable_selectors = [] + if isinstance(prompt_template, list): + for prompt in prompt_template: + variable_template_parser = VariableTemplateParser(template=prompt.text) + variable_selectors.extend(variable_template_parser.extract_variable_selectors()) + else: + variable_template_parser = VariableTemplateParser(template=prompt_template.text) + variable_selectors = variable_template_parser.extract_variable_selectors() + variable_mapping = {} + for variable_selector in variable_selectors: + variable_mapping[variable_selector.variable] = variable_selector.value_selector + if node_data.context.enabled: variable_mapping['#context#'] = node_data.context.variable_selector