From de3b7e88154382ae9f6022d11b20027364b473b0 Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 29 Mar 2024 20:54:08 +0800 Subject: [PATCH] http request node support template variable --- .../sensitive_word_avoidance/manager.py | 4 +- api/core/tools/tool_file_manager.py | 40 ++++---- api/core/utils/module_import_helper.py | 2 +- api/core/workflow/nodes/answer/answer_node.py | 11 +-- .../workflow/nodes/http_request/entities.py | 2 - .../nodes/http_request/http_executor.py | 95 +++++++++---------- .../nodes/http_request/http_request_node.py | 29 +++--- api/services/workflow/workflow_converter.py | 15 +-- .../workflow/nodes/test_http.py | 60 ++---------- 9 files changed, 90 insertions(+), 168 deletions(-) diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py index 3dccfa3cbe..66d4a3275b 100644 --- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -39,12 +39,12 @@ class SensitiveWordAvoidanceConfigManager: if not only_structure_validate: typ = config["sensitive_word_avoidance"]["type"] - config = config["sensitive_word_avoidance"]["config"] + sensitive_word_avoidance_config = config["sensitive_word_avoidance"]["config"] ModerationFactory.validate_config( name=typ, tenant_id=tenant_id, - config=config + config=sensitive_word_avoidance_config ) return config, ["sensitive_word_avoidance"] diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index ceda31952e..e21a2efedd 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -6,7 +6,7 @@ import os import time from collections.abc import Generator from mimetypes import guess_extension, guess_type -from typing import Union +from typing import Optional, Union from uuid import uuid4 from flask import current_app @@ -19,6 +19,7 @@ from models.tools import ToolFile logger = logging.getLogger(__name__) + class ToolFileManager: @staticmethod def sign_file(tool_file_id: str, extension: str) -> str: @@ -55,10 +56,10 @@ class ToolFileManager: return current_time - int(timestamp) <= 300 # expired after 5 minutes @staticmethod - def create_file_by_raw(user_id: str, tenant_id: str, - conversation_id: str, file_binary: bytes, - mimetype: str - ) -> ToolFile: + def create_file_by_raw(user_id: str, tenant_id: str, + conversation_id: Optional[str], file_binary: bytes, + mimetype: str + ) -> ToolFile: """ create file """ @@ -74,11 +75,11 @@ class ToolFileManager: db.session.commit() return tool_file - + @staticmethod - def create_file_by_url(user_id: str, tenant_id: str, - conversation_id: str, file_url: str, - ) -> ToolFile: + def create_file_by_url(user_id: str, tenant_id: str, + conversation_id: str, file_url: str, + ) -> ToolFile: """ create file """ @@ -93,26 +94,26 @@ class ToolFileManager: storage.save(filename, blob) tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id, - conversation_id=conversation_id, file_key=filename, + conversation_id=conversation_id, file_key=filename, mimetype=mimetype, original_url=file_url) - + db.session.add(tool_file) db.session.commit() return tool_file @staticmethod - def create_file_by_key(user_id: str, tenant_id: str, - conversation_id: str, file_key: str, - mimetype: str - ) -> ToolFile: + def create_file_by_key(user_id: str, tenant_id: str, + conversation_id: str, file_key: str, + mimetype: str + ) -> ToolFile: """ create file """ tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_key=file_key, mimetype=mimetype) return tool_file - + @staticmethod def get_file_binary(id: str) -> Union[tuple[bytes, str], None]: """ @@ -132,7 +133,7 @@ class ToolFileManager: blob = storage.load_once(tool_file.file_key) return blob, tool_file.mimetype - + @staticmethod def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]: """ @@ -161,7 +162,7 @@ class ToolFileManager: blob = storage.load_once(tool_file.file_key) return blob, tool_file.mimetype - + @staticmethod def get_file_generator_by_tool_file_id(tool_file_id: str) -> Union[tuple[Generator, str], None]: """ @@ -181,7 +182,8 @@ class ToolFileManager: generator = storage.load_stream(tool_file.file_key) return generator, tool_file.mimetype - + + # init tool_file_parser from core.file.tool_file_parser import tool_file_manager diff --git a/api/core/utils/module_import_helper.py b/api/core/utils/module_import_helper.py index 9e6e02f29f..d3a4bab4a1 100644 --- a/api/core/utils/module_import_helper.py +++ b/api/core/utils/module_import_helper.py @@ -59,4 +59,4 @@ def load_single_subclass_from_source( case 0: raise Exception(f'Missing subclass of {parent_type.__name__} in {script_path}') case _: - raise Exception(f'Multiple subclasses of {parent_type.__name__} in {script_path}') + raise Exception(f'Multiple subclasses of {parent_type.__name__} in {script_path}') \ No newline at end of file diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 9194d3fef7..c11846a935 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -142,13 +142,4 @@ class AnswerNode(BaseNode): :param node_data: node data :return: """ - node_data = cast(cls._node_data_cls, node_data) - - variable_template_parser = VariableTemplateParser(template=node_data.answer) - variable_selectors = variable_template_parser.extract_variable_selectors() - - variable_mapping = {} - for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector - - return variable_mapping + return {} diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 4ab5538cf5..94ba6cb866 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -3,7 +3,6 @@ from typing import Literal, Optional, Union from pydantic import BaseModel, validator from core.workflow.entities.base_node_data_entities import BaseNodeData -from core.workflow.entities.variable_entities import VariableSelector class HttpRequestNodeData(BaseNodeData): @@ -36,7 +35,6 @@ class HttpRequestNodeData(BaseNodeData): type: Literal['none', 'form-data', 'x-www-form-urlencoded', 'raw-text', 'json'] data: Union[None, str] - variables: list[VariableSelector] method: Literal['get', 'post', 'put', 'patch', 'delete', 'head'] url: str authorization: Authorization diff --git a/api/core/workflow/nodes/http_request/http_executor.py b/api/core/workflow/nodes/http_request/http_executor.py index 67aa53a07b..cb636df2f3 100644 --- a/api/core/workflow/nodes/http_request/http_executor.py +++ b/api/core/workflow/nodes/http_request/http_executor.py @@ -1,5 +1,4 @@ import json -import re from copy import deepcopy from random import randint from typing import Any, Union @@ -9,7 +8,9 @@ import httpx import requests import core.helper.ssrf_proxy as ssrf_proxy +from core.workflow.entities.variable_pool import ValueType, VariablePool from core.workflow.nodes.http_request.entities import HttpRequestNodeData +from core.workflow.utils.variable_template_parser import VariableTemplateParser HTTP_REQUEST_DEFAULT_TIMEOUT = (10, 60) MAX_BINARY_SIZE = 1024 * 1024 * 10 # 10MB @@ -17,6 +18,7 @@ READABLE_MAX_BINARY_SIZE = '10MB' MAX_TEXT_SIZE = 1024 * 1024 // 10 # 0.1MB READABLE_MAX_TEXT_SIZE = '0.1MB' + class HttpExecutorResponse: headers: dict[str, str] response: Union[httpx.Response, requests.Response] @@ -123,6 +125,7 @@ class HttpExecutorResponse: else: return f'{(self.size / 1024 / 1024):.2f} MB' + class HttpExecutor: server_url: str method: str @@ -133,7 +136,7 @@ class HttpExecutor: files: Union[None, dict[str, Any]] boundary: str - def __init__(self, node_data: HttpRequestNodeData, variables: dict[str, Any]): + def __init__(self, node_data: HttpRequestNodeData, variable_pool: VariablePool): """ init """ @@ -146,49 +149,33 @@ class HttpExecutor: self.files = None # init template - self._init_template(node_data, variables) + self._init_template(node_data, variable_pool) - def _is_json_body(self, node_data: HttpRequestNodeData): + def _is_json_body(self, body: HttpRequestNodeData.Body): """ check if body is json """ - if node_data.body and node_data.body.type == 'json': + if body and body.type == 'json': try: - json.loads(node_data.body.data) + json.loads(body.data) return True except: return False return False - def _init_template(self, node_data: HttpRequestNodeData, variables: dict[str, Any]): + def _init_template(self, node_data: HttpRequestNodeData, variable_pool: VariablePool): """ init template """ # extract all template in url - url_template = re.findall(r'{{(.*?)}}', node_data.url) or [] - url_template = list(set(url_template)) - original_url = node_data.url - for url in url_template: - if not url: - continue - - original_url = original_url.replace(f'{{{{{url}}}}}', str(variables.get(url, ''))) - - self.server_url = original_url + self.server_url = self._format_template(node_data.url, variable_pool) # extract all template in params - param_template = re.findall(r'{{(.*?)}}', node_data.params) or [] - param_template = list(set(param_template)) - original_params = node_data.params - for param in param_template: - if not param: - continue - - original_params = original_params.replace(f'{{{{{param}}}}}', str(variables.get(param, ''))) + params = self._format_template(node_data.params, variable_pool) # fill in params - kv_paris = original_params.split('\n') + kv_paris = params.split('\n') for kv in kv_paris: if not kv.strip(): continue @@ -204,17 +191,10 @@ class HttpExecutor: self.params[k.strip()] = v # extract all template in headers - header_template = re.findall(r'{{(.*?)}}', node_data.headers) or [] - header_template = list(set(header_template)) - original_headers = node_data.headers - for header in header_template: - if not header: - continue - - original_headers = original_headers.replace(f'{{{{{header}}}}}', str(variables.get(header, ''))) + headers = self._format_template(node_data.headers, variable_pool) # fill in headers - kv_paris = original_headers.split('\n') + kv_paris = headers.split('\n') for kv in kv_paris: if not kv.strip(): continue @@ -232,19 +212,11 @@ class HttpExecutor: # extract all template in body if node_data.body: # check if it's a valid JSON - is_valid_json = self._is_json_body(node_data) - body_template = re.findall(r'{{(.*?)}}', node_data.body.data or '') or [] - body_template = list(set(body_template)) - original_body = node_data.body.data or '' - for body in body_template: - if not body: - continue - - body_value = variables.get(body, '') - if is_valid_json: - body_value = body_value.replace('"', '\\"') - - original_body = original_body.replace(f'{{{{{body}}}}}', body_value) + is_valid_json = self._is_json_body(node_data.body) + + body_data = node_data.body.data or '' + if body_data: + body_data = self._format_template(body_data, variable_pool, is_valid_json) if node_data.body.type == 'json': self.headers['Content-Type'] = 'application/json' @@ -253,7 +225,7 @@ class HttpExecutor: if node_data.body.type in ['form-data', 'x-www-form-urlencoded']: body = {} - kv_paris = original_body.split('\n') + kv_paris = body_data.split('\n') for kv in kv_paris: if not kv.strip(): continue @@ -276,7 +248,7 @@ class HttpExecutor: else: self.body = urlencode(body) elif node_data.body.type in ['json', 'raw-text']: - self.body = original_body + self.body = body_data elif node_data.body.type == 'none': self.body = '' @@ -390,4 +362,25 @@ class HttpExecutor: else: raw_request += self.body or '' - return raw_request \ No newline at end of file + return raw_request + + def _format_template(self, template: str, variable_pool: VariablePool, escape_quotes: bool = False) -> str: + """ + format template + """ + variable_template_parser = VariableTemplateParser(template=template) + variable_selectors = variable_template_parser.extract_variable_selectors() + + variable_value_mapping = {} + for variable_selector in variable_selectors: + value = variable_pool.get_variable_value( + variable_selector=variable_selector.value_selector, + target_value_type=ValueType.STRING + ) + + if escape_quotes: + value = value.replace('"', '\\"') + + variable_value_mapping[variable_selector.variable] = value + + return variable_template_parser.format(variable_value_mapping) diff --git a/api/core/workflow/nodes/http_request/http_request_node.py b/api/core/workflow/nodes/http_request/http_request_node.py index e74cdf3145..1e538bc2fe 100644 --- a/api/core/workflow/nodes/http_request/http_request_node.py +++ b/api/core/workflow/nodes/http_request/http_request_node.py @@ -19,33 +19,29 @@ class HttpRequestNode(BaseNode): def _run(self, variable_pool: VariablePool) -> NodeRunResult: node_data: HttpRequestNodeData = cast(self._node_data_cls, self.node_data) - # extract variables - variables = { - variable_selector.variable: variable_pool.get_variable_value(variable_selector=variable_selector.value_selector) - for variable_selector in node_data.variables - } - # init http executor + http_executor = None try: - http_executor = HttpExecutor(node_data=node_data, variables=variables) + http_executor = HttpExecutor(node_data=node_data, variable_pool=variable_pool) # invoke http executor response = http_executor.invoke() except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error=str(e), - process_data={ + process_data = {} + if http_executor: + process_data = { 'request': http_executor.to_raw_request(), } + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + process_data=process_data ) files = self.extract_files(http_executor.server_url, response) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, outputs={ 'status_code': response.status_code, 'body': response.content if not files else '', @@ -57,7 +53,6 @@ class HttpRequestNode(BaseNode): } ) - @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: HttpRequestNodeData) -> dict[str, list[str]]: """ @@ -65,9 +60,7 @@ class HttpRequestNode(BaseNode): :param node_data: node data :return: """ - return { - variable_selector.variable: variable_selector.value_selector for variable_selector in node_data.variables - } + return {} def extract_files(self, url: str, response: HttpExecutorResponse) -> list[FileVar]: """ @@ -103,4 +96,4 @@ class HttpRequestNode(BaseNode): mime_type=mimetype, )) - return files \ No newline at end of file + return files diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index d597941ef6..0c5f453208 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -278,21 +278,9 @@ class WorkflowConverter: token=api_based_extension.api_key ) - http_request_variables = [] inputs = {} for v in variables: - http_request_variables.append({ - "variable": v.variable, - "value_selector": ["start", v.variable] - }) - - inputs[v.variable] = '{{' + v.variable + '}}' - - if app_model.mode == AppMode.CHAT.value: - http_request_variables.append({ - "variable": "_query", - "value_selector": ["sys", ".query"] - }) + inputs[v.variable] = '{{#start.' + v.variable + '#}}' request_body = { 'point': APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value, @@ -313,7 +301,6 @@ class WorkflowConverter: "data": { "title": f"HTTP REQUEST {api_based_extension.name}", "type": NodeType.HTTP_REQUEST.value, - "variables": http_request_variables, "method": "post", "url": api_based_extension.api_endpoint, "authorization": { diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 8b94105b44..a6c011944f 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -1,4 +1,3 @@ -from calendar import c import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool @@ -16,8 +15,8 @@ BASIC_NODE_DATA = { # construct variable pool pool = VariablePool(system_variables={}, user_inputs={}) -pool.append_variable(node_id='1', variable_key_list=['123', 'args1'], value=1) -pool.append_variable(node_id='1', variable_key_list=['123', 'args2'], value=2) +pool.append_variable(node_id='a', variable_key_list=['b123', 'args1'], value=1) +pool.append_variable(node_id='a', variable_key_list=['b123', 'args2'], value=2) @pytest.mark.parametrize('setup_http_mock', [['none']], indirect=True) def test_get(setup_http_mock): @@ -26,10 +25,6 @@ def test_get(setup_http_mock): 'data': { 'title': 'http', 'desc': '', - 'variables': [{ - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], - }], 'method': 'get', 'url': 'http://example.com', 'authorization': { @@ -61,10 +56,6 @@ def test_no_auth(setup_http_mock): 'data': { 'title': 'http', 'desc': '', - 'variables': [{ - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], - }], 'method': 'get', 'url': 'http://example.com', 'authorization': { @@ -91,10 +82,6 @@ def test_custom_authorization_header(setup_http_mock): 'data': { 'title': 'http', 'desc': '', - 'variables': [{ - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], - }], 'method': 'get', 'url': 'http://example.com', 'authorization': { @@ -126,12 +113,8 @@ def test_template(setup_http_mock): 'data': { 'title': 'http', 'desc': '', - 'variables': [{ - 'variable': 'args1', - 'value_selector': ['1', '123', 'args2'], - }], 'method': 'get', - 'url': 'http://example.com/{{args1}}', + 'url': 'http://example.com/{{#a.b123.args2#}}', 'authorization': { 'type': 'api-key', 'config': { @@ -140,8 +123,8 @@ def test_template(setup_http_mock): 'header': 'api-key', } }, - 'headers': 'X-Header:123\nX-Header2:{{args1}}', - 'params': 'A:b\nTemplate:{{args1}}', + 'headers': 'X-Header:123\nX-Header2:{{#a.b123.args2#}}', + 'params': 'A:b\nTemplate:{{#a.b123.args2#}}', 'body': None, } }, **BASIC_NODE_DATA) @@ -162,10 +145,6 @@ def test_json(setup_http_mock): 'data': { 'title': 'http', 'desc': '', - 'variables': [{ - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], - }], 'method': 'post', 'url': 'http://example.com', 'authorization': { @@ -180,7 +159,7 @@ def test_json(setup_http_mock): 'params': 'A:b', 'body': { 'type': 'json', - 'data': '{"a": "{{args1}}"}' + 'data': '{"a": "{{#a.b123.args1#}}"}' }, } }, **BASIC_NODE_DATA) @@ -198,13 +177,6 @@ def test_x_www_form_urlencoded(setup_http_mock): 'data': { 'title': 'http', 'desc': '', - 'variables': [{ - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], - }, { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'], - }], 'method': 'post', 'url': 'http://example.com', 'authorization': { @@ -219,7 +191,7 @@ def test_x_www_form_urlencoded(setup_http_mock): 'params': 'A:b', 'body': { 'type': 'x-www-form-urlencoded', - 'data': 'a:{{args1}}\nb:{{args2}}' + 'data': 'a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}' }, } }, **BASIC_NODE_DATA) @@ -237,13 +209,6 @@ def test_form_data(setup_http_mock): 'data': { 'title': 'http', 'desc': '', - 'variables': [{ - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], - }, { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'], - }], 'method': 'post', 'url': 'http://example.com', 'authorization': { @@ -258,7 +223,7 @@ def test_form_data(setup_http_mock): 'params': 'A:b', 'body': { 'type': 'form-data', - 'data': 'a:{{args1}}\nb:{{args2}}' + 'data': 'a:{{#a.b123.args1#}}\nb:{{#a.b123.args2#}}' }, } }, **BASIC_NODE_DATA) @@ -279,13 +244,6 @@ def test_none_data(setup_http_mock): 'data': { 'title': 'http', 'desc': '', - 'variables': [{ - 'variable': 'args1', - 'value_selector': ['1', '123', 'args1'], - }, { - 'variable': 'args2', - 'value_selector': ['1', '123', 'args2'], - }], 'method': 'post', 'url': 'http://example.com', 'authorization': { @@ -310,4 +268,4 @@ def test_none_data(setup_http_mock): assert 'api-key: Basic ak-xxx' in data assert 'X-Header: 123' in data - assert '123123123' not in data \ No newline at end of file + assert '123123123' not in data