From 0806b3163ab45f8149acc493bb7b5c33095ebe65 Mon Sep 17 00:00:00 2001 From: takatost Date: Fri, 23 Feb 2024 18:18:49 +0800 Subject: [PATCH] add to http request node convert tests --- api/core/application_manager.py | 8 +- api/core/entities/application_entities.py | 1 + api/services/app_model_config_service.py | 2 +- api/services/workflow/workflow_converter.py | 24 ++- api/tests/unit_tests/services/__init__.py | 0 .../unit_tests/services/workflow/__init__.py | 0 .../workflow/test_workflow_converter.py | 184 ++++++++++++++++++ 7 files changed, 210 insertions(+), 9 deletions(-) create mode 100644 api/tests/unit_tests/services/__init__.py create mode 100644 api/tests/unit_tests/services/workflow/__init__.py create mode 100644 api/tests/unit_tests/services/workflow/test_workflow_converter.py diff --git a/api/core/application_manager.py b/api/core/application_manager.py index cf463be1df..77bb81b0da 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -400,10 +400,14 @@ class ApplicationManager: config=val['config'] ) ) - elif typ in [VariableEntity.Type.TEXT_INPUT.value, VariableEntity.Type.PARAGRAPH.value]: + elif typ in [ + VariableEntity.Type.TEXT_INPUT.value, + VariableEntity.Type.PARAGRAPH.value, + VariableEntity.Type.NUMBER.value, + ]: properties['variables'].append( VariableEntity( - type=VariableEntity.Type.TEXT_INPUT, + type=VariableEntity.Type.value_of(typ), variable=variable[typ].get('variable'), description=variable[typ].get('description'), label=variable[typ].get('label'), diff --git a/api/core/entities/application_entities.py b/api/core/entities/application_entities.py index f8f293d96a..667940f184 100644 --- a/api/core/entities/application_entities.py +++ b/api/core/entities/application_entities.py @@ -94,6 +94,7 @@ class VariableEntity(BaseModel): TEXT_INPUT = 'text-input' SELECT = 'select' PARAGRAPH = 'paragraph' + NUMBER = 'number' @classmethod def value_of(cls, value: str) -> 'VariableEntity.Type': diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 3ac11c645c..aa8cd73ea7 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -205,7 +205,7 @@ class AppModelConfigService: variables = [] for item in config["user_input_form"]: key = list(item.keys())[0] - if key not in ["text-input", "select", "paragraph", "external_data_tool"]: + if key not in ["text-input", "select", "paragraph", "number", "external_data_tool"]: raise ValueError("Keys in user_input_form list can only be 'text-input', 'paragraph' or 'select'") form_item = item[key] diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 1fb37afe01..31df58a583 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -190,10 +190,10 @@ class WorkflowConverter: api_based_extension_id = tool_config.get("api_based_extension_id") # get api_based_extension - api_based_extension = db.session.query(APIBasedExtension).filter( - APIBasedExtension.tenant_id == tenant_id, - APIBasedExtension.id == api_based_extension_id - ).first() + api_based_extension = self._get_api_based_extension( + tenant_id=tenant_id, + api_based_extension_id=api_based_extension_id + ) if not api_based_extension: raise ValueError("[External data tool] API query failed, variable: {}, " @@ -259,7 +259,6 @@ class WorkflowConverter: } } } - index += 1 nodes.append(http_request_node) @@ -268,7 +267,7 @@ class WorkflowConverter: "id": f"code-{index}", "position": None, "data": { - "title": f"Parse {api_based_extension.name} response", + "title": f"Parse {api_based_extension.name} Response", "type": NodeType.CODE.value, "variables": [{ "variable": "response_json", @@ -287,6 +286,7 @@ class WorkflowConverter: } nodes.append(code_node) + index += 1 return nodes @@ -513,3 +513,15 @@ class WorkflowConverter: return AppMode.WORKFLOW else: return AppMode.value_of(app_model.mode) + + def _get_api_based_extension(self, tenant_id: str, api_based_extension_id: str) -> APIBasedExtension: + """ + Get API Based Extension + :param tenant_id: tenant id + :param api_based_extension_id: api based extension id + :return: + """ + return db.session.query(APIBasedExtension).filter( + APIBasedExtension.tenant_id == tenant_id, + APIBasedExtension.id == api_based_extension_id + ).first() diff --git a/api/tests/unit_tests/services/__init__.py b/api/tests/unit_tests/services/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/services/workflow/__init__.py b/api/tests/unit_tests/services/workflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py new file mode 100644 index 0000000000..69cf6afe45 --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -0,0 +1,184 @@ +# test for api/services/workflow/workflow_converter.py +import json +from unittest.mock import MagicMock + +import pytest + +from core.entities.application_entities import VariableEntity, ExternalDataVariableEntity +from core.helper import encrypter +from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint +from models.model import AppMode +from services.workflow.workflow_converter import WorkflowConverter + + +@pytest.fixture +def default_variables(): + return [ + VariableEntity( + variable="text-input", + label="text-input", + type=VariableEntity.Type.TEXT_INPUT + ), + VariableEntity( + variable="paragraph", + label="paragraph", + type=VariableEntity.Type.PARAGRAPH + ), + VariableEntity( + variable="select", + label="select", + type=VariableEntity.Type.SELECT + ) + ] + + +def test__convert_to_start_node(default_variables): + # act + result = WorkflowConverter()._convert_to_start_node(default_variables) + + # assert + assert result["data"]["variables"][0]["variable"] == "text-input" + assert result["data"]["variables"][1]["variable"] == "paragraph" + assert result["data"]["variables"][2]["variable"] == "select" + + +def test__convert_to_http_request_node(default_variables): + """ + Test convert to http request nodes + :return: + """ + app_model = MagicMock() + app_model.id = "app_id" + app_model.tenant_id = "tenant_id" + app_model.mode = AppMode.CHAT.value + + api_based_extension_id = "api_based_extension_id" + mock_api_based_extension = APIBasedExtension( + id=api_based_extension_id, + name="api-1", + api_key="encrypted_api_key", + api_endpoint="https://dify.ai", + ) + + workflow_converter = WorkflowConverter() + workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension) + + encrypter.decrypt_token = MagicMock(return_value="api_key") + + external_data_variables = [ + ExternalDataVariableEntity( + variable="external_variable", + type="api", + config={ + "api_based_extension_id": api_based_extension_id + } + ) + ] + + nodes = workflow_converter._convert_to_http_request_node( + app_model=app_model, + variables=default_variables, + external_data_variables=external_data_variables + ) + + assert len(nodes) == 2 + assert nodes[0]["data"]["type"] == "http-request" + + http_request_node = nodes[0] + + assert len(http_request_node["data"]["variables"]) == 4 # appended _query variable + assert http_request_node["data"]["method"] == "post" + assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint + assert http_request_node["data"]["authorization"]["type"] == "api-key" + assert http_request_node["data"]["authorization"]["config"] == { + "type": "bearer", + "api_key": "api_key" + } + assert http_request_node["data"]["body"]["type"] == "json" + + body_data = http_request_node["data"]["body"]["data"] + + assert body_data + + body_data_json = json.loads(body_data) + assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value + + body_params = body_data_json["params"] + assert body_params["app_id"] == app_model.id + assert body_params["tool_variable"] == external_data_variables[0].variable + assert len(body_params["inputs"]) == 3 + assert body_params["query"] == "{{_query}}" # for chatbot + + code_node = nodes[1] + assert code_node["data"]["type"] == "code" + + +def test__convert_to_http_request_node_for_workflow_app(default_variables): + """ + Test convert to http request nodes for workflow app + :return: + """ + app_model = MagicMock() + app_model.id = "app_id" + app_model.tenant_id = "tenant_id" + app_model.mode = AppMode.WORKFLOW.value + + api_based_extension_id = "api_based_extension_id" + mock_api_based_extension = APIBasedExtension( + id=api_based_extension_id, + name="api-1", + api_key="encrypted_api_key", + api_endpoint="https://dify.ai", + ) + + workflow_converter = WorkflowConverter() + workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension) + + encrypter.decrypt_token = MagicMock(return_value="api_key") + + external_data_variables = [ + ExternalDataVariableEntity( + variable="external_variable", + type="api", + config={ + "api_based_extension_id": api_based_extension_id + } + ) + ] + + nodes = workflow_converter._convert_to_http_request_node( + app_model=app_model, + variables=default_variables, + external_data_variables=external_data_variables + ) + + assert len(nodes) == 2 + assert nodes[0]["data"]["type"] == "http-request" + + http_request_node = nodes[0] + + assert len(http_request_node["data"]["variables"]) == 3 + assert http_request_node["data"]["method"] == "post" + assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint + assert http_request_node["data"]["authorization"]["type"] == "api-key" + assert http_request_node["data"]["authorization"]["config"] == { + "type": "bearer", + "api_key": "api_key" + } + assert http_request_node["data"]["body"]["type"] == "json" + + body_data = http_request_node["data"]["body"]["data"] + + assert body_data + + body_data_json = json.loads(body_data) + assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY.value + + body_params = body_data_json["params"] + assert body_params["app_id"] == app_model.id + assert body_params["tool_variable"] == external_data_variables[0].variable + assert len(body_params["inputs"]) == 3 + assert body_params["query"] == "" + + code_node = nodes[1] + assert code_node["data"]["type"] == "code"