From 1024fc623efd19389742c5c1afa49ebf0a35a342 Mon Sep 17 00:00:00 2001 From: Jyong <76649700+JohnJyong@users.noreply.github.com> Date: Mon, 4 Nov 2024 15:22:07 +0800 Subject: [PATCH 01/29] fix the ssrf of docx file extractor external images (#10237) --- api/core/rag/extractor/word_extractor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index ae3c25125c..d4434ea28f 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -14,6 +14,7 @@ import requests from docx import Document as DocxDocument from configs import dify_config +from core.helper import ssrf_proxy from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_database import db @@ -86,7 +87,7 @@ class WordExtractor(BaseExtractor): image_count += 1 if rel.is_external: url = rel.reltype - response = requests.get(url, stream=True) + response = ssrf_proxy.get(url, stream=True) if response.status_code == 200: image_ext = mimetypes.guess_extension(response.headers["Content-Type"]) file_uuid = str(uuid.uuid4()) From 8b5ea399168e957e737dd62375d592de34dbf3df Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 4 Nov 2024 15:22:31 +0800 Subject: [PATCH 02/29] chore(llm_node): remove unnecessary type ignore for context assignment (#10216) --- api/core/workflow/nodes/llm/node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index b4728e6abf..bb9290ddc2 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -103,7 +103,7 @@ class LLMNode(BaseNode[LLMNodeData]): yield event if context: - node_inputs["#context#"] = context # type: ignore + node_inputs["#context#"] = context # fetch model config model_instance, model_config = self._fetch_model_config(self.node_data.model) From be96f6e62db1152ca14cd2ab70bd9327907cd9bc Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 4 Nov 2024 15:22:41 +0800 Subject: [PATCH 03/29] refactor(workflow): introduce specific exceptions for code validation (#10218) --- api/core/workflow/nodes/code/code_node.py | 46 +++++++++++++---------- api/core/workflow/nodes/code/exc.py | 16 ++++++++ 2 files changed, 42 insertions(+), 20 deletions(-) create mode 100644 api/core/workflow/nodes/code/exc.py diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 9d7d9027c3..de70af58dd 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -12,6 +12,12 @@ from core.workflow.nodes.code.entities import CodeNodeData from core.workflow.nodes.enums import NodeType from models.workflow import WorkflowNodeExecutionStatus +from .exc import ( + CodeNodeError, + DepthLimitError, + OutputValidationError, +) + class CodeNode(BaseNode[CodeNodeData]): _node_data_cls = CodeNodeData @@ -60,7 +66,7 @@ class CodeNode(BaseNode[CodeNodeData]): # Transform result result = self._transform_result(result, self.node_data.outputs) - except (CodeExecutionError, ValueError) as e: + except (CodeExecutionError, CodeNodeError) as e: return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e)) return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) @@ -76,10 +82,10 @@ class CodeNode(BaseNode[CodeNodeData]): if value is None: return None else: - raise ValueError(f"Output variable `{variable}` must be a string") + raise OutputValidationError(f"Output variable `{variable}` must be a string") if len(value) > dify_config.CODE_MAX_STRING_LENGTH: - raise ValueError( + raise OutputValidationError( f"The length of output variable `{variable}` must be" f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters" ) @@ -97,10 +103,10 @@ class CodeNode(BaseNode[CodeNodeData]): if value is None: return None else: - raise ValueError(f"Output variable `{variable}` must be a number") + raise OutputValidationError(f"Output variable `{variable}` must be a number") if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER: - raise ValueError( + raise OutputValidationError( f"Output variable `{variable}` is out of range," f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}." ) @@ -108,7 +114,7 @@ class CodeNode(BaseNode[CodeNodeData]): if isinstance(value, float): # raise error if precision is too high if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION: - raise ValueError( + raise OutputValidationError( f"Output variable `{variable}` has too high precision," f" it must be less than {dify_config.CODE_MAX_PRECISION} digits." ) @@ -125,7 +131,7 @@ class CodeNode(BaseNode[CodeNodeData]): :return: """ if depth > dify_config.CODE_MAX_DEPTH: - raise ValueError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.") + raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.") transformed_result = {} if output_schema is None: @@ -177,14 +183,14 @@ class CodeNode(BaseNode[CodeNodeData]): depth=depth + 1, ) else: - raise ValueError( + raise OutputValidationError( f"Output {prefix}.{output_name} is not a valid array." f" make sure all elements are of the same type." ) elif output_value is None: pass else: - raise ValueError(f"Output {prefix}.{output_name} is not a valid type.") + raise OutputValidationError(f"Output {prefix}.{output_name} is not a valid type.") return result @@ -192,7 +198,7 @@ class CodeNode(BaseNode[CodeNodeData]): for output_name, output_config in output_schema.items(): dot = "." if prefix else "" if output_name not in result: - raise ValueError(f"Output {prefix}{dot}{output_name} is missing.") + raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.") if output_config.type == "object": # check if output is object @@ -200,7 +206,7 @@ class CodeNode(BaseNode[CodeNodeData]): if isinstance(result.get(output_name), type(None)): transformed_result[output_name] = None else: - raise ValueError( + raise OutputValidationError( f"Output {prefix}{dot}{output_name} is not an object," f" got {type(result.get(output_name))} instead." ) @@ -228,13 +234,13 @@ class CodeNode(BaseNode[CodeNodeData]): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: - raise ValueError( + raise OutputValidationError( f"Output {prefix}{dot}{output_name} is not an array," f" got {type(result.get(output_name))} instead." ) else: if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH: - raise ValueError( + raise OutputValidationError( f"The length of output variable `{prefix}{dot}{output_name}` must be" f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements." ) @@ -249,13 +255,13 @@ class CodeNode(BaseNode[CodeNodeData]): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: - raise ValueError( + raise OutputValidationError( f"Output {prefix}{dot}{output_name} is not an array," f" got {type(result.get(output_name))} instead." ) else: if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH: - raise ValueError( + raise OutputValidationError( f"The length of output variable `{prefix}{dot}{output_name}` must be" f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements." ) @@ -270,13 +276,13 @@ class CodeNode(BaseNode[CodeNodeData]): if isinstance(result[output_name], type(None)): transformed_result[output_name] = None else: - raise ValueError( + raise OutputValidationError( f"Output {prefix}{dot}{output_name} is not an array," f" got {type(result.get(output_name))} instead." ) else: if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH: - raise ValueError( + raise OutputValidationError( f"The length of output variable `{prefix}{dot}{output_name}` must be" f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements." ) @@ -286,7 +292,7 @@ class CodeNode(BaseNode[CodeNodeData]): if value is None: pass else: - raise ValueError( + raise OutputValidationError( f"Output {prefix}{dot}{output_name}[{i}] is not an object," f" got {type(value)} instead at index {i}." ) @@ -303,13 +309,13 @@ class CodeNode(BaseNode[CodeNodeData]): for i, value in enumerate(result[output_name]) ] else: - raise ValueError(f"Output type {output_config.type} is not supported.") + raise OutputValidationError(f"Output type {output_config.type} is not supported.") parameters_validated[output_name] = True # check if all output parameters are validated if len(parameters_validated) != len(result): - raise ValueError("Not all output parameters are validated.") + raise CodeNodeError("Not all output parameters are validated.") return transformed_result diff --git a/api/core/workflow/nodes/code/exc.py b/api/core/workflow/nodes/code/exc.py new file mode 100644 index 0000000000..d6334fd554 --- /dev/null +++ b/api/core/workflow/nodes/code/exc.py @@ -0,0 +1,16 @@ +class CodeNodeError(ValueError): + """Base class for code node errors.""" + + pass + + +class OutputValidationError(CodeNodeError): + """Raised when there is an output validation error.""" + + pass + + +class DepthLimitError(CodeNodeError): + """Raised when the depth limit is reached.""" + + pass From 2adab7f71a9052b153f338ce7b921ac2ec1aa41e Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 4 Nov 2024 15:22:50 +0800 Subject: [PATCH 04/29] refactor(http_request): add custom exception handling for HTTP request nodes (#10219) --- api/core/workflow/nodes/http_request/exc.py | 18 +++++++++++++++++ .../workflow/nodes/http_request/executor.py | 20 ++++++++++++------- api/core/workflow/nodes/http_request/node.py | 3 ++- 3 files changed, 33 insertions(+), 8 deletions(-) create mode 100644 api/core/workflow/nodes/http_request/exc.py diff --git a/api/core/workflow/nodes/http_request/exc.py b/api/core/workflow/nodes/http_request/exc.py new file mode 100644 index 0000000000..7a5ab7dbc1 --- /dev/null +++ b/api/core/workflow/nodes/http_request/exc.py @@ -0,0 +1,18 @@ +class HttpRequestNodeError(ValueError): + """Custom error for HTTP request node.""" + + +class AuthorizationConfigError(HttpRequestNodeError): + """Raised when authorization config is missing or invalid.""" + + +class FileFetchError(HttpRequestNodeError): + """Raised when a file cannot be fetched.""" + + +class InvalidHttpMethodError(HttpRequestNodeError): + """Raised when an invalid HTTP method is used.""" + + +class ResponseSizeError(HttpRequestNodeError): + """Raised when the response size exceeds the allowed threshold.""" diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 6872478299..6204fc2644 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -18,6 +18,12 @@ from .entities import ( HttpRequestNodeTimeout, Response, ) +from .exc import ( + AuthorizationConfigError, + FileFetchError, + InvalidHttpMethodError, + ResponseSizeError, +) BODY_TYPE_TO_CONTENT_TYPE = { "json": "application/json", @@ -51,7 +57,7 @@ class Executor: # If authorization API key is present, convert the API key using the variable pool if node_data.authorization.type == "api-key": if node_data.authorization.config is None: - raise ValueError("authorization config is required") + raise AuthorizationConfigError("authorization config is required") node_data.authorization.config.api_key = variable_pool.convert_template( node_data.authorization.config.api_key ).text @@ -116,7 +122,7 @@ class Executor: file_selector = data[0].file file_variable = self.variable_pool.get_file(file_selector) if file_variable is None: - raise ValueError(f"cannot fetch file with selector {file_selector}") + raise FileFetchError(f"cannot fetch file with selector {file_selector}") file = file_variable.value self.content = file_manager.download(file) case "x-www-form-urlencoded": @@ -155,12 +161,12 @@ class Executor: headers = deepcopy(self.headers) or {} if self.auth.type == "api-key": if self.auth.config is None: - raise ValueError("self.authorization config is required") + raise AuthorizationConfigError("self.authorization config is required") if authorization.config is None: - raise ValueError("authorization config is required") + raise AuthorizationConfigError("authorization config is required") if self.auth.config.api_key is None: - raise ValueError("api_key is required") + raise AuthorizationConfigError("api_key is required") if not authorization.config.header: authorization.config.header = "Authorization" @@ -183,7 +189,7 @@ class Executor: else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE ) if executor_response.size > threshold_size: - raise ValueError( + raise ResponseSizeError( f'{"File" if executor_response.is_file else "Text"} size is too large,' f' max size is {threshold_size / 1024 / 1024:.2f} MB,' f' but current size is {executor_response.readable_size}.' @@ -196,7 +202,7 @@ class Executor: do http request depending on api bundle """ if self.method not in {"get", "head", "post", "put", "delete", "patch"}: - raise ValueError(f"Invalid http method {self.method}") + raise InvalidHttpMethodError(f"Invalid http method {self.method}") request_args = { "url": self.url, diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index a037bee665..61c661e587 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -20,6 +20,7 @@ from .entities import ( HttpRequestNodeTimeout, Response, ) +from .exc import HttpRequestNodeError HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, @@ -77,7 +78,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): "request": http_executor.to_log(), }, ) - except Exception as e: + except HttpRequestNodeError as e: logger.warning(f"http request node {self.node_id} failed to run: {e}") return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, From 38bca6731c64dfea33fba62a5e71a70504479d8c Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 4 Nov 2024 15:22:58 +0800 Subject: [PATCH 05/29] refactor(workflow): introduce specific error handling for LLM nodes (#10221) --- api/core/workflow/nodes/llm/exc.py | 26 +++++++++++++++++++++++ api/core/workflow/nodes/llm/node.py | 33 ++++++++++++++++++----------- 2 files changed, 47 insertions(+), 12 deletions(-) create mode 100644 api/core/workflow/nodes/llm/exc.py diff --git a/api/core/workflow/nodes/llm/exc.py b/api/core/workflow/nodes/llm/exc.py new file mode 100644 index 0000000000..f858be2515 --- /dev/null +++ b/api/core/workflow/nodes/llm/exc.py @@ -0,0 +1,26 @@ +class LLMNodeError(ValueError): + """Base class for LLM Node errors.""" + + +class VariableNotFoundError(LLMNodeError): + """Raised when a required variable is not found.""" + + +class InvalidContextStructureError(LLMNodeError): + """Raised when the context structure is invalid.""" + + +class InvalidVariableTypeError(LLMNodeError): + """Raised when the variable type is invalid.""" + + +class ModelNotExistError(LLMNodeError): + """Raised when the specified model does not exist.""" + + +class LLMModeRequiredError(LLMNodeError): + """Raised when LLM mode is required but not provided.""" + + +class NoPromptFoundError(LLMNodeError): + """Raised when no prompt is found in the LLM configuration.""" diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index bb9290ddc2..47b0e25d9c 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -56,6 +56,15 @@ from .entities import ( LLMNodeData, ModelConfig, ) +from .exc import ( + InvalidContextStructureError, + InvalidVariableTypeError, + LLMModeRequiredError, + LLMNodeError, + ModelNotExistError, + NoPromptFoundError, + VariableNotFoundError, +) if TYPE_CHECKING: from core.file.models import File @@ -115,7 +124,7 @@ class LLMNode(BaseNode[LLMNodeData]): if self.node_data.memory: query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) if not query: - raise ValueError("Query not found") + raise VariableNotFoundError("Query not found") query = query.text else: query = None @@ -161,7 +170,7 @@ class LLMNode(BaseNode[LLMNodeData]): usage = event.usage finish_reason = event.finish_reason break - except Exception as e: + except LLMNodeError as e: yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -275,7 +284,7 @@ class LLMNode(BaseNode[LLMNodeData]): variable_name = variable_selector.variable variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable is None: - raise ValueError(f"Variable {variable_selector.variable} not found") + raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") def parse_dict(input_dict: Mapping[str, Any]) -> str: """ @@ -325,7 +334,7 @@ class LLMNode(BaseNode[LLMNodeData]): for variable_selector in variable_selectors: variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable is None: - raise ValueError(f"Variable {variable_selector.variable} not found") + raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") if isinstance(variable, NoneSegment): inputs[variable_selector.variable] = "" inputs[variable_selector.variable] = variable.to_object() @@ -338,7 +347,7 @@ class LLMNode(BaseNode[LLMNodeData]): for variable_selector in query_variable_selectors: variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if variable is None: - raise ValueError(f"Variable {variable_selector.variable} not found") + raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") if isinstance(variable, NoneSegment): continue inputs[variable_selector.variable] = variable.to_object() @@ -355,7 +364,7 @@ class LLMNode(BaseNode[LLMNodeData]): return variable.value elif isinstance(variable, NoneSegment | ArrayAnySegment): return [] - raise ValueError(f"Invalid variable type: {type(variable)}") + raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}") def _fetch_context(self, node_data: LLMNodeData): if not node_data.context.enabled: @@ -376,7 +385,7 @@ class LLMNode(BaseNode[LLMNodeData]): context_str += item + "\n" else: if "content" not in item: - raise ValueError(f"Invalid context structure: {item}") + raise InvalidContextStructureError(f"Invalid context structure: {item}") context_str += item["content"] + "\n" @@ -441,7 +450,7 @@ class LLMNode(BaseNode[LLMNodeData]): ) if provider_model is None: - raise ValueError(f"Model {model_name} not exist.") + raise ModelNotExistError(f"Model {model_name} not exist.") if provider_model.status == ModelStatus.NO_CONFIGURE: raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") @@ -460,12 +469,12 @@ class LLMNode(BaseNode[LLMNodeData]): # get model mode model_mode = node_data_model.mode if not model_mode: - raise ValueError("LLM mode is required.") + raise LLMModeRequiredError("LLM mode is required.") model_schema = model_type_instance.get_model_schema(model_name, model_credentials) if not model_schema: - raise ValueError(f"Model {model_name} not exist.") + raise ModelNotExistError(f"Model {model_name} not exist.") return model_instance, ModelConfigWithCredentialsEntity( provider=provider_name, @@ -564,7 +573,7 @@ class LLMNode(BaseNode[LLMNodeData]): filtered_prompt_messages.append(prompt_message) if not filtered_prompt_messages: - raise ValueError( + raise NoPromptFoundError( "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) @@ -636,7 +645,7 @@ class LLMNode(BaseNode[LLMNodeData]): variable_template_parser = VariableTemplateParser(template=prompt_template.text) variable_selectors = variable_template_parser.extract_variable_selectors() else: - raise ValueError(f"Invalid prompt template type: {type(prompt_template)}") + raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}") variable_mapping = {} for variable_selector in variable_selectors: From 9369cc44e615506f3b8ce6a22a7784671a7d2667 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 4 Nov 2024 15:23:08 +0800 Subject: [PATCH 06/29] refactor(list_operator): replace ValueError with InvalidKeyError (#10222) --- api/core/workflow/nodes/list_operator/node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 6053a15d96..0406b97eb8 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -295,4 +295,4 @@ def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Seq extract_func = _get_file_extract_number_func(key=order_by) return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc") else: - raise ValueError(f"Invalid order key: {order_by}") + raise InvalidKeyError(f"Invalid order key: {order_by}") From da204c131d39bfb2c5a8b68be06c1ec84c304268 Mon Sep 17 00:00:00 2001 From: shisaru292 <87224749+shisaru292@users.noreply.github.com> Date: Mon, 4 Nov 2024 15:23:18 +0800 Subject: [PATCH 07/29] fix: missing working directory parameter in script (#10226) --- dev/reformat | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dev/reformat b/dev/reformat index ad83e897d9..94a7f3e6fe 100755 --- a/dev/reformat +++ b/dev/reformat @@ -9,10 +9,10 @@ if ! command -v ruff &> /dev/null || ! command -v dotenv-linter &> /dev/null; th fi # run ruff linter -ruff check --fix ./api +poetry run -C api ruff check --fix ./api # run ruff formatter -ruff format ./api +poetry run -C api ruff format ./api # run dotenv-linter linter -dotenv-linter ./api/.env.example ./web/.env.example +poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example From 64523422228db11b035f207c805e94345c63c288 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 4 Nov 2024 15:55:34 +0800 Subject: [PATCH 08/29] feat(workflow): add configurable workflow file upload limit (#10176) Co-authored-by: JzoNg --- api/.env.example | 3 + api/configs/feature/__init__.py | 5 ++ api/controllers/common/fields.py | 24 ++++++ api/controllers/common/helpers.py | 39 +++++++++ api/controllers/console/explore/parameter.py | 81 +++---------------- api/controllers/console/files/__init__.py | 1 + api/controllers/service_api/app/app.py | 78 +++--------------- api/controllers/web/app.py | 78 +++--------------- .../features/file_upload/manager.py | 5 +- api/fields/file_fields.py | 1 + api/models/__init__.py | 2 - api/models/model.py | 13 +-- docker/.env.example | 1 + docker/docker-compose.yaml | 1 + .../base/file-uploader/constants.ts | 1 + .../components/base/file-uploader/hooks.ts | 3 + .../_base/components/file-upload-setting.tsx | 10 ++- web/models/common.ts | 2 +- 18 files changed, 125 insertions(+), 223 deletions(-) create mode 100644 api/controllers/common/fields.py diff --git a/api/.env.example b/api/.env.example index c07c292369..f7bcab6d6d 100644 --- a/api/.env.example +++ b/api/.env.example @@ -327,6 +327,9 @@ SSRF_DEFAULT_MAX_RETRIES=3 BATCH_UPLOAD_LIMIT=10 KEYWORD_DATA_SOURCE_TYPE=database +# Workflow file upload limit +WORKFLOW_FILE_UPLOAD_LIMIT=10 + # CODE EXECUTION CONFIGURATION CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194 CODE_EXECUTION_API_KEY=dify-sandbox diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 0fa926038d..533a24dcbd 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -216,6 +216,11 @@ class FileUploadConfig(BaseSettings): default=20, ) + WORKFLOW_FILE_UPLOAD_LIMIT: PositiveInt = Field( + description="Maximum number of files allowed in a workflow upload operation", + default=10, + ) + class HttpConfig(BaseSettings): """ diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py new file mode 100644 index 0000000000..79869916ed --- /dev/null +++ b/api/controllers/common/fields.py @@ -0,0 +1,24 @@ +from flask_restful import fields + +parameters__system_parameters = { + "image_file_size_limit": fields.Integer, + "video_file_size_limit": fields.Integer, + "audio_file_size_limit": fields.Integer, + "file_size_limit": fields.Integer, + "workflow_file_upload_limit": fields.Integer, +} + +parameters_fields = { + "opening_statement": fields.String, + "suggested_questions": fields.Raw, + "suggested_questions_after_answer": fields.Raw, + "speech_to_text": fields.Raw, + "text_to_speech": fields.Raw, + "retriever_resource": fields.Raw, + "annotation_reply": fields.Raw, + "more_like_this": fields.Raw, + "user_input_form": fields.Raw, + "sensitive_word_avoidance": fields.Raw, + "file_upload": fields.Raw, + "system_parameters": fields.Nested(parameters__system_parameters), +} diff --git a/api/controllers/common/helpers.py b/api/controllers/common/helpers.py index ed24b265ef..2bae203712 100644 --- a/api/controllers/common/helpers.py +++ b/api/controllers/common/helpers.py @@ -2,11 +2,15 @@ import mimetypes import os import re import urllib.parse +from collections.abc import Mapping +from typing import Any from uuid import uuid4 import httpx from pydantic import BaseModel +from configs import dify_config + class FileInfo(BaseModel): filename: str @@ -56,3 +60,38 @@ def guess_file_info_from_response(response: httpx.Response): mimetype=mimetype, size=int(response.headers.get("Content-Length", -1)), ) + + +def get_parameters_from_feature_dict(*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]): + return { + "opening_statement": features_dict.get("opening_statement"), + "suggested_questions": features_dict.get("suggested_questions", []), + "suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}), + "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), + "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), + "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), + "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), + "more_like_this": features_dict.get("more_like_this", {"enabled": False}), + "user_input_form": user_input_form, + "sensitive_word_avoidance": features_dict.get( + "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} + ), + "file_upload": features_dict.get( + "file_upload", + { + "image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + }, + ), + "system_parameters": { + "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, + "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, + "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, + "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, + "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, + }, + } diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 7c7580e3c6..fee52248a6 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,6 +1,7 @@ -from flask_restful import fields, marshal_with +from flask_restful import marshal_with -from configs import dify_config +from controllers.common import fields +from controllers.common import helpers as controller_helpers from controllers.console import api from controllers.console.app.error import AppUnavailableError from controllers.console.explore.wraps import InstalledAppResource @@ -11,43 +12,14 @@ from services.app_service import AppService class AppParameterApi(InstalledAppResource): """Resource for app variables.""" - variable_fields = { - "key": fields.String, - "name": fields.String, - "description": fields.String, - "type": fields.String, - "default": fields.String, - "max_length": fields.Integer, - "options": fields.List(fields.String), - } - - system_parameters_fields = { - "image_file_size_limit": fields.Integer, - "video_file_size_limit": fields.Integer, - "audio_file_size_limit": fields.Integer, - "file_size_limit": fields.Integer, - } - - parameters_fields = { - "opening_statement": fields.String, - "suggested_questions": fields.Raw, - "suggested_questions_after_answer": fields.Raw, - "speech_to_text": fields.Raw, - "text_to_speech": fields.Raw, - "retriever_resource": fields.Raw, - "annotation_reply": fields.Raw, - "more_like_this": fields.Raw, - "user_input_form": fields.Raw, - "sensitive_word_avoidance": fields.Raw, - "file_upload": fields.Raw, - "system_parameters": fields.Nested(system_parameters_fields), - } - - @marshal_with(parameters_fields) + @marshal_with(fields.parameters_fields) def get(self, installed_app: InstalledApp): """Retrieve app parameters.""" app_model = installed_app.app + if app_model is None: + raise AppUnavailableError() + if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: workflow = app_model.workflow if workflow is None: @@ -57,43 +29,16 @@ class AppParameterApi(InstalledAppResource): user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config + if app_model_config is None: + raise AppUnavailableError() + features_dict = app_model_config.to_dict() user_input_form = features_dict.get("user_input_form", []) - return { - "opening_statement": features_dict.get("opening_statement"), - "suggested_questions": features_dict.get("suggested_questions", []), - "suggested_questions_after_answer": features_dict.get( - "suggested_questions_after_answer", {"enabled": False} - ), - "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), - "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), - "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), - "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), - "more_like_this": features_dict.get("more_like_this", {"enabled": False}), - "user_input_form": user_input_form, - "sensitive_word_avoidance": features_dict.get( - "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} - ), - "file_upload": features_dict.get( - "file_upload", - { - "image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"], - } - }, - ), - "system_parameters": { - "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, - "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, - "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, - "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, - }, - } + return controller_helpers.get_parameters_from_feature_dict( + features_dict=features_dict, user_input_form=user_input_form + ) class ExploreAppMetaApi(InstalledAppResource): diff --git a/api/controllers/console/files/__init__.py b/api/controllers/console/files/__init__.py index 69ee7eaabd..6c7bd8acfd 100644 --- a/api/controllers/console/files/__init__.py +++ b/api/controllers/console/files/__init__.py @@ -37,6 +37,7 @@ class FileApi(Resource): "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, + "workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT, }, 200 @setup_required diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 9a4cdc26cd..88b13faa52 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,6 +1,7 @@ -from flask_restful import Resource, fields, marshal_with +from flask_restful import Resource, marshal_with -from configs import dify_config +from controllers.common import fields +from controllers.common import helpers as controller_helpers from controllers.service_api import api from controllers.service_api.app.error import AppUnavailableError from controllers.service_api.wraps import validate_app_token @@ -11,40 +12,8 @@ from services.app_service import AppService class AppParameterApi(Resource): """Resource for app variables.""" - variable_fields = { - "key": fields.String, - "name": fields.String, - "description": fields.String, - "type": fields.String, - "default": fields.String, - "max_length": fields.Integer, - "options": fields.List(fields.String), - } - - system_parameters_fields = { - "image_file_size_limit": fields.Integer, - "video_file_size_limit": fields.Integer, - "audio_file_size_limit": fields.Integer, - "file_size_limit": fields.Integer, - } - - parameters_fields = { - "opening_statement": fields.String, - "suggested_questions": fields.Raw, - "suggested_questions_after_answer": fields.Raw, - "speech_to_text": fields.Raw, - "text_to_speech": fields.Raw, - "retriever_resource": fields.Raw, - "annotation_reply": fields.Raw, - "more_like_this": fields.Raw, - "user_input_form": fields.Raw, - "sensitive_word_avoidance": fields.Raw, - "file_upload": fields.Raw, - "system_parameters": fields.Nested(system_parameters_fields), - } - @validate_app_token - @marshal_with(parameters_fields) + @marshal_with(fields.parameters_fields) def get(self, app_model: App): """Retrieve app parameters.""" if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: @@ -56,43 +25,16 @@ class AppParameterApi(Resource): user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config + if app_model_config is None: + raise AppUnavailableError() + features_dict = app_model_config.to_dict() user_input_form = features_dict.get("user_input_form", []) - return { - "opening_statement": features_dict.get("opening_statement"), - "suggested_questions": features_dict.get("suggested_questions", []), - "suggested_questions_after_answer": features_dict.get( - "suggested_questions_after_answer", {"enabled": False} - ), - "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), - "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), - "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), - "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), - "more_like_this": features_dict.get("more_like_this", {"enabled": False}), - "user_input_form": user_input_form, - "sensitive_word_avoidance": features_dict.get( - "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} - ), - "file_upload": features_dict.get( - "file_upload", - { - "image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"], - } - }, - ), - "system_parameters": { - "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, - "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, - "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, - "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, - }, - } + return controller_helpers.get_parameters_from_feature_dict( + features_dict=features_dict, user_input_form=user_input_form + ) class AppMetaApi(Resource): diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 974d2cff94..cc8255ccf4 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,6 +1,7 @@ -from flask_restful import fields, marshal_with +from flask_restful import marshal_with -from configs import dify_config +from controllers.common import fields +from controllers.common import helpers as controller_helpers from controllers.web import api from controllers.web.error import AppUnavailableError from controllers.web.wraps import WebApiResource @@ -11,39 +12,7 @@ from services.app_service import AppService class AppParameterApi(WebApiResource): """Resource for app variables.""" - variable_fields = { - "key": fields.String, - "name": fields.String, - "description": fields.String, - "type": fields.String, - "default": fields.String, - "max_length": fields.Integer, - "options": fields.List(fields.String), - } - - system_parameters_fields = { - "image_file_size_limit": fields.Integer, - "video_file_size_limit": fields.Integer, - "audio_file_size_limit": fields.Integer, - "file_size_limit": fields.Integer, - } - - parameters_fields = { - "opening_statement": fields.String, - "suggested_questions": fields.Raw, - "suggested_questions_after_answer": fields.Raw, - "speech_to_text": fields.Raw, - "text_to_speech": fields.Raw, - "retriever_resource": fields.Raw, - "annotation_reply": fields.Raw, - "more_like_this": fields.Raw, - "user_input_form": fields.Raw, - "sensitive_word_avoidance": fields.Raw, - "file_upload": fields.Raw, - "system_parameters": fields.Nested(system_parameters_fields), - } - - @marshal_with(parameters_fields) + @marshal_with(fields.parameters_fields) def get(self, app_model: App, end_user): """Retrieve app parameters.""" if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}: @@ -55,43 +24,16 @@ class AppParameterApi(WebApiResource): user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config + if app_model_config is None: + raise AppUnavailableError() + features_dict = app_model_config.to_dict() user_input_form = features_dict.get("user_input_form", []) - return { - "opening_statement": features_dict.get("opening_statement"), - "suggested_questions": features_dict.get("suggested_questions", []), - "suggested_questions_after_answer": features_dict.get( - "suggested_questions_after_answer", {"enabled": False} - ), - "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), - "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), - "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), - "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), - "more_like_this": features_dict.get("more_like_this", {"enabled": False}), - "user_input_form": user_input_form, - "sensitive_word_avoidance": features_dict.get( - "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} - ), - "file_upload": features_dict.get( - "file_upload", - { - "image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"], - } - }, - ), - "system_parameters": { - "image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT, - "video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT, - "audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT, - "file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT, - }, - } + return controller_helpers.get_parameters_from_feature_dict( + features_dict=features_dict, user_input_form=user_input_form + ) class AppMeta(WebApiResource): diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 42beec2535..d0f75d0b75 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,8 +1,7 @@ from collections.abc import Mapping from typing import Any -from core.file.models import FileExtraConfig -from models import FileUploadConfig +from core.file import FileExtraConfig class FileUploadConfigManager: @@ -43,6 +42,6 @@ class FileUploadConfigManager: if not config.get("file_upload"): config["file_upload"] = {} else: - FileUploadConfig.model_validate(config["file_upload"]) + FileExtraConfig.model_validate(config["file_upload"]) return config, ["file_upload"] diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index 1cddc24b2c..afaacc0568 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -8,6 +8,7 @@ upload_config_fields = { "image_file_size_limit": fields.Integer, "video_file_size_limit": fields.Integer, "audio_file_size_limit": fields.Integer, + "workflow_file_upload_limit": fields.Integer, } file_fields = { diff --git a/api/models/__init__.py b/api/models/__init__.py index 1d8bae6cfa..cd6c7674da 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -6,7 +6,6 @@ from .model import ( AppMode, Conversation, EndUser, - FileUploadConfig, InstalledApp, Message, MessageAnnotation, @@ -50,6 +49,5 @@ __all__ = [ "Tenant", "Conversation", "MessageAnnotation", - "FileUploadConfig", "ToolFile", ] diff --git a/api/models/model.py b/api/models/model.py index e9c6b6732f..bd124cce8e 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1,7 +1,7 @@ import json import re import uuid -from collections.abc import Mapping, Sequence +from collections.abc import Mapping from datetime import datetime from enum import Enum from typing import Any, Literal, Optional @@ -9,7 +9,6 @@ from typing import Any, Literal, Optional import sqlalchemy as sa from flask import request from flask_login import UserMixin -from pydantic import BaseModel, Field from sqlalchemy import Float, func, text from sqlalchemy.orm import Mapped, mapped_column @@ -25,14 +24,6 @@ from .account import Account, Tenant from .types import StringUUID -class FileUploadConfig(BaseModel): - enabled: bool = Field(default=False) - allowed_file_types: Sequence[FileType] = Field(default_factory=list) - allowed_extensions: Sequence[str] = Field(default_factory=list) - allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) - number_limits: int = Field(default=0, gt=0, le=10) - - class DifySetup(db.Model): __tablename__ = "dify_setups" __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) @@ -115,7 +106,7 @@ class App(db.Model): return site @property - def app_model_config(self) -> Optional["AppModelConfig"]: + def app_model_config(self): if self.app_model_config_id: return db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() diff --git a/docker/.env.example b/docker/.env.example index 5b82d62d7b..aa5e102bd0 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -690,6 +690,7 @@ WORKFLOW_MAX_EXECUTION_STEPS=500 WORKFLOW_MAX_EXECUTION_TIME=1200 WORKFLOW_CALL_MAX_DEPTH=5 MAX_VARIABLE_SIZE=204800 +WORKFLOW_FILE_UPLOAD_LIMIT=10 # HTTP request node in workflow configuration HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 12cdf25e70..a26838af10 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -1,4 +1,5 @@ x-shared-env: &shared-api-worker-env + WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10} LOG_LEVEL: ${LOG_LEVEL:-INFO} LOG_FILE: ${LOG_FILE:-} LOG_FILE_MAX_SIZE: ${LOG_FILE_MAX_SIZE:-20} diff --git a/web/app/components/base/file-uploader/constants.ts b/web/app/components/base/file-uploader/constants.ts index 629fe2566b..a749d73c74 100644 --- a/web/app/components/base/file-uploader/constants.ts +++ b/web/app/components/base/file-uploader/constants.ts @@ -3,5 +3,6 @@ export const IMG_SIZE_LIMIT = 10 * 1024 * 1024 export const FILE_SIZE_LIMIT = 15 * 1024 * 1024 export const AUDIO_SIZE_LIMIT = 50 * 1024 * 1024 export const VIDEO_SIZE_LIMIT = 100 * 1024 * 1024 +export const MAX_FILE_UPLOAD_LIMIT = 10 export const FILE_URL_REGEX = /^(https?|ftp):\/\// diff --git a/web/app/components/base/file-uploader/hooks.ts b/web/app/components/base/file-uploader/hooks.ts index 088160691b..c735754ffe 100644 --- a/web/app/components/base/file-uploader/hooks.ts +++ b/web/app/components/base/file-uploader/hooks.ts @@ -18,6 +18,7 @@ import { AUDIO_SIZE_LIMIT, FILE_SIZE_LIMIT, IMG_SIZE_LIMIT, + MAX_FILE_UPLOAD_LIMIT, VIDEO_SIZE_LIMIT, } from '@/app/components/base/file-uploader/constants' import { useToastContext } from '@/app/components/base/toast' @@ -33,12 +34,14 @@ export const useFileSizeLimit = (fileUploadConfig?: FileUploadConfigResponse) => const docSizeLimit = Number(fileUploadConfig?.file_size_limit) * 1024 * 1024 || FILE_SIZE_LIMIT const audioSizeLimit = Number(fileUploadConfig?.audio_file_size_limit) * 1024 * 1024 || AUDIO_SIZE_LIMIT const videoSizeLimit = Number(fileUploadConfig?.video_file_size_limit) * 1024 * 1024 || VIDEO_SIZE_LIMIT + const maxFileUploadLimit = Number(fileUploadConfig?.workflow_file_upload_limit) || MAX_FILE_UPLOAD_LIMIT return { imgSizeLimit, docSizeLimit, audioSizeLimit, videoSizeLimit, + maxFileUploadLimit, } } diff --git a/web/app/components/workflow/nodes/_base/components/file-upload-setting.tsx b/web/app/components/workflow/nodes/_base/components/file-upload-setting.tsx index 82a3a906cf..42a7213f80 100644 --- a/web/app/components/workflow/nodes/_base/components/file-upload-setting.tsx +++ b/web/app/components/workflow/nodes/_base/components/file-upload-setting.tsx @@ -39,7 +39,13 @@ const FileUploadSetting: FC = ({ allowed_file_extensions, } = payload const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig) - const { imgSizeLimit, docSizeLimit, audioSizeLimit, videoSizeLimit } = useFileSizeLimit(fileUploadConfigResponse) + const { + imgSizeLimit, + docSizeLimit, + audioSizeLimit, + videoSizeLimit, + maxFileUploadLimit, + } = useFileSizeLimit(fileUploadConfigResponse) const handleSupportFileTypeChange = useCallback((type: SupportUploadFileTypes) => { const newPayload = produce(payload, (draft) => { @@ -156,7 +162,7 @@ const FileUploadSetting: FC = ({ diff --git a/web/models/common.ts b/web/models/common.ts index 9ab27a6018..dc2b1120b9 100644 --- a/web/models/common.ts +++ b/web/models/common.ts @@ -216,7 +216,7 @@ export type FileUploadConfigResponse = { file_size_limit: number // default is 15MB audio_file_size_limit?: number // default is 50MB video_file_size_limit?: number // default is 100MB - + workflow_file_upload_limit?: number // default is 10 } export type InvitationResult = { From 2aa171c348dbaeacf6e9604c976fec772bd1df5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=96=B9=E7=A8=8B?= Date: Mon, 4 Nov 2024 17:22:02 +0800 Subject: [PATCH 09/29] Using a dedicated interface to obtain the token credential for the gitee.ai provider (#10243) --- .../model_providers/gitee_ai/gitee_ai.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py b/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py index ca67594ce4..14aa811905 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py +++ b/api/core/model_runtime/model_providers/gitee_ai/gitee_ai.py @@ -1,6 +1,7 @@ import logging -from core.model_runtime.entities.model_entities import ModelType +import requests + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.model_provider import ModelProvider @@ -16,8 +17,18 @@ class GiteeAIProvider(ModelProvider): :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. """ try: - model_instance = self.get_model_instance(ModelType.LLM) - model_instance.validate_credentials(model="Qwen2-7B-Instruct", credentials=credentials) + api_key = credentials.get("api_key") + if not api_key: + raise CredentialsValidateFailedError("Credentials validation failed: api_key not given") + + # send a get request to validate the credentials + headers = {"Authorization": f"Bearer {api_key}"} + response = requests.get("https://ai.gitee.com/api/base/account/me", headers=headers, timeout=(10, 300)) + + if response.status_code != 200: + raise CredentialsValidateFailedError( + f"Credentials validation failed with status code {response.status_code}" + ) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: From 87c1de66f21547eb5a0df939dda5210352659f5e Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 4 Nov 2024 17:48:10 +0800 Subject: [PATCH 10/29] chore(Dockerfile): upgrade zlib arm64 (#10244) --- api/Dockerfile | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/api/Dockerfile b/api/Dockerfile index 1f84fab657..eb37303182 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -55,12 +55,7 @@ RUN apt-get update \ && echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \ && apt-get update \ # For Security - && apt-get install -y --no-install-recommends expat=2.6.3-2 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 \ - && if [ "$(dpkg --print-architecture)" = "amd64" ]; then \ - apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1+b1; \ - else \ - apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1; \ - fi \ + && apt-get install -y --no-install-recommends expat=2.6.3-2 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \ # install a chinese font to support the use of tools like matplotlib && apt-get install -y fonts-noto-cjk \ && apt-get autoremove -y \ From 6b0de08157c5475ef9e1006f9bfe09bba85216c0 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 4 Nov 2024 18:34:55 +0800 Subject: [PATCH 11/29] fix(validation): allow to use 0 in the inputs form (#10255) --- api/core/app/apps/base_app_generator.py | 78 +++++++++++-------- .../core/app/apps/test_base_app_generator.py | 52 +++++++++++++ 2 files changed, 97 insertions(+), 33 deletions(-) create mode 100644 api/tests/unit_tests/core/app/apps/test_base_app_generator.py diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 7daff83533..d8e38476c7 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -22,7 +22,10 @@ class BaseAppGenerator: user_inputs = user_inputs or {} # Filter input variables from form configuration, handle required fields, default values, and option values variables = app_config.variables - user_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables} + user_inputs = { + var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var) + for var in variables + } user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()} # Convert files in inputs to File entity_dictionary = {item.variable: item for item in app_config.variables} @@ -74,57 +77,66 @@ class BaseAppGenerator: return user_inputs - def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"): - user_input_value = inputs.get(var.variable) + def _validate_inputs( + self, + *, + variable_entity: "VariableEntity", + value: Any, + ): + if value is None: + if variable_entity.required: + raise ValueError(f"{variable_entity.variable} is required in input form") + return value - if not user_input_value: - if var.required: - raise ValueError(f"{var.variable} is required in input form") - else: - return None - - if var.type in { + if variable_entity.type in { VariableEntityType.TEXT_INPUT, VariableEntityType.SELECT, VariableEntityType.PARAGRAPH, - } and not isinstance(user_input_value, str): - raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string") + } and not isinstance(value, str): + raise ValueError( + f"(type '{variable_entity.type}') {variable_entity.variable} in input form must be a string" + ) - if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str): + if variable_entity.type == VariableEntityType.NUMBER and isinstance(value, str): # may raise ValueError if user_input_value is not a valid number try: - if "." in user_input_value: - return float(user_input_value) + if "." in value: + return float(value) else: - return int(user_input_value) + return int(value) except ValueError: - raise ValueError(f"{var.variable} in input form must be a valid number") + raise ValueError(f"{variable_entity.variable} in input form must be a valid number") - match var.type: + match variable_entity.type: case VariableEntityType.SELECT: - if user_input_value not in var.options: - raise ValueError(f"{var.variable} in input form must be one of the following: {var.options}") + if value not in variable_entity.options: + raise ValueError( + f"{variable_entity.variable} in input form must be one of the following: " + f"{variable_entity.options}" + ) case VariableEntityType.TEXT_INPUT | VariableEntityType.PARAGRAPH: - if var.max_length and len(user_input_value) > var.max_length: - raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters") + if variable_entity.max_length and len(value) > variable_entity.max_length: + raise ValueError( + f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} " + "characters" + ) case VariableEntityType.FILE: - if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File): - raise ValueError(f"{var.variable} in input form must be a file") + if not isinstance(value, dict) and not isinstance(value, File): + raise ValueError(f"{variable_entity.variable} in input form must be a file") case VariableEntityType.FILE_LIST: # if number of files exceeds the limit, raise ValueError if not ( - isinstance(user_input_value, list) - and ( - all(isinstance(item, dict) for item in user_input_value) - or all(isinstance(item, File) for item in user_input_value) - ) + isinstance(value, list) + and (all(isinstance(item, dict) for item in value) or all(isinstance(item, File) for item in value)) ): - raise ValueError(f"{var.variable} in input form must be a list of files") + raise ValueError(f"{variable_entity.variable} in input form must be a list of files") - if var.max_length and len(user_input_value) > var.max_length: - raise ValueError(f"{var.variable} in input form must be less than {var.max_length} files") + if variable_entity.max_length and len(value) > variable_entity.max_length: + raise ValueError( + f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files" + ) - return user_input_value + return value def _sanitize_value(self, value: Any) -> Any: if isinstance(value, str): diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py new file mode 100644 index 0000000000..a6bf43ab0c --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -0,0 +1,52 @@ +import pytest + +from core.app.app_config.entities import VariableEntity, VariableEntityType +from core.app.apps.base_app_generator import BaseAppGenerator + + +def test_validate_inputs_with_zero(): + base_app_generator = BaseAppGenerator() + + var = VariableEntity( + variable="test_var", + label="test_var", + type=VariableEntityType.NUMBER, + required=True, + ) + + # Test with input 0 + result = base_app_generator._validate_inputs( + variable_entity=var, + value=0, + ) + + assert result == 0 + + # Test with input "0" (string) + result = base_app_generator._validate_inputs( + variable_entity=var, + value="0", + ) + + assert result == 0 + + +def test_validate_input_with_none_for_required_variable(): + base_app_generator = BaseAppGenerator() + + for var_type in VariableEntityType: + var = VariableEntity( + variable="test_var", + label="test_var", + type=var_type, + required=True, + ) + + # Test with input None + with pytest.raises(ValueError) as exc_info: + base_app_generator._validate_inputs( + variable_entity=var, + value=None, + ) + + assert str(exc_info.value) == "test_var is required in input form" From 971defbbbd71cf1f63344619044069b37a87ec75 Mon Sep 17 00:00:00 2001 From: guogeer <1500065870@qq.com> Date: Mon, 4 Nov 2024 18:46:39 +0800 Subject: [PATCH 12/29] fix: buitin tool aippt (#10234) Co-authored-by: jinqi.guo --- .../provider/builtin/aippt/tools/aippt.py | 78 ++++++++++++------- api/core/workflow/nodes/tool/tool_node.py | 2 +- 2 files changed, 50 insertions(+), 30 deletions(-) diff --git a/api/core/tools/provider/builtin/aippt/tools/aippt.py b/api/core/tools/provider/builtin/aippt/tools/aippt.py index dd9371f70d..38123f125a 100644 --- a/api/core/tools/provider/builtin/aippt/tools/aippt.py +++ b/api/core/tools/provider/builtin/aippt/tools/aippt.py @@ -4,7 +4,7 @@ from hmac import new as hmac_new from json import loads as json_loads from threading import Lock from time import sleep, time -from typing import Any, Optional +from typing import Any from httpx import get, post from requests import get as requests_get @@ -15,27 +15,27 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, from core.tools.tool.builtin_tool import BuiltinTool -class AIPPTGenerateTool(BuiltinTool): +class AIPPTGenerateToolAdapter: """ A tool for generating a ppt """ _api_base_url = URL("https://co.aippt.cn/api") _api_token_cache = {} - _api_token_cache_lock: Optional[Lock] = None _style_cache = {} - _style_cache_lock: Optional[Lock] = None + + _api_token_cache_lock = Lock() + _style_cache_lock = Lock() _task = {} _task_type_map = { "auto": 1, "markdown": 7, } + _tool: BuiltinTool - def __init__(self, **kwargs: Any): - super().__init__(**kwargs) - self._api_token_cache_lock = Lock() - self._style_cache_lock = Lock() + def __init__(self, tool: BuiltinTool = None): + self._tool = tool def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: """ @@ -51,11 +51,11 @@ class AIPPTGenerateTool(BuiltinTool): """ title = tool_parameters.get("title", "") if not title: - return self.create_text_message("Please provide a title for the ppt") + return self._tool.create_text_message("Please provide a title for the ppt") model = tool_parameters.get("model", "aippt") if not model: - return self.create_text_message("Please provide a model for the ppt") + return self._tool.create_text_message("Please provide a model for the ppt") outline = tool_parameters.get("outline", "") @@ -68,8 +68,8 @@ class AIPPTGenerateTool(BuiltinTool): ) # get suit - color = tool_parameters.get("color") - style = tool_parameters.get("style") + color: str = tool_parameters.get("color") + style: str = tool_parameters.get("style") if color == "__default__": color_id = "" @@ -93,9 +93,9 @@ class AIPPTGenerateTool(BuiltinTool): # generate ppt _, ppt_url = self._generate_ppt(task_id=task_id, suit_id=suit_id, user_id=user_id) - return self.create_text_message( + return self._tool.create_text_message( """the ppt has been created successfully,""" - f"""the ppt url is {ppt_url}""" + f"""the ppt url is {ppt_url} .""" """please give the ppt url to user and direct user to download it.""" ) @@ -111,8 +111,8 @@ class AIPPTGenerateTool(BuiltinTool): """ headers = { "x-channel": "", - "x-api-key": self.runtime.credentials["aippt_access_key"], - "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-api-key": self._tool.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id), } response = post( str(self._api_base_url / "ai" / "chat" / "v2" / "task"), @@ -139,8 +139,8 @@ class AIPPTGenerateTool(BuiltinTool): headers = { "x-channel": "", - "x-api-key": self.runtime.credentials["aippt_access_key"], - "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-api-key": self._tool.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id), } response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60)) @@ -183,8 +183,8 @@ class AIPPTGenerateTool(BuiltinTool): headers = { "x-channel": "", - "x-api-key": self.runtime.credentials["aippt_access_key"], - "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-api-key": self._tool.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id), } response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60)) @@ -236,14 +236,15 @@ class AIPPTGenerateTool(BuiltinTool): """ headers = { "x-channel": "", - "x-api-key": self.runtime.credentials["aippt_access_key"], - "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + "x-api-key": self._tool.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id), } response = post( str(self._api_base_url / "design" / "v2" / "save"), headers=headers, data={"task_id": task_id, "template_id": suit_id}, + timeout=(10, 60), ) if response.status_code != 200: @@ -350,11 +351,13 @@ class AIPPTGenerateTool(BuiltinTool): return token - @classmethod - def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str: + @staticmethod + def _calculate_sign(access_key: str, secret_key: str, timestamp: int) -> str: return b64encode( hmac_new( - key=secret_key.encode("utf-8"), msg=f"GET@/api/grant/token/@{timestamp}".encode(), digestmod=sha1 + key=secret_key.encode("utf-8"), + msg=f"GET@/api/grant/token/@{timestamp}".encode(), + digestmod=sha1, ).digest() ).decode("utf-8") @@ -419,10 +422,12 @@ class AIPPTGenerateTool(BuiltinTool): :param credentials: the credentials :return: Tuple[list[dict[id, color]], list[dict[id, style]] """ - if not self.runtime.credentials.get("aippt_access_key") or not self.runtime.credentials.get("aippt_secret_key"): + if not self._tool.runtime.credentials.get("aippt_access_key") or not self._tool.runtime.credentials.get( + "aippt_secret_key" + ): raise Exception("Please provide aippt credentials") - return self._get_styles(credentials=self.runtime.credentials, user_id=user_id) + return self._get_styles(credentials=self._tool.runtime.credentials, user_id=user_id) def _get_suit(self, style_id: int, colour_id: int) -> int: """ @@ -430,8 +435,8 @@ class AIPPTGenerateTool(BuiltinTool): """ headers = { "x-channel": "", - "x-api-key": self.runtime.credentials["aippt_access_key"], - "x-token": self._get_api_token(credentials=self.runtime.credentials, user_id="__dify_system__"), + "x-api-key": self._tool.runtime.credentials["aippt_access_key"], + "x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id="__dify_system__"), } response = get( str(self._api_base_url / "template_component" / "suit" / "search"), @@ -496,3 +501,18 @@ class AIPPTGenerateTool(BuiltinTool): ], ), ] + + +class AIPPTGenerateTool(BuiltinTool): + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + return AIPPTGenerateToolAdapter(self)._invoke(user_id, tool_parameters) + + def get_runtime_parameters(self) -> list[ToolParameter]: + return AIPPTGenerateToolAdapter(self).get_runtime_parameters() + + @classmethod + def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str: + return AIPPTGenerateToolAdapter()._get_api_token(credentials, user_id) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index df22130d69..0994ccaedb 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -53,7 +53,7 @@ class ToolNode(BaseNode[ToolNodeData]): ) # get parameters - tool_parameters = tool_runtime.get_runtime_parameters() or [] + tool_parameters = tool_runtime.parameters or [] parameters = self._generate_parameters( tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, From 7a98dab6a4fd01152c543a656b0771d58b600a20 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 5 Nov 2024 09:27:51 +0800 Subject: [PATCH 13/29] refactor(parameter_extractor): implement custom error classes (#10260) --- .../workflow/nodes/parameter_extractor/exc.py | 50 ++++++++++++++++ .../parameter_extractor_node.py | 57 ++++++++++++------- 2 files changed, 86 insertions(+), 21 deletions(-) create mode 100644 api/core/workflow/nodes/parameter_extractor/exc.py diff --git a/api/core/workflow/nodes/parameter_extractor/exc.py b/api/core/workflow/nodes/parameter_extractor/exc.py new file mode 100644 index 0000000000..6511aba185 --- /dev/null +++ b/api/core/workflow/nodes/parameter_extractor/exc.py @@ -0,0 +1,50 @@ +class ParameterExtractorNodeError(ValueError): + """Base error for ParameterExtractorNode.""" + + +class InvalidModelTypeError(ParameterExtractorNodeError): + """Raised when the model is not a Large Language Model.""" + + +class ModelSchemaNotFoundError(ParameterExtractorNodeError): + """Raised when the model schema is not found.""" + + +class InvalidInvokeResultError(ParameterExtractorNodeError): + """Raised when the invoke result is invalid.""" + + +class InvalidTextContentTypeError(ParameterExtractorNodeError): + """Raised when the text content type is invalid.""" + + +class InvalidNumberOfParametersError(ParameterExtractorNodeError): + """Raised when the number of parameters is invalid.""" + + +class RequiredParameterMissingError(ParameterExtractorNodeError): + """Raised when a required parameter is missing.""" + + +class InvalidSelectValueError(ParameterExtractorNodeError): + """Raised when a select value is invalid.""" + + +class InvalidNumberValueError(ParameterExtractorNodeError): + """Raised when a number value is invalid.""" + + +class InvalidBoolValueError(ParameterExtractorNodeError): + """Raised when a bool value is invalid.""" + + +class InvalidStringValueError(ParameterExtractorNodeError): + """Raised when a string value is invalid.""" + + +class InvalidArrayValueError(ParameterExtractorNodeError): + """Raised when an array value is invalid.""" + + +class InvalidModelModeError(ParameterExtractorNodeError): + """Raised when the model mode is invalid.""" diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 49546e9356..b64bde8ac5 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -32,6 +32,21 @@ from extensions.ext_database import db from models.workflow import WorkflowNodeExecutionStatus from .entities import ParameterExtractorNodeData +from .exc import ( + InvalidArrayValueError, + InvalidBoolValueError, + InvalidInvokeResultError, + InvalidModelModeError, + InvalidModelTypeError, + InvalidNumberOfParametersError, + InvalidNumberValueError, + InvalidSelectValueError, + InvalidStringValueError, + InvalidTextContentTypeError, + ModelSchemaNotFoundError, + ParameterExtractorNodeError, + RequiredParameterMissingError, +) from .prompts import ( CHAT_EXAMPLE, CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE, @@ -85,7 +100,7 @@ class ParameterExtractorNode(LLMNode): model_instance, model_config = self._fetch_model_config(node_data.model) if not isinstance(model_instance.model_type_instance, LargeLanguageModel): - raise ValueError("Model is not a Large Language Model") + raise InvalidModelTypeError("Model is not a Large Language Model") llm_model = model_instance.model_type_instance model_schema = llm_model.get_model_schema( @@ -93,7 +108,7 @@ class ParameterExtractorNode(LLMNode): credentials=model_config.credentials, ) if not model_schema: - raise ValueError("Model schema not found") + raise ModelSchemaNotFoundError("Model schema not found") # fetch memory memory = self._fetch_memory( @@ -155,7 +170,7 @@ class ParameterExtractorNode(LLMNode): process_data["usage"] = jsonable_encoder(usage) process_data["tool_call"] = jsonable_encoder(tool_call) process_data["llm_text"] = text - except Exception as e: + except ParameterExtractorNodeError as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=inputs, @@ -177,7 +192,7 @@ class ParameterExtractorNode(LLMNode): try: result = self._validate_result(data=node_data, result=result or {}) - except Exception as e: + except ParameterExtractorNodeError as e: error = str(e) # transform result into standard format @@ -217,11 +232,11 @@ class ParameterExtractorNode(LLMNode): # handle invoke result if not isinstance(invoke_result, LLMResult): - raise ValueError(f"Invalid invoke result: {invoke_result}") + raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}") text = invoke_result.message.content if not isinstance(text, str): - raise ValueError(f"Invalid text content type: {type(text)}. Expected str.") + raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.") usage = invoke_result.usage tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None @@ -344,7 +359,7 @@ class ParameterExtractorNode(LLMNode): files=files, ) else: - raise ValueError(f"Invalid model mode: {model_mode}") + raise InvalidModelModeError(f"Invalid model mode: {model_mode}") def _generate_prompt_engineering_completion_prompt( self, @@ -449,36 +464,36 @@ class ParameterExtractorNode(LLMNode): Validate result. """ if len(data.parameters) != len(result): - raise ValueError("Invalid number of parameters") + raise InvalidNumberOfParametersError("Invalid number of parameters") for parameter in data.parameters: if parameter.required and parameter.name not in result: - raise ValueError(f"Parameter {parameter.name} is required") + raise RequiredParameterMissingError(f"Parameter {parameter.name} is required") if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options: - raise ValueError(f"Invalid `select` value for parameter {parameter.name}") + raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}") if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float): - raise ValueError(f"Invalid `number` value for parameter {parameter.name}") + raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}") if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool): - raise ValueError(f"Invalid `bool` value for parameter {parameter.name}") + raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}") if parameter.type == "string" and not isinstance(result.get(parameter.name), str): - raise ValueError(f"Invalid `string` value for parameter {parameter.name}") + raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}") if parameter.type.startswith("array"): parameters = result.get(parameter.name) if not isinstance(parameters, list): - raise ValueError(f"Invalid `array` value for parameter {parameter.name}") + raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}") nested_type = parameter.type[6:-1] for item in parameters: if nested_type == "number" and not isinstance(item, int | float): - raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}") + raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}") if nested_type == "string" and not isinstance(item, str): - raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}") + raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}") if nested_type == "object" and not isinstance(item, dict): - raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}") + raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}") return result def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict: @@ -634,7 +649,7 @@ class ParameterExtractorNode(LLMNode): user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] else: - raise ValueError(f"Model mode {model_mode} not support.") + raise InvalidModelModeError(f"Model mode {model_mode} not support.") def _get_prompt_engineering_prompt_template( self, @@ -669,7 +684,7 @@ class ParameterExtractorNode(LLMNode): .replace("}γγγ", "") ) else: - raise ValueError(f"Model mode {model_mode} not support.") + raise InvalidModelModeError(f"Model mode {model_mode} not support.") def _calculate_rest_token( self, @@ -683,12 +698,12 @@ class ParameterExtractorNode(LLMNode): model_instance, model_config = self._fetch_model_config(node_data.model) if not isinstance(model_instance.model_type_instance, LargeLanguageModel): - raise ValueError("Model is not a Large Language Model") + raise InvalidModelTypeError("Model is not a Large Language Model") llm_model = model_instance.model_type_instance model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials) if not model_schema: - raise ValueError("Model schema not found") + raise ModelSchemaNotFoundError("Model schema not found") if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) From 9305ad210225dd97813f85cd51cdad3dca379acc Mon Sep 17 00:00:00 2001 From: Matsuda Date: Tue, 5 Nov 2024 10:42:51 +0900 Subject: [PATCH 14/29] feat: support Claude 3.5 Haiku on Amazon Bedrock (#10265) --- .../llm/anthropic.claude-3-5-haiku-v1.yaml | 61 +++++++++++++++++++ .../llm/us.anthropic.claude-3-5-haiku-v1.yaml | 61 +++++++++++++++++++ 2 files changed, 122 insertions(+) create mode 100644 api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-5-haiku-v1.yaml create mode 100644 api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-5-haiku-v1.yaml diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-5-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-5-haiku-v1.yaml new file mode 100644 index 0000000000..7c676136db --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-5-haiku-v1.yaml @@ -0,0 +1,61 @@ +model: anthropic.claude-3-5-haiku-20241022-v1:0 +label: + en_US: Claude 3.5 Haiku +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + # docs: https://docs.anthropic.com/claude/docs/system-prompts + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format +pricing: + input: '0.001' + output: '0.005' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-5-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-5-haiku-v1.yaml new file mode 100644 index 0000000000..a9b66b1925 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-5-haiku-v1.yaml @@ -0,0 +1,61 @@ +model: us.anthropic.claude-3-5-haiku-20241022-v1:0 +label: + en_US: Claude 3.5 Haiku(US.Cross Region Inference) +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html +parameter_rules: + - name: max_tokens + use_template: max_tokens + required: true + type: int + default: 4096 + min: 1 + max: 4096 + help: + zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 + en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. + # docs: https://docs.anthropic.com/claude/docs/system-prompts + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. + - name: response_format + use_template: response_format +pricing: + input: '0.001' + output: '0.005' + unit: '0.001' + currency: USD From 2c4d8dbe9b249d200def5b8c512ef8b8dee56624 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 5 Nov 2024 09:49:43 +0800 Subject: [PATCH 15/29] feat(document_extractor): support tool file in document extractor (#10217) --- api/core/workflow/nodes/document_extractor/node.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index aacee94095..c90017d5e1 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -198,10 +198,8 @@ def _download_file_content(file: File) -> bytes: response = ssrf_proxy.get(file.remote_url) response.raise_for_status() return response.content - elif file.transfer_method == FileTransferMethod.LOCAL_FILE: - return file_manager.download(file) else: - raise ValueError(f"Unsupported transfer method: {file.transfer_method}") + return file_manager.download(file) except Exception as e: raise FileDownloadError(f"Error downloading file: {str(e)}") from e From cca2e7876d67596338e4732e3a5f9424f98705de Mon Sep 17 00:00:00 2001 From: GeorgeCaoJ Date: Tue, 5 Nov 2024 09:56:41 +0800 Subject: [PATCH 16/29] fix(workflow): handle else condition branch addition error in if-else node (#10257) --- .../workflow/nodes/if-else/use-config.ts | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/web/app/components/workflow/nodes/if-else/use-config.ts b/web/app/components/workflow/nodes/if-else/use-config.ts index d1210431a0..41e41f6b8b 100644 --- a/web/app/components/workflow/nodes/if-else/use-config.ts +++ b/web/app/components/workflow/nodes/if-else/use-config.ts @@ -78,24 +78,24 @@ const useConfig = (id: string, payload: IfElseNodeType) => { }) const handleAddCase = useCallback(() => { - const newInputs = produce(inputs, () => { - if (inputs.cases) { + const newInputs = produce(inputs, (draft) => { + if (draft.cases) { const case_id = uuid4() - inputs.cases.push({ + draft.cases.push({ case_id, logical_operator: LogicalOperator.and, conditions: [], }) - if (inputs._targetBranches) { - const elseCaseIndex = inputs._targetBranches.findIndex(branch => branch.id === 'false') + if (draft._targetBranches) { + const elseCaseIndex = draft._targetBranches.findIndex(branch => branch.id === 'false') if (elseCaseIndex > -1) { - inputs._targetBranches = branchNameCorrect([ - ...inputs._targetBranches.slice(0, elseCaseIndex), + draft._targetBranches = branchNameCorrect([ + ...draft._targetBranches.slice(0, elseCaseIndex), { id: case_id, name: '', }, - ...inputs._targetBranches.slice(elseCaseIndex), + ...draft._targetBranches.slice(elseCaseIndex), ]) } } From d1505b15c40aa7bfe71bdc6f07efacc885985b34 Mon Sep 17 00:00:00 2001 From: Novice Date: Tue, 5 Nov 2024 10:32:49 +0800 Subject: [PATCH 17/29] feat: Iteration node support parallel mode (#9493) --- .../advanced_chat/generate_task_pipeline.py | 3 +- .../apps/workflow/generate_task_pipeline.py | 3 +- api/core/app/apps/workflow_app_runner.py | 35 ++ api/core/app/entities/queue_entities.py | 37 +- api/core/app/entities/task_entities.py | 2 + .../task_pipeline/workflow_cycle_manage.py | 28 +- api/core/workflow/entities/node_entities.py | 1 + .../workflow/graph_engine/entities/event.py | 7 + .../workflow/graph_engine/graph_engine.py | 11 + api/core/workflow/nodes/iteration/entities.py | 10 + .../nodes/iteration/iteration_node.py | 412 ++++++++++++---- .../nodes/iteration/test_iteration.py | 449 +++++++++++++++++- web/app/components/base/select/index.tsx | 2 +- web/app/components/workflow/constants.ts | 4 +- .../workflow/hooks/use-nodes-interactions.ts | 5 + .../workflow/hooks/use-workflow-run.ts | 102 +++- .../workflow/nodes/_base/components/field.tsx | 6 +- .../components/workflow/nodes/_base/node.tsx | 24 +- .../workflow/nodes/iteration/default.ts | 39 +- .../workflow/nodes/iteration/node.tsx | 15 +- .../workflow/nodes/iteration/panel.tsx | 59 ++- .../workflow/nodes/iteration/types.ts | 5 + .../workflow/nodes/iteration/use-config.ts | 25 +- .../workflow/panel/debug-and-preview/hooks.ts | 12 +- web/app/components/workflow/run/index.tsx | 77 ++- .../workflow/run/iteration-result-panel.tsx | 20 +- web/app/components/workflow/run/node.tsx | 16 +- web/app/components/workflow/store.ts | 10 + web/app/components/workflow/types.ts | 6 +- web/app/components/workflow/utils.ts | 11 +- web/i18n/en-US/workflow.ts | 17 + web/i18n/zh-Hans/workflow.ts | 17 + web/types/workflow.ts | 5 + 33 files changed, 1283 insertions(+), 192 deletions(-) diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index e4cb3f8527..1fc7ffe2c7 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -20,6 +20,7 @@ from core.app.entities.queue_entities import ( QueueIterationStartEvent, QueueMessageReplaceEvent, QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -314,7 +315,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if response: yield response - elif isinstance(event, QueueNodeFailedEvent): + elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent): workflow_node_execution = self._handle_workflow_node_execution_failed(event) response = self._workflow_node_finish_to_stream_response( diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 419a5da806..d119d94a61 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -16,6 +16,7 @@ from core.app.entities.queue_entities import ( QueueIterationNextEvent, QueueIterationStartEvent, QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -275,7 +276,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if response: yield response - elif isinstance(event, QueueNodeFailedEvent): + elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent): workflow_node_execution = self._handle_workflow_node_execution_failed(event) response = self._workflow_node_finish_to_stream_response( diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index ca23bbdd47..9a01e8a253 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -9,6 +9,7 @@ from core.app.entities.queue_entities import ( QueueIterationNextEvent, QueueIterationStartEvent, QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -30,6 +31,7 @@ from core.workflow.graph_engine.entities.event import ( IterationRunNextEvent, IterationRunStartedEvent, IterationRunSucceededEvent, + NodeInIterationFailedEvent, NodeRunFailedEvent, NodeRunRetrieverResourceEvent, NodeRunStartedEvent, @@ -193,6 +195,7 @@ class WorkflowBasedAppRunner(AppRunner): node_run_index=event.route_node_state.index, predecessor_node_id=event.predecessor_node_id, in_iteration_id=event.in_iteration_id, + parallel_mode_run_id=event.parallel_mode_run_id, ) ) elif isinstance(event, NodeRunSucceededEvent): @@ -246,9 +249,40 @@ class WorkflowBasedAppRunner(AppRunner): error=event.route_node_state.node_run_result.error if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error else "Unknown error", + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result + else {}, in_iteration_id=event.in_iteration_id, ) ) + elif isinstance(event, NodeInIterationFailedEvent): + self._publish_event( + QueueNodeInIterationFailedEvent( + node_execution_id=event.id, + node_id=event.node_id, + node_type=event.node_type, + node_data=event.node_data, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + start_at=event.route_node_state.start_at, + inputs=event.route_node_state.node_run_result.inputs + if event.route_node_state.node_run_result + else {}, + process_data=event.route_node_state.node_run_result.process_data + if event.route_node_state.node_run_result + else {}, + outputs=event.route_node_state.node_run_result.outputs + if event.route_node_state.node_run_result + else {}, + execution_metadata=event.route_node_state.node_run_result.metadata + if event.route_node_state.node_run_result + else {}, + in_iteration_id=event.in_iteration_id, + error=event.error, + ) + ) elif isinstance(event, NodeRunStreamChunkEvent): self._publish_event( QueueTextChunkEvent( @@ -326,6 +360,7 @@ class WorkflowBasedAppRunner(AppRunner): index=event.index, node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps, output=event.pre_iteration_output, + parallel_mode_run_id=event.parallel_mode_run_id, ) ) elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)): diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index bc43baf8a5..f1542ec5d8 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -107,7 +107,8 @@ class QueueIterationNextEvent(AppQueueEvent): """parent parallel id if node is in parallel""" parent_parallel_start_node_id: Optional[str] = None """parent parallel start node id if node is in parallel""" - + parallel_mode_run_id: Optional[str] = None + """iteratoin run in parallel mode run id""" node_run_index: int output: Optional[Any] = None # output for the current iteration @@ -273,6 +274,8 @@ class QueueNodeStartedEvent(AppQueueEvent): in_iteration_id: Optional[str] = None """iteration id if node is in iteration""" start_at: datetime + parallel_mode_run_id: Optional[str] = None + """iteratoin run in parallel mode run id""" class QueueNodeSucceededEvent(AppQueueEvent): @@ -306,6 +309,37 @@ class QueueNodeSucceededEvent(AppQueueEvent): error: Optional[str] = None +class QueueNodeInIterationFailedEvent(AppQueueEvent): + """ + QueueNodeInIterationFailedEvent entity + """ + + event: QueueEvent = QueueEvent.NODE_FAILED + + node_execution_id: str + node_id: str + node_type: NodeType + node_data: BaseNodeData + parallel_id: Optional[str] = None + """parallel id if node is in parallel""" + parallel_start_node_id: Optional[str] = None + """parallel start node id if node is in parallel""" + parent_parallel_id: Optional[str] = None + """parent parallel id if node is in parallel""" + parent_parallel_start_node_id: Optional[str] = None + """parent parallel start node id if node is in parallel""" + in_iteration_id: Optional[str] = None + """iteration id if node is in iteration""" + start_at: datetime + + inputs: Optional[dict[str, Any]] = None + process_data: Optional[dict[str, Any]] = None + outputs: Optional[dict[str, Any]] = None + execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None + + error: str + + class QueueNodeFailedEvent(AppQueueEvent): """ QueueNodeFailedEvent entity @@ -332,6 +366,7 @@ class QueueNodeFailedEvent(AppQueueEvent): inputs: Optional[dict[str, Any]] = None process_data: Optional[dict[str, Any]] = None outputs: Optional[dict[str, Any]] = None + execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None error: str diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 4b5f4716ed..7e9aad54be 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -244,6 +244,7 @@ class NodeStartStreamResponse(StreamResponse): parent_parallel_id: Optional[str] = None parent_parallel_start_node_id: Optional[str] = None iteration_id: Optional[str] = None + parallel_run_id: Optional[str] = None event: StreamEvent = StreamEvent.NODE_STARTED workflow_run_id: str @@ -432,6 +433,7 @@ class IterationNodeNextStreamResponse(StreamResponse): extras: dict = {} parallel_id: Optional[str] = None parallel_start_node_id: Optional[str] = None + parallel_mode_run_id: Optional[str] = None event: StreamEvent = StreamEvent.ITERATION_NEXT workflow_run_id: str diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 2abee5bef5..b89edf9079 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -12,6 +12,7 @@ from core.app.entities.queue_entities import ( QueueIterationNextEvent, QueueIterationStartEvent, QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, QueueParallelBranchRunFailedEvent, @@ -35,6 +36,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.tool_manager import ToolManager +from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData @@ -251,6 +253,12 @@ class WorkflowCycleManage: workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value workflow_node_execution.created_by_role = workflow_run.created_by_role workflow_node_execution.created_by = workflow_run.created_by + workflow_node_execution.execution_metadata = json.dumps( + { + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, + } + ) workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) session.add(workflow_node_execution) @@ -305,7 +313,9 @@ class WorkflowCycleManage: return workflow_node_execution - def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution: + def _handle_workflow_node_execution_failed( + self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent + ) -> WorkflowNodeExecution: """ Workflow node execution failed :param event: queue node failed event @@ -318,16 +328,19 @@ class WorkflowCycleManage: outputs = WorkflowEntry.handle_special_values(event.outputs) finished_at = datetime.now(timezone.utc).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() - + execution_metadata = ( + json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None + ) db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( { WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value, WorkflowNodeExecution.error: event.error, WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, - WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None, + WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None, WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, WorkflowNodeExecution.finished_at: finished_at, WorkflowNodeExecution.elapsed_time: elapsed_time, + WorkflowNodeExecution.execution_metadata: execution_metadata, } ) @@ -342,6 +355,7 @@ class WorkflowCycleManage: workflow_node_execution.outputs = json.dumps(outputs) if outputs else None workflow_node_execution.finished_at = finished_at workflow_node_execution.elapsed_time = elapsed_time + workflow_node_execution.execution_metadata = execution_metadata self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id) @@ -448,6 +462,7 @@ class WorkflowCycleManage: parent_parallel_id=event.parent_parallel_id, parent_parallel_start_node_id=event.parent_parallel_start_node_id, iteration_id=event.in_iteration_id, + parallel_run_id=event.parallel_mode_run_id, ), ) @@ -464,7 +479,7 @@ class WorkflowCycleManage: def _workflow_node_finish_to_stream_response( self, - event: QueueNodeSucceededEvent | QueueNodeFailedEvent, + event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution, ) -> Optional[NodeFinishStreamResponse]: @@ -608,6 +623,7 @@ class WorkflowCycleManage: extras={}, parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, + parallel_mode_run_id=event.parallel_mode_run_id, ), ) @@ -633,7 +649,9 @@ class WorkflowCycleManage: created_at=int(time.time()), extras={}, inputs=event.inputs or {}, - status=WorkflowNodeExecutionStatus.SUCCEEDED, + status=WorkflowNodeExecutionStatus.SUCCEEDED + if event.error is None + else WorkflowNodeExecutionStatus.FAILED, error=None, elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(), total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 0131bb342b..7e10cddc71 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -23,6 +23,7 @@ class NodeRunMetadataKey(str, Enum): PARALLEL_START_NODE_ID = "parallel_start_node_id" PARENT_PARALLEL_ID = "parent_parallel_id" PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" + PARALLEL_MODE_RUN_ID = "parallel_mode_run_id" class NodeRunResult(BaseModel): diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 86d89e0a32..bacea191dd 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -59,6 +59,7 @@ class BaseNodeEvent(GraphEngineEvent): class NodeRunStartedEvent(BaseNodeEvent): predecessor_node_id: Optional[str] = None + parallel_mode_run_id: Optional[str] = None """predecessor node id""" @@ -81,6 +82,10 @@ class NodeRunFailedEvent(BaseNodeEvent): error: str = Field(..., description="error") +class NodeInIterationFailedEvent(BaseNodeEvent): + error: str = Field(..., description="error") + + ########################################### # Parallel Branch Events ########################################### @@ -129,6 +134,8 @@ class BaseIterationEvent(GraphEngineEvent): """parent parallel id if node is in parallel""" parent_parallel_start_node_id: Optional[str] = None """parent parallel start node id if node is in parallel""" + parallel_mode_run_id: Optional[str] = None + """iteratoin run in parallel mode run id""" class IterationRunStartedEvent(BaseIterationEvent): diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 8f58af00ef..f07ad4de11 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -4,6 +4,7 @@ import time import uuid from collections.abc import Generator, Mapping from concurrent.futures import ThreadPoolExecutor, wait +from copy import copy, deepcopy from typing import Any, Optional from flask import Flask, current_app @@ -724,6 +725,16 @@ class GraphEngine: """ return time.perf_counter() - start_at > max_execution_time + def create_copy(self): + """ + create a graph engine copy + :return: with a new variable pool instance of graph engine + """ + new_instance = copy(self) + new_instance.graph_runtime_state = copy(self.graph_runtime_state) + new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool) + return new_instance + class GraphRunFailedError(Exception): def __init__(self, error: str): diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py index 4afc870e50..ebcb6f82fb 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/core/workflow/nodes/iteration/entities.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Any, Optional from pydantic import Field @@ -5,6 +6,12 @@ from pydantic import Field from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData +class ErrorHandleMode(str, Enum): + TERMINATED = "terminated" + CONTINUE_ON_ERROR = "continue-on-error" + REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output" + + class IterationNodeData(BaseIterationNodeData): """ Iteration Node Data. @@ -13,6 +20,9 @@ class IterationNodeData(BaseIterationNodeData): parent_loop_id: Optional[str] = None # redundant field, not used currently iterator_selector: list[str] # variable selector output_selector: list[str] # output selector + is_parallel: bool = False # open the parallel mode or not + parallel_nums: int = 10 # the numbers of parallel + error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error class IterationStartNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index af79da9215..d121b0530a 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -1,12 +1,20 @@ import logging +import uuid from collections.abc import Generator, Mapping, Sequence +from concurrent.futures import Future, wait from datetime import datetime, timezone -from typing import Any, cast +from queue import Empty, Queue +from typing import TYPE_CHECKING, Any, Optional, cast + +from flask import Flask, current_app from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder -from core.variables import IntegerSegment -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.entities.node_entities import ( + NodeRunMetadataKey, + NodeRunResult, +) +from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine.entities.event import ( BaseGraphEvent, BaseNodeEvent, @@ -17,6 +25,9 @@ from core.workflow.graph_engine.entities.event import ( IterationRunNextEvent, IterationRunStartedEvent, IterationRunSucceededEvent, + NodeInIterationFailedEvent, + NodeRunFailedEvent, + NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) @@ -24,9 +35,11 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent -from core.workflow.nodes.iteration.entities import IterationNodeData +from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from models.workflow import WorkflowNodeExecutionStatus +if TYPE_CHECKING: + from core.workflow.graph_engine.graph_engine import GraphEngine logger = logging.getLogger(__name__) @@ -38,6 +51,17 @@ class IterationNode(BaseNode[IterationNodeData]): _node_data_cls = IterationNodeData _node_type = NodeType.ITERATION + @classmethod + def get_default_config(cls, filters: Optional[dict] = None) -> dict: + return { + "type": "iteration", + "config": { + "is_parallel": False, + "parallel_nums": 10, + "error_handle_mode": ErrorHandleMode.TERMINATED.value, + }, + } + def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: """ Run the node. @@ -83,7 +107,7 @@ class IterationNode(BaseNode[IterationNodeData]): variable_pool.add([self.node_id, "item"], iterator_list_value[0]) # init graph engine - from core.workflow.graph_engine.graph_engine import GraphEngine + from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool graph_engine = GraphEngine( tenant_id=self.tenant_id, @@ -123,108 +147,64 @@ class IterationNode(BaseNode[IterationNodeData]): index=0, pre_iteration_output=None, ) - outputs: list[Any] = [] try: - for _ in range(len(iterator_list_value)): - # run workflow - rst = graph_engine.run() - for event in rst: - if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: - event.in_iteration_id = self.node_id - - if ( - isinstance(event, BaseNodeEvent) - and event.node_type == NodeType.ITERATION_START - and not isinstance(event, NodeRunStreamChunkEvent) - ): + if self.node_data.is_parallel: + futures: list[Future] = [] + q = Queue() + thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100) + for index, item in enumerate(iterator_list_value): + future: Future = thread_pool.submit( + self._run_single_iter_parallel, + current_app._get_current_object(), + q, + iterator_list_value, + inputs, + outputs, + start_at, + graph_engine, + iteration_graph, + index, + item, + ) + future.add_done_callback(thread_pool.task_done_callback) + futures.append(future) + succeeded_count = 0 + while True: + try: + event = q.get(timeout=1) + if event is None: + break + if isinstance(event, IterationRunNextEvent): + succeeded_count += 1 + if succeeded_count == len(futures): + q.put(None) + yield event + if isinstance(event, RunCompletedEvent): + q.put(None) + for f in futures: + if not f.done(): + f.cancel() + yield event + if isinstance(event, IterationRunFailedEvent): + q.put(None) + yield event + except Empty: continue - if isinstance(event, NodeRunSucceededEvent): - if event.route_node_state.node_run_result: - metadata = event.route_node_state.node_run_result.metadata - if not metadata: - metadata = {} - - if NodeRunMetadataKey.ITERATION_ID not in metadata: - metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id - index_variable = variable_pool.get([self.node_id, "index"]) - if not isinstance(index_variable, IntegerSegment): - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=f"Invalid index variable type: {type(index_variable)}", - ) - ) - return - metadata[NodeRunMetadataKey.ITERATION_INDEX] = index_variable.value - event.route_node_state.node_run_result.metadata = metadata - - yield event - elif isinstance(event, BaseGraphEvent): - if isinstance(event, GraphRunFailedEvent): - # iteration run failed - yield IterationRunFailedEvent( - iteration_id=self.id, - iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, - start_at=start_at, - inputs=inputs, - outputs={"output": jsonable_encoder(outputs)}, - steps=len(iterator_list_value), - metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, - error=event.error, - ) - - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=event.error, - ) - ) - return - else: - event = cast(InNodeEvent, event) - yield event - - # append to iteration output variable list - current_iteration_output_variable = variable_pool.get(self.node_data.output_selector) - if current_iteration_output_variable is None: - yield RunCompletedEvent( - run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=f"Iteration output variable {self.node_data.output_selector} not found", - ) + # wait all threads + wait(futures) + else: + for _ in range(len(iterator_list_value)): + yield from self._run_single_iter( + iterator_list_value, + variable_pool, + inputs, + outputs, + start_at, + graph_engine, + iteration_graph, ) - return - current_iteration_output = current_iteration_output_variable.to_object() - outputs.append(current_iteration_output) - - # remove all nodes outputs from variable pool - for node_id in iteration_graph.node_ids: - variable_pool.remove([node_id]) - - # move to next iteration - current_index_variable = variable_pool.get([self.node_id, "index"]) - if not isinstance(current_index_variable, IntegerSegment): - raise ValueError(f"iteration {self.node_id} current index not found") - - next_index = current_index_variable.value + 1 - variable_pool.add([self.node_id, "index"], next_index) - - if next_index < len(iterator_list_value): - variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - - yield IterationRunNextEvent( - iteration_id=self.id, - iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, - index=next_index, - pre_iteration_output=jsonable_encoder(current_iteration_output), - ) - yield IterationRunSucceededEvent( iteration_id=self.id, iteration_node_id=self.node_id, @@ -330,3 +310,231 @@ class IterationNode(BaseNode[IterationNodeData]): } return variable_mapping + + def _handle_event_metadata( + self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str + ) -> NodeRunStartedEvent | BaseNodeEvent: + """ + add iteration metadata to event. + """ + if not isinstance(event, BaseNodeEvent): + return event + if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent): + event.parallel_mode_run_id = parallel_mode_run_id + return event + if event.route_node_state.node_run_result: + metadata = event.route_node_state.node_run_result.metadata + if not metadata: + metadata = {} + + if NodeRunMetadataKey.ITERATION_ID not in metadata: + metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id + if self.node_data.is_parallel: + metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id + else: + metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index + event.route_node_state.node_run_result.metadata = metadata + return event + + def _run_single_iter( + self, + iterator_list_value: list[str], + variable_pool: VariablePool, + inputs: dict[str, list], + outputs: list, + start_at: datetime, + graph_engine: "GraphEngine", + iteration_graph: Graph, + parallel_mode_run_id: Optional[str] = None, + ) -> Generator[NodeEvent | InNodeEvent, None, None]: + """ + run single iteration + """ + try: + rst = graph_engine.run() + # get current iteration index + current_index = variable_pool.get([self.node_id, "index"]).value + next_index = int(current_index) + 1 + + if current_index is None: + raise ValueError(f"iteration {self.node_id} current index not found") + for event in rst: + if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id: + event.in_iteration_id = self.node_id + + if ( + isinstance(event, BaseNodeEvent) + and event.node_type == NodeType.ITERATION_START + and not isinstance(event, NodeRunStreamChunkEvent) + ): + continue + + if isinstance(event, NodeRunSucceededEvent): + yield self._handle_event_metadata(event, current_index, parallel_mode_run_id) + elif isinstance(event, BaseGraphEvent): + if isinstance(event, GraphRunFailedEvent): + # iteration run failed + if self.node_data.is_parallel: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + parallel_mode_run_id=parallel_mode_run_id, + start_at=start_at, + inputs=inputs, + outputs={"output": jsonable_encoder(outputs)}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + else: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": jsonable_encoder(outputs)}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=event.error, + ) + ) + return + else: + event = cast(InNodeEvent, event) + metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id) + if isinstance(event, NodeRunFailedEvent): + if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR: + yield NodeInIterationFailedEvent( + **metadata_event.model_dump(), + ) + outputs.insert(current_index, None) + variable_pool.add([self.node_id, "index"], next_index) + if next_index < len(iterator_list_value): + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=next_index, + parallel_mode_run_id=parallel_mode_run_id, + pre_iteration_output=None, + ) + return + elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: + yield NodeInIterationFailedEvent( + **metadata_event.model_dump(), + ) + variable_pool.add([self.node_id, "index"], next_index) + + if next_index < len(iterator_list_value): + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=next_index, + parallel_mode_run_id=parallel_mode_run_id, + pre_iteration_output=None, + ) + return + elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": None}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + yield metadata_event + + current_iteration_output = variable_pool.get(self.node_data.output_selector).value + outputs.insert(current_index, current_iteration_output) + # remove all nodes outputs from variable pool + for node_id in iteration_graph.node_ids: + variable_pool.remove([node_id]) + + # move to next iteration + variable_pool.add([self.node_id, "index"], next_index) + + if next_index < len(iterator_list_value): + variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) + yield IterationRunNextEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + index=next_index, + parallel_mode_run_id=parallel_mode_run_id, + pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None, + ) + + except Exception as e: + logger.exception(f"Iteration run failed:{str(e)}") + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": None}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=str(e), + ) + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + ) + ) + + def _run_single_iter_parallel( + self, + flask_app: Flask, + q: Queue, + iterator_list_value: list[str], + inputs: dict[str, list], + outputs: list, + start_at: datetime, + graph_engine: "GraphEngine", + iteration_graph: Graph, + index: int, + item: Any, + ) -> Generator[NodeEvent | InNodeEvent, None, None]: + """ + run single iteration in parallel mode + """ + with flask_app.app_context(): + parallel_mode_run_id = uuid.uuid4().hex + graph_engine_copy = graph_engine.create_copy() + variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool + variable_pool_copy.add([self.node_id, "index"], index) + variable_pool_copy.add([self.node_id, "item"], item) + for event in self._run_single_iter( + iterator_list_value=iterator_list_value, + variable_pool=variable_pool_copy, + inputs=inputs, + outputs=outputs, + start_at=start_at, + graph_engine=graph_engine_copy, + iteration_graph=iteration_graph, + parallel_mode_run_id=parallel_mode_run_id, + ): + q.put(event) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py index d755faee8a..29bd4d6c6c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -10,6 +10,7 @@ from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.event import RunCompletedEvent +from core.workflow.nodes.iteration.entities import ErrorHandleMode from core.workflow.nodes.iteration.iteration_node import IterationNode from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode from models.enums import UserFrom @@ -185,8 +186,6 @@ def test_run(): outputs={"output": "dify 123"}, ) - # print("") - with patch.object(TemplateTransformNode, "_run", new=tt_generator): # execute node result = iteration_node._run() @@ -404,18 +403,458 @@ def test_run_parallel(): outputs={"output": "dify 123"}, ) - # print("") - with patch.object(TemplateTransformNode, "_run", new=tt_generator): # execute node result = iteration_node._run() count = 0 for item in result: - # print(type(item), item) count += 1 if isinstance(item, RunCompletedEvent): assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} assert count == 32 + + +def test_iteration_run_in_parallel_mode(): + graph_config = { + "edges": [ + { + "id": "start-source-pe-target", + "source": "start", + "target": "pe", + }, + { + "id": "iteration-1-source-answer-3-target", + "source": "iteration-1", + "target": "answer-3", + }, + { + "id": "iteration-start-source-tt-target", + "source": "iteration-start", + "target": "tt", + }, + { + "id": "iteration-start-source-tt-2-target", + "source": "iteration-start", + "target": "tt-2", + }, + { + "id": "tt-source-if-else-target", + "source": "tt", + "target": "if-else", + }, + { + "id": "tt-2-source-if-else-target", + "source": "tt-2", + "target": "if-else", + }, + { + "id": "if-else-true-answer-2-target", + "source": "if-else", + "sourceHandle": "true", + "target": "answer-2", + }, + { + "id": "if-else-false-answer-4-target", + "source": "if-else", + "sourceHandle": "false", + "target": "answer-4", + }, + { + "id": "pe-source-iteration-1-target", + "source": "pe", + "target": "iteration-1", + }, + ], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "iteration", + "type": "iteration", + }, + "id": "iteration-1", + }, + { + "data": { + "answer": "{{#tt.output#}}", + "iteration_id": "iteration-1", + "title": "answer 2", + "type": "answer", + }, + "id": "answer-2", + }, + { + "data": { + "iteration_id": "iteration-1", + "title": "iteration-start", + "type": "iteration-start", + }, + "id": "iteration-start", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 123", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }} 321", + "title": "template transform", + "type": "template-transform", + "variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}], + }, + "id": "tt-2", + }, + { + "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, + "id": "answer-3", + }, + { + "data": { + "conditions": [ + { + "comparison_operator": "is", + "id": "1721916275284", + "value": "hi", + "variable_selector": ["sys", "query"], + } + ], + "iteration_id": "iteration-1", + "logical_operator": "and", + "title": "if", + "type": "if-else", + }, + "id": "if-else", + }, + { + "data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"}, + "id": "answer-4", + }, + { + "data": { + "instruction": "test1", + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "parameters": [ + {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} + ], + "query": ["sys", "query"], + "reasoning_mode": "prompt", + "title": "pe", + "type": "parameter-extractor", + }, + "id": "pe", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.CHAT, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "dify", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "1", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) + + parallel_iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "迭代", + "type": "iteration", + "is_parallel": True, + }, + "id": "iteration-1", + }, + ) + sequential_iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "迭代", + "type": "iteration", + "is_parallel": True, + }, + "id": "iteration-1", + }, + ) + + def tt_generator(self): + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"iterator_selector": "dify"}, + outputs={"output": "dify 123"}, + ) + + with patch.object(TemplateTransformNode, "_run", new=tt_generator): + # execute node + parallel_result = parallel_iteration_node._run() + sequential_result = sequential_iteration_node._run() + assert parallel_iteration_node.node_data.parallel_nums == 10 + assert parallel_iteration_node.node_data.error_handle_mode == ErrorHandleMode.TERMINATED + count = 0 + parallel_arr = [] + sequential_arr = [] + for item in parallel_result: + count += 1 + parallel_arr.append(item) + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert count == 32 + + for item in sequential_result: + sequential_arr.append(item) + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]} + assert count == 64 + + +def test_iteration_run_error_handle(): + graph_config = { + "edges": [ + { + "id": "start-source-pe-target", + "source": "start", + "target": "pe", + }, + { + "id": "iteration-1-source-answer-3-target", + "source": "iteration-1", + "target": "answer-3", + }, + { + "id": "tt-source-if-else-target", + "source": "iteration-start", + "target": "if-else", + }, + { + "id": "if-else-true-answer-2-target", + "source": "if-else", + "sourceHandle": "true", + "target": "tt", + }, + { + "id": "if-else-false-answer-4-target", + "source": "if-else", + "sourceHandle": "false", + "target": "tt2", + }, + { + "id": "pe-source-iteration-1-target", + "source": "pe", + "target": "iteration-1", + }, + ], + "nodes": [ + {"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"}, + { + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt2", "output"], + "output_type": "array[string]", + "start_node_id": "if-else", + "title": "iteration", + "type": "iteration", + }, + "id": "iteration-1", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1.split(arg2) }}", + "title": "template transform", + "type": "template-transform", + "variables": [ + {"value_selector": ["iteration-1", "item"], "variable": "arg1"}, + {"value_selector": ["iteration-1", "index"], "variable": "arg2"}, + ], + }, + "id": "tt", + }, + { + "data": { + "iteration_id": "iteration-1", + "template": "{{ arg1 }}", + "title": "template transform", + "type": "template-transform", + "variables": [ + {"value_selector": ["iteration-1", "item"], "variable": "arg1"}, + ], + }, + "id": "tt2", + }, + { + "data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"}, + "id": "answer-3", + }, + { + "data": { + "iteration_id": "iteration-1", + "title": "iteration-start", + "type": "iteration-start", + }, + "id": "iteration-start", + }, + { + "data": { + "conditions": [ + { + "comparison_operator": "is", + "id": "1721916275284", + "value": "1", + "variable_selector": ["iteration-1", "item"], + } + ], + "iteration_id": "iteration-1", + "logical_operator": "and", + "title": "if", + "type": "if-else", + }, + "id": "if-else", + }, + { + "data": { + "instruction": "test1", + "model": { + "completion_params": {"temperature": 0.7}, + "mode": "chat", + "name": "gpt-4o", + "provider": "openai", + }, + "parameters": [ + {"description": "test", "name": "list_output", "required": False, "type": "array[string]"} + ], + "query": ["sys", "query"], + "reasoning_mode": "prompt", + "title": "pe", + "type": "parameter-extractor", + }, + "id": "pe", + }, + ], + } + + graph = Graph.init(graph_config=graph_config) + + init_params = GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.CHAT, + workflow_id="1", + graph_config=graph_config, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + # construct variable pool + pool = VariablePool( + system_variables={ + SystemVariableKey.QUERY: "dify", + SystemVariableKey.FILES: [], + SystemVariableKey.CONVERSATION_ID: "abababa", + SystemVariableKey.USER_ID: "1", + }, + user_inputs={}, + environment_variables=[], + ) + pool.add(["pe", "list_output"], ["1", "1"]) + iteration_node = IterationNode( + id=str(uuid.uuid4()), + graph_init_params=init_params, + graph=graph, + graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()), + config={ + "data": { + "iterator_selector": ["pe", "list_output"], + "output_selector": ["tt", "output"], + "output_type": "array[string]", + "startNodeType": "template-transform", + "start_node_id": "iteration-start", + "title": "iteration", + "type": "iteration", + "is_parallel": True, + "error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR, + }, + "id": "iteration-1", + }, + ) + # execute continue on error node + result = iteration_node._run() + result_arr = [] + count = 0 + for item in result: + result_arr.append(item) + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": [None, None]} + + assert count == 14 + # execute remove abnormal output + iteration_node.node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT + result = iteration_node._run() + count = 0 + for item in result: + count += 1 + if isinstance(item, RunCompletedEvent): + assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.run_result.outputs == {"output": []} + assert count == 14 diff --git a/web/app/components/base/select/index.tsx b/web/app/components/base/select/index.tsx index ee5cee977b..c70cf24661 100644 --- a/web/app/components/base/select/index.tsx +++ b/web/app/components/base/select/index.tsx @@ -125,7 +125,7 @@ const Select: FC = ({ - {filteredItems.length > 0 && ( + {(filteredItems.length > 0 && open) && ( {filteredItems.map((item: Item) => ( { newNode.data.isInIteration = true newNode.data.iteration_id = prevNode.parentId newNode.zIndex = ITERATION_CHILDREN_Z_INDEX + if (newNode.data.type === BlockEnum.Answer || newNode.data.type === BlockEnum.Tool || newNode.data.type === BlockEnum.Assigner) { + const parentIterNodeIndex = nodes.findIndex(node => node.id === prevNode.parentId) + const iterNodeData: IterationNodeType = nodes[parentIterNodeIndex].data + iterNodeData._isShowTips = true + } } const newEdge: Edge = { diff --git a/web/app/components/workflow/hooks/use-workflow-run.ts b/web/app/components/workflow/hooks/use-workflow-run.ts index 0bbb1adab8..26654ef71e 100644 --- a/web/app/components/workflow/hooks/use-workflow-run.ts +++ b/web/app/components/workflow/hooks/use-workflow-run.ts @@ -14,6 +14,7 @@ import { NodeRunningStatus, WorkflowRunningStatus, } from '../types' +import { DEFAULT_ITER_TIMES } from '../constants' import { useWorkflowUpdate } from './use-workflow-interactions' import { useStore as useAppStore } from '@/app/components/app/store' import type { IOtherOptions } from '@/service/base' @@ -170,11 +171,13 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + setIterParallelLogMap, } = workflowStore.getState() const { edges, setEdges, } = store.getState() + setIterParallelLogMap(new Map()) setWorkflowRunningData(produce(workflowRunningData!, (draft) => { draft.task_id = task_id draft.result = { @@ -244,6 +247,8 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + iterParallelLogMap, + setIterParallelLogMap, } = workflowStore.getState() const { getNodes, @@ -259,10 +264,21 @@ export const useWorkflowRun = () => { const tracing = draft.tracing! const iterations = tracing.find(trace => trace.node_id === node?.parentId) const currIteration = iterations?.details![node.data.iteration_index] || iterations?.details![iterations.details!.length - 1] - currIteration?.push({ - ...data, - status: NodeRunningStatus.Running, - } as any) + if (!data.parallel_run_id) { + currIteration?.push({ + ...data, + status: NodeRunningStatus.Running, + } as any) + } + else { + if (!iterParallelLogMap.has(data.parallel_run_id)) + iterParallelLogMap.set(data.parallel_run_id, [{ ...data, status: NodeRunningStatus.Running } as any]) + else + iterParallelLogMap.get(data.parallel_run_id)!.push({ ...data, status: NodeRunningStatus.Running } as any) + setIterParallelLogMap(iterParallelLogMap) + if (iterations) + iterations.details = Array.from(iterParallelLogMap.values()) + } })) } else { @@ -309,6 +325,8 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + iterParallelLogMap, + setIterParallelLogMap, } = workflowStore.getState() const { getNodes, @@ -317,21 +335,21 @@ export const useWorkflowRun = () => { const nodes = getNodes() const nodeParentId = nodes.find(node => node.id === data.node_id)!.parentId if (nodeParentId) { - setWorkflowRunningData(produce(workflowRunningData!, (draft) => { - const tracing = draft.tracing! - const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node + if (!data.execution_metadata.parallel_mode_run_id) { + setWorkflowRunningData(produce(workflowRunningData!, (draft) => { + const tracing = draft.tracing! + const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node - if (iterations && iterations.details) { - const iterationIndex = data.execution_metadata?.iteration_index || 0 - if (!iterations.details[iterationIndex]) - iterations.details[iterationIndex] = [] + if (iterations && iterations.details) { + const iterationIndex = data.execution_metadata?.iteration_index || 0 + if (!iterations.details[iterationIndex]) + iterations.details[iterationIndex] = [] - const currIteration = iterations.details[iterationIndex] - const nodeIndex = currIteration.findIndex(node => - node.node_id === data.node_id && ( - node.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || node.parallel_id === data.execution_metadata?.parallel_id), - ) - if (data.status === NodeRunningStatus.Succeeded) { + const currIteration = iterations.details[iterationIndex] + const nodeIndex = currIteration.findIndex(node => + node.node_id === data.node_id && ( + node.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || node.parallel_id === data.execution_metadata?.parallel_id), + ) if (nodeIndex !== -1) { currIteration[nodeIndex] = { ...currIteration[nodeIndex], @@ -344,8 +362,40 @@ export const useWorkflowRun = () => { } as any) } } - } - })) + })) + } + else { + // open parallel mode + setWorkflowRunningData(produce(workflowRunningData!, (draft) => { + const tracing = draft.tracing! + const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node + + if (iterations && iterations.details) { + const iterRunID = data.execution_metadata?.parallel_mode_run_id + + const currIteration = iterParallelLogMap.get(iterRunID) + const nodeIndex = currIteration?.findIndex(node => + node.node_id === data.node_id && ( + node?.parallel_run_id === data.execution_metadata?.parallel_mode_run_id), + ) + if (currIteration) { + if (nodeIndex !== undefined && nodeIndex !== -1) { + currIteration[nodeIndex] = { + ...currIteration[nodeIndex], + ...data, + } as any + } + else { + currIteration.push({ + ...data, + } as any) + } + } + setIterParallelLogMap(iterParallelLogMap) + iterations.details = Array.from(iterParallelLogMap.values()) + } + })) + } } else { setWorkflowRunningData(produce(workflowRunningData!, (draft) => { @@ -379,6 +429,7 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + setIterTimes, } = workflowStore.getState() const { getNodes, @@ -388,6 +439,7 @@ export const useWorkflowRun = () => { transform, } = store.getState() const nodes = getNodes() + setIterTimes(DEFAULT_ITER_TIMES) setWorkflowRunningData(produce(workflowRunningData!, (draft) => { draft.tracing!.push({ ...data, @@ -431,6 +483,8 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + iterTimes, + setIterTimes, } = workflowStore.getState() const { data } = params @@ -445,13 +499,14 @@ export const useWorkflowRun = () => { if (iteration.details!.length >= iteration.metadata.iterator_length!) return } - iteration?.details!.push([]) + if (!data.parallel_mode_run_id) + iteration?.details!.push([]) })) const nodes = getNodes() const newNodes = produce(nodes, (draft) => { const currentNode = draft.find(node => node.id === data.node_id)! - - currentNode.data._iterationIndex = data.index > 0 ? data.index : 1 + currentNode.data._iterationIndex = iterTimes + setIterTimes(iterTimes + 1) }) setNodes(newNodes) @@ -464,6 +519,7 @@ export const useWorkflowRun = () => { const { workflowRunningData, setWorkflowRunningData, + setIterTimes, } = workflowStore.getState() const { getNodes, @@ -480,7 +536,7 @@ export const useWorkflowRun = () => { }) } })) - + setIterTimes(DEFAULT_ITER_TIMES) const newNodes = produce(nodes, (draft) => { const currentNode = draft.find(node => node.id === data.node_id)! diff --git a/web/app/components/workflow/nodes/_base/components/field.tsx b/web/app/components/workflow/nodes/_base/components/field.tsx index 6459cf8056..b2f815a325 100644 --- a/web/app/components/workflow/nodes/_base/components/field.tsx +++ b/web/app/components/workflow/nodes/_base/components/field.tsx @@ -12,15 +12,15 @@ import Tooltip from '@/app/components/base/tooltip' type Props = { className?: string title: JSX.Element | string | DefaultTFuncReturn + tooltip?: React.ReactNode isSubTitle?: boolean - tooltip?: string supportFold?: boolean children?: JSX.Element | string | null operations?: JSX.Element inline?: boolean } -const Filed: FC = ({ +const Field: FC = ({ className, title, isSubTitle, @@ -60,4 +60,4 @@ const Filed: FC = ({ ) } -export default React.memo(Filed) +export default React.memo(Field) diff --git a/web/app/components/workflow/nodes/_base/node.tsx b/web/app/components/workflow/nodes/_base/node.tsx index bd5921c735..e864c419e2 100644 --- a/web/app/components/workflow/nodes/_base/node.tsx +++ b/web/app/components/workflow/nodes/_base/node.tsx @@ -25,6 +25,7 @@ import { useToolIcon, } from '../../hooks' import { useNodeIterationInteractions } from '../iteration/use-interactions' +import type { IterationNodeType } from '../iteration/types' import { NodeSourceHandle, NodeTargetHandle, @@ -34,6 +35,7 @@ import NodeControl from './components/node-control' import AddVariablePopupWithPosition from './components/add-variable-popup-with-position' import cn from '@/utils/classnames' import BlockIcon from '@/app/components/workflow/block-icon' +import Tooltip from '@/app/components/base/tooltip' type BaseNodeProps = { children: ReactElement @@ -166,9 +168,27 @@ const BaseNode: FC = ({ />
- {data.title} +
+ {data.title} +
+ { + data.type === BlockEnum.Iteration && (data as IterationNodeType).is_parallel && ( + +
+ {t('workflow.nodes.iteration.parallelModeEnableTitle')} +
+ {t('workflow.nodes.iteration.parallelModeEnableDesc')} +
} + > +
+ {t('workflow.nodes.iteration.parallelModeUpper')} +
+ + ) + } { data._iterationLength && data._iterationIndex && data._runningStatus === NodeRunningStatus.Running && ( diff --git a/web/app/components/workflow/nodes/iteration/default.ts b/web/app/components/workflow/nodes/iteration/default.ts index 3afa52d06e..cdef268adb 100644 --- a/web/app/components/workflow/nodes/iteration/default.ts +++ b/web/app/components/workflow/nodes/iteration/default.ts @@ -1,7 +1,10 @@ -import { BlockEnum } from '../../types' +import { BlockEnum, ErrorHandleMode } from '../../types' import type { NodeDefault } from '../../types' import type { IterationNodeType } from './types' -import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants' +import { + ALL_CHAT_AVAILABLE_BLOCKS, + ALL_COMPLETION_AVAILABLE_BLOCKS, +} from '@/app/components/workflow/constants' const i18nPrefix = 'workflow' const nodeDefault: NodeDefault = { @@ -10,25 +13,45 @@ const nodeDefault: NodeDefault = { iterator_selector: [], output_selector: [], _children: [], + _isShowTips: false, + is_parallel: false, + parallel_nums: 10, + error_handle_mode: ErrorHandleMode.Terminated, }, getAvailablePrevNodes(isChatMode: boolean) { const nodes = isChatMode ? ALL_CHAT_AVAILABLE_BLOCKS - : ALL_COMPLETION_AVAILABLE_BLOCKS.filter(type => type !== BlockEnum.End) + : ALL_COMPLETION_AVAILABLE_BLOCKS.filter( + type => type !== BlockEnum.End, + ) return nodes }, getAvailableNextNodes(isChatMode: boolean) { - const nodes = isChatMode ? ALL_CHAT_AVAILABLE_BLOCKS : ALL_COMPLETION_AVAILABLE_BLOCKS + const nodes = isChatMode + ? ALL_CHAT_AVAILABLE_BLOCKS + : ALL_COMPLETION_AVAILABLE_BLOCKS return nodes }, checkValid(payload: IterationNodeType, t: any) { let errorMessages = '' - if (!errorMessages && (!payload.iterator_selector || payload.iterator_selector.length === 0)) - errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.iteration.input`) }) + if ( + !errorMessages + && (!payload.iterator_selector || payload.iterator_selector.length === 0) + ) { + errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { + field: t(`${i18nPrefix}.nodes.iteration.input`), + }) + } - if (!errorMessages && (!payload.output_selector || payload.output_selector.length === 0)) - errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.iteration.output`) }) + if ( + !errorMessages + && (!payload.output_selector || payload.output_selector.length === 0) + ) { + errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { + field: t(`${i18nPrefix}.nodes.iteration.output`), + }) + } return { isValid: !errorMessages, diff --git a/web/app/components/workflow/nodes/iteration/node.tsx b/web/app/components/workflow/nodes/iteration/node.tsx index 48a005a261..fda033b87a 100644 --- a/web/app/components/workflow/nodes/iteration/node.tsx +++ b/web/app/components/workflow/nodes/iteration/node.tsx @@ -8,12 +8,16 @@ import { useNodesInitialized, useViewport, } from 'reactflow' +import { useTranslation } from 'react-i18next' import { IterationStartNodeDumb } from '../iteration-start' import { useNodeIterationInteractions } from './use-interactions' import type { IterationNodeType } from './types' import AddBlock from './add-block' import cn from '@/utils/classnames' import type { NodeProps } from '@/app/components/workflow/types' +import Toast from '@/app/components/base/toast' + +const i18nPrefix = 'workflow.nodes.iteration' const Node: FC> = ({ id, @@ -22,11 +26,20 @@ const Node: FC> = ({ const { zoom } = useViewport() const nodesInitialized = useNodesInitialized() const { handleNodeIterationRerender } = useNodeIterationInteractions() + const { t } = useTranslation() useEffect(() => { if (nodesInitialized) handleNodeIterationRerender(id) - }, [nodesInitialized, id, handleNodeIterationRerender]) + if (data.is_parallel && data._isShowTips) { + Toast.notify({ + type: 'warning', + message: t(`${i18nPrefix}.answerNodeWarningDesc`), + duration: 5000, + }) + data._isShowTips = false + } + }, [nodesInitialized, id, handleNodeIterationRerender, data, t]) return (
> = ({ data, }) => { const { t } = useTranslation() - + const responseMethod = [ + { + value: ErrorHandleMode.Terminated, + name: t(`${i18nPrefix}.ErrorMethod.operationTerminated`), + }, + { + value: ErrorHandleMode.ContinueOnError, + name: t(`${i18nPrefix}.ErrorMethod.continueOnError`), + }, + { + value: ErrorHandleMode.RemoveAbnormalOutput, + name: t(`${i18nPrefix}.ErrorMethod.removeAbnormalOutput`), + }, + ] const { readOnly, inputs, @@ -47,6 +66,9 @@ const Panel: FC> = ({ setIterator, iteratorInputKey, iterationRunResult, + changeParallel, + changeErrorResponseMode, + changeParallelNums, } = useConfig(id, data) return ( @@ -87,6 +109,39 @@ const Panel: FC> = ({ />
+
+ {t(`${i18nPrefix}.parallelPanelDesc`)}
} inline> + + + + { + inputs.is_parallel && (
+ {t(`${i18nPrefix}.MaxParallelismDesc`)}
}> +
+ { changeParallelNums(Number(e.target.value)) }} /> + +
+ + + ) + } +
+ +
+ +
+ + + +
+ {isShowSingleRun && ( { @@ -184,6 +185,25 @@ const useConfig = (id: string, payload: IterationNodeType) => { }) }, [iteratorInputKey, runInputData, setRunInputData]) + const changeParallel = useCallback((value: boolean) => { + const newInputs = produce(inputs, (draft) => { + draft.is_parallel = value + }) + setInputs(newInputs) + }, [inputs, setInputs]) + + const changeErrorResponseMode = useCallback((item: Item) => { + const newInputs = produce(inputs, (draft) => { + draft.error_handle_mode = item.value as ErrorHandleMode + }) + setInputs(newInputs) + }, [inputs, setInputs]) + const changeParallelNums = useCallback((num: number) => { + const newInputs = produce(inputs, (draft) => { + draft.parallel_nums = num + }) + setInputs(newInputs) + }, [inputs, setInputs]) return { readOnly, inputs, @@ -210,6 +230,9 @@ const useConfig = (id: string, payload: IterationNodeType) => { setIterator, iteratorInputKey, iterationRunResult, + changeParallel, + changeErrorResponseMode, + changeParallelNums, } } diff --git a/web/app/components/workflow/panel/debug-and-preview/hooks.ts b/web/app/components/workflow/panel/debug-and-preview/hooks.ts index 58a4561e2c..5d932a1ba2 100644 --- a/web/app/components/workflow/panel/debug-and-preview/hooks.ts +++ b/web/app/components/workflow/panel/debug-and-preview/hooks.ts @@ -9,6 +9,8 @@ import { produce, setAutoFreeze } from 'immer' import { uniqBy } from 'lodash-es' import { useWorkflowRun } from '../../hooks' import { NodeRunningStatus, WorkflowRunningStatus } from '../../types' +import { useWorkflowStore } from '../../store' +import { DEFAULT_ITER_TIMES } from '../../constants' import type { ChatItem, Inputs, @@ -43,6 +45,7 @@ export const useChat = ( const { notify } = useToastContext() const { handleRun } = useWorkflowRun() const hasStopResponded = useRef(false) + const workflowStore = useWorkflowStore() const conversationId = useRef('') const taskIdRef = useRef('') const [chatList, setChatList] = useState(prevChatList || []) @@ -52,6 +55,9 @@ export const useChat = ( const [suggestedQuestions, setSuggestQuestions] = useState([]) const suggestedQuestionsAbortControllerRef = useRef(null) + const { + setIterTimes, + } = workflowStore.getState() useEffect(() => { setAutoFreeze(false) return () => { @@ -102,15 +108,16 @@ export const useChat = ( handleResponding(false) if (stopChat && taskIdRef.current) stopChat(taskIdRef.current) - + setIterTimes(DEFAULT_ITER_TIMES) if (suggestedQuestionsAbortControllerRef.current) suggestedQuestionsAbortControllerRef.current.abort() - }, [handleResponding, stopChat]) + }, [handleResponding, setIterTimes, stopChat]) const handleRestart = useCallback(() => { conversationId.current = '' taskIdRef.current = '' handleStop() + setIterTimes(DEFAULT_ITER_TIMES) const newChatList = config?.opening_statement ? [{ id: `${Date.now()}`, @@ -126,6 +133,7 @@ export const useChat = ( config, handleStop, handleUpdateChatList, + setIterTimes, ]) const updateCurrentQA = useCallback(({ diff --git a/web/app/components/workflow/run/index.tsx b/web/app/components/workflow/run/index.tsx index 9e636e902b..89db43fa35 100644 --- a/web/app/components/workflow/run/index.tsx +++ b/web/app/components/workflow/run/index.tsx @@ -60,36 +60,67 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe }, [notify, getResultCallback]) const formatNodeList = useCallback((list: NodeTracing[]) => { - const allItems = list.reverse() + const allItems = [...list].reverse() const result: NodeTracing[] = [] - allItems.forEach((item) => { - const { node_type, execution_metadata } = item - if (node_type !== BlockEnum.Iteration) { - const isInIteration = !!execution_metadata?.iteration_id + const groupMap = new Map() - if (isInIteration) { - const iterationNode = result.find(node => node.node_id === execution_metadata?.iteration_id) - const iterationDetails = iterationNode?.details - const currentIterationIndex = execution_metadata?.iteration_index ?? 0 - - if (Array.isArray(iterationDetails)) { - if (iterationDetails.length === 0 || !iterationDetails[currentIterationIndex]) - iterationDetails[currentIterationIndex] = [item] - else - iterationDetails[currentIterationIndex].push(item) - } - return - } - // not in iteration - result.push(item) - - return - } + const processIterationNode = (item: NodeTracing) => { result.push({ ...item, details: [], }) + } + const updateParallelModeGroup = (runId: string, item: NodeTracing, iterationNode: NodeTracing) => { + if (!groupMap.has(runId)) + groupMap.set(runId, [item]) + else + groupMap.get(runId)!.push(item) + if (item.status === 'failed') { + iterationNode.status = 'failed' + iterationNode.error = item.error + } + + iterationNode.details = Array.from(groupMap.values()) + } + const updateSequentialModeGroup = (index: number, item: NodeTracing, iterationNode: NodeTracing) => { + const { details } = iterationNode + if (details) { + if (!details[index]) + details[index] = [item] + else + details[index].push(item) + } + + if (item.status === 'failed') { + iterationNode.status = 'failed' + iterationNode.error = item.error + } + } + const processNonIterationNode = (item: NodeTracing) => { + const { execution_metadata } = item + if (!execution_metadata?.iteration_id) { + result.push(item) + return + } + + const iterationNode = result.find(node => node.node_id === execution_metadata.iteration_id) + if (!iterationNode || !Array.isArray(iterationNode.details)) + return + + const { parallel_mode_run_id, iteration_index = 0 } = execution_metadata + + if (parallel_mode_run_id) + updateParallelModeGroup(parallel_mode_run_id, item, iterationNode) + else + updateSequentialModeGroup(iteration_index, item, iterationNode) + } + + allItems.forEach((item) => { + item.node_type === BlockEnum.Iteration + ? processIterationNode(item) + : processNonIterationNode(item) }) + return result }, []) diff --git a/web/app/components/workflow/run/iteration-result-panel.tsx b/web/app/components/workflow/run/iteration-result-panel.tsx index 7e2f6cbc00..c4cd909f2e 100644 --- a/web/app/components/workflow/run/iteration-result-panel.tsx +++ b/web/app/components/workflow/run/iteration-result-panel.tsx @@ -5,6 +5,7 @@ import { useTranslation } from 'react-i18next' import { RiArrowRightSLine, RiCloseLine, + RiErrorWarningLine, } from '@remixicon/react' import { ArrowNarrowLeft } from '../../base/icons/src/vender/line/arrows' import TracingPanel from './tracing-panel' @@ -27,7 +28,7 @@ const IterationResultPanel: FC = ({ noWrap, }) => { const { t } = useTranslation() - const [expandedIterations, setExpandedIterations] = useState>([]) + const [expandedIterations, setExpandedIterations] = useState>({}) const toggleIteration = useCallback((index: number) => { setExpandedIterations(prev => ({ @@ -71,10 +72,19 @@ const IterationResultPanel: FC = ({ {t(`${i18nPrefix}.iteration`)} {index + 1} - + { + iteration.some(item => item.status === 'failed') + ? ( + + ) + : (< RiArrowRightSLine className={ + cn( + 'w-4 h-4 text-text-tertiary transition-transform duration-200 flex-shrink-0', + expandedIterations[index] && 'transform rotate-90', + )} /> + ) + } + {expandedIterations[index] &&
= ({ return iteration_length } + const getErrorCount = (details: NodeTracing[][] | undefined) => { + if (!details || details.length === 0) + return 0 + return details.reduce((acc, iteration) => { + if (iteration.some(item => item.status === 'failed')) + acc++ + return acc + }, 0) + } useEffect(() => { setCollapseState(!nodeInfo.expand) }, [nodeInfo.expand, setCollapseState]) @@ -136,7 +145,12 @@ const NodePanel: FC = ({ onClick={handleOnShowIterationDetail} > -
{t('workflow.nodes.iteration.iteration', { count: getCount(nodeInfo.details?.length, nodeInfo.metadata?.iterator_length) })}
+
{t('workflow.nodes.iteration.iteration', { count: getCount(nodeInfo.details?.length, nodeInfo.metadata?.iterator_length) })}{getErrorCount(nodeInfo.details) > 0 && ( + <> + {t('workflow.nodes.iteration.comma')} + {t('workflow.nodes.iteration.error', { count: getErrorCount(nodeInfo.details) })} + + )}
{justShowIterationNavArrow ? ( diff --git a/web/app/components/workflow/store.ts b/web/app/components/workflow/store.ts index c2a6823e6b..c4a625c777 100644 --- a/web/app/components/workflow/store.ts +++ b/web/app/components/workflow/store.ts @@ -21,6 +21,7 @@ import type { WorkflowRunningData, } from './types' import { WorkflowContext } from './context' +import type { NodeTracing } from '@/types/workflow' // #TODO chatVar# // const MOCK_DATA = [ @@ -166,6 +167,10 @@ type Shape = { setShowImportDSLModal: (showImportDSLModal: boolean) => void showTips: string setShowTips: (showTips: string) => void + iterTimes: number + setIterTimes: (iterTimes: number) => void + iterParallelLogMap: Map + setIterParallelLogMap: (iterParallelLogMap: Map) => void } export const createWorkflowStore = () => { @@ -281,6 +286,11 @@ export const createWorkflowStore = () => { setShowImportDSLModal: showImportDSLModal => set(() => ({ showImportDSLModal })), showTips: '', setShowTips: showTips => set(() => ({ showTips })), + iterTimes: 1, + setIterTimes: iterTimes => set(() => ({ iterTimes })), + iterParallelLogMap: new Map(), + setIterParallelLogMap: iterParallelLogMap => set(() => ({ iterParallelLogMap })), + })) } diff --git a/web/app/components/workflow/types.ts b/web/app/components/workflow/types.ts index 81bec41eac..9b6ad033bf 100644 --- a/web/app/components/workflow/types.ts +++ b/web/app/components/workflow/types.ts @@ -36,7 +36,11 @@ export enum ControlMode { Pointer = 'pointer', Hand = 'hand', } - +export enum ErrorHandleMode { + Terminated = 'terminated', + ContinueOnError = 'continue-on-error', + RemoveAbnormalOutput = 'remove-abnormal-output', +} export type Branch = { id: string name: string diff --git a/web/app/components/workflow/utils.ts b/web/app/components/workflow/utils.ts index 91656e3bbc..aaf333f4d7 100644 --- a/web/app/components/workflow/utils.ts +++ b/web/app/components/workflow/utils.ts @@ -19,7 +19,7 @@ import type { ToolWithProvider, ValueSelector, } from './types' -import { BlockEnum } from './types' +import { BlockEnum, ErrorHandleMode } from './types' import { CUSTOM_NODE, ITERATION_CHILDREN_Z_INDEX, @@ -267,8 +267,13 @@ export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => { }) } - if (node.data.type === BlockEnum.Iteration) - node.data._children = iterationNodeMap[node.id] || [] + if (node.data.type === BlockEnum.Iteration) { + const iterationNodeData = node.data as IterationNodeType + iterationNodeData._children = iterationNodeMap[node.id] || [] + iterationNodeData.is_parallel = iterationNodeData.is_parallel || false + iterationNodeData.parallel_nums = iterationNodeData.parallel_nums || 10 + iterationNodeData.error_handle_mode = iterationNodeData.error_handle_mode || ErrorHandleMode.Terminated + } return node }) diff --git a/web/i18n/en-US/workflow.ts b/web/i18n/en-US/workflow.ts index ea8355500a..1c6639aba0 100644 --- a/web/i18n/en-US/workflow.ts +++ b/web/i18n/en-US/workflow.ts @@ -556,6 +556,23 @@ const translation = { iteration_one: '{{count}} Iteration', iteration_other: '{{count}} Iterations', currentIteration: 'Current Iteration', + comma: ', ', + error_one: '{{count}} Error', + error_other: '{{count}} Errors', + parallelMode: 'Parallel Mode', + parallelModeUpper: 'PARALLEL MODE', + parallelModeEnableTitle: 'Parallel Mode Enabled', + parallelModeEnableDesc: 'In parallel mode, tasks within iterations support parallel execution. You can configure this in the properties panel on the right.', + parallelPanelDesc: 'In parallel mode, tasks in the iteration support parallel execution.', + MaxParallelismTitle: 'Maximum parallelism', + MaxParallelismDesc: 'The maximum parallelism is used to control the number of tasks executed simultaneously in a single iteration.', + errorResponseMethod: 'Error response method', + ErrorMethod: { + operationTerminated: 'terminated', + continueOnError: 'continue-on-error', + removeAbnormalOutput: 'remove-abnormal-output', + }, + answerNodeWarningDesc: 'Parallel mode warning: Answer nodes, conversation variable assignments, and persistent read/write operations within iterations may cause exceptions.', }, note: { addNote: 'Add Note', diff --git a/web/i18n/zh-Hans/workflow.ts b/web/i18n/zh-Hans/workflow.ts index 515d0fe235..1229ba8c03 100644 --- a/web/i18n/zh-Hans/workflow.ts +++ b/web/i18n/zh-Hans/workflow.ts @@ -556,6 +556,23 @@ const translation = { iteration_one: '{{count}}个迭代', iteration_other: '{{count}}个迭代', currentIteration: '当前迭代', + comma: ',', + error_one: '{{count}}个失败', + error_other: '{{count}}个失败', + parallelMode: '并行模式', + parallelModeUpper: '并行模式', + parallelModeEnableTitle: '并行模式启用', + parallelModeEnableDesc: '启用并行模式时迭代内的任务支持并行执行。你可以在右侧的属性面板中进行配置。', + parallelPanelDesc: '在并行模式下,迭代中的任务支持并行执行。', + MaxParallelismTitle: '最大并行度', + MaxParallelismDesc: '最大并行度用于控制单次迭代中同时执行的任务数量。', + errorResponseMethod: '错误响应方法', + ErrorMethod: { + operationTerminated: '错误时终止', + continueOnError: '忽略错误并继续', + removeAbnormalOutput: '移除错误输出', + }, + answerNodeWarningDesc: '并行模式警告:在迭代中,回答节点、会话变量赋值和工具持久读/写操作可能会导致异常。', }, note: { addNote: '添加注释', diff --git a/web/types/workflow.ts b/web/types/workflow.ts index 810026b084..3c0675b605 100644 --- a/web/types/workflow.ts +++ b/web/types/workflow.ts @@ -19,6 +19,7 @@ export type NodeTracing = { process_data: any outputs?: any status: string + parallel_run_id?: string error?: string elapsed_time: number execution_metadata: { @@ -31,6 +32,7 @@ export type NodeTracing = { parallel_start_node_id?: string parent_parallel_id?: string parent_parallel_start_node_id?: string + parallel_mode_run_id?: string } metadata: { iterator_length: number @@ -121,6 +123,7 @@ export type NodeStartedResponse = { id: string node_id: string iteration_id?: string + parallel_run_id?: string node_type: string index: number predecessor_node_id?: string @@ -166,6 +169,7 @@ export type NodeFinishedResponse = { parallel_start_node_id?: string iteration_index?: number iteration_id?: string + parallel_mode_run_id: string } created_at: number files?: FileResponse[] @@ -200,6 +204,7 @@ export type IterationNextResponse = { output: any extras?: any created_at: number + parallel_mode_run_id: string execution_metadata: { parallel_id?: string } From acb22f0fde2e4fc780d01a626b068de0b9dd2e72 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 5 Nov 2024 10:34:28 +0800 Subject: [PATCH 18/29] =?UTF-8?q?Updates:=20Add=20mplfonts=20library=20for?= =?UTF-8?q?=20customizing=20matplotlib=20fonts=20and=20Va=E2=80=A6=20(#990?= =?UTF-8?q?3)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/poetry.lock | 70 ++++++++++++++++++++++++++++++++++++++++++---- api/pyproject.toml | 3 +- 2 files changed, 67 insertions(+), 6 deletions(-) diff --git a/api/poetry.lock b/api/poetry.lock index f543b2b4b9..2a93fa38f9 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -2532,6 +2532,19 @@ files = [ {file = "filetype-1.2.0.tar.gz", hash = "sha256:66b56cd6474bf41d8c54660347d37afcc3f7d1970648de365c102ef77548aadb"}, ] +[[package]] +name = "fire" +version = "0.7.0" +description = "A library for automatically generating command line interfaces." +optional = false +python-versions = "*" +files = [ + {file = "fire-0.7.0.tar.gz", hash = "sha256:961550f07936eaf65ad1dc8360f2b2bf8408fad46abbfa4d2a3794f8d2a95cdf"}, +] + +[package.dependencies] +termcolor = "*" + [[package]] name = "flasgger" version = "0.9.7.1" @@ -2697,6 +2710,19 @@ files = [ {file = "flatbuffers-24.3.25.tar.gz", hash = "sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4"}, ] +[[package]] +name = "fontmeta" +version = "1.6.1" +description = "An Utility to get ttf/otf font metadata" +optional = false +python-versions = "*" +files = [ + {file = "fontmeta-1.6.1.tar.gz", hash = "sha256:837e5bc4da879394b41bda1428a8a480eb7c4e993799a93cfb582bab771a9c24"}, +] + +[package.dependencies] +fonttools = "*" + [[package]] name = "fonttools" version = "4.54.1" @@ -5279,6 +5305,22 @@ files = [ {file = "monotonic-1.6.tar.gz", hash = "sha256:3a55207bcfed53ddd5c5bae174524062935efed17792e9de2ad0205ce9ad63f7"}, ] +[[package]] +name = "mplfonts" +version = "0.0.8" +description = "Fonts manager for matplotlib" +optional = false +python-versions = ">=3.7" +files = [ + {file = "mplfonts-0.0.8-py3-none-any.whl", hash = "sha256:b2182e5b0baa216cf016dec19942740e5b48956415708ad2d465e03952112ec1"}, + {file = "mplfonts-0.0.8.tar.gz", hash = "sha256:0abcb2fc0605645e1e7561c6923014d856f11676899b33b4d89757843f5e7c22"}, +] + +[package.dependencies] +fire = ">=0.4.0" +fontmeta = ">=1.6.1" +matplotlib = ">=3.4" + [[package]] name = "mpmath" version = "1.3.0" @@ -9300,6 +9342,20 @@ files = [ [package.dependencies] tencentcloud-sdk-python-common = "3.0.1257" +[[package]] +name = "termcolor" +version = "2.5.0" +description = "ANSI color formatting for output in terminal" +optional = false +python-versions = ">=3.9" +files = [ + {file = "termcolor-2.5.0-py3-none-any.whl", hash = "sha256:37b17b5fc1e604945c2642c872a3764b5d547a48009871aea3edd3afa180afb8"}, + {file = "termcolor-2.5.0.tar.gz", hash = "sha256:998d8d27da6d48442e8e1f016119076b690d962507531df4890fcd2db2ef8a6f"}, +] + +[package.extras] +tests = ["pytest", "pytest-cov"] + [[package]] name = "threadpoolctl" version = "3.5.0" @@ -10046,13 +10102,13 @@ files = [ [[package]] name = "vanna" -version = "0.7.3" +version = "0.7.5" description = "Generate SQL queries from natural language" optional = false python-versions = ">=3.9" files = [ - {file = "vanna-0.7.3-py3-none-any.whl", hash = "sha256:82ba39e5d6c503d1c8cca60835ed401d20ec3a3da98d487f529901dcb30061d6"}, - {file = "vanna-0.7.3.tar.gz", hash = "sha256:4590dd94d2fe180b4efc7a83c867b73144ef58794018910dc226857cfb703077"}, + {file = "vanna-0.7.5-py3-none-any.whl", hash = "sha256:07458c7befa49de517a8760c2d80a13147278b484c515d49a906acc88edcb835"}, + {file = "vanna-0.7.5.tar.gz", hash = "sha256:2fdffc58832898e4fc8e93c45b173424db59a22773b22ca348640161d391eacf"}, ] [package.dependencies] @@ -10073,7 +10129,7 @@ sqlparse = "*" tabulate = "*" [package.extras] -all = ["PyMySQL", "anthropic", "azure-common", "azure-identity", "azure-search-documents", "chromadb", "db-dtypes", "duckdb", "fastembed", "google-cloud-aiplatform", "google-cloud-bigquery", "google-generativeai", "httpx", "marqo", "mistralai (>=1.0.0)", "ollama", "openai", "opensearch-dsl", "opensearch-py", "pinecone-client", "psycopg2-binary", "pymilvus[model]", "qdrant-client", "qianfan", "snowflake-connector-python", "transformers", "weaviate-client", "zhipuai"] +all = ["PyMySQL", "anthropic", "azure-common", "azure-identity", "azure-search-documents", "boto", "boto3", "botocore", "chromadb", "db-dtypes", "duckdb", "faiss-cpu", "fastembed", "google-cloud-aiplatform", "google-cloud-bigquery", "google-generativeai", "httpx", "langchain_core", "langchain_postgres", "marqo", "mistralai (>=1.0.0)", "ollama", "openai", "opensearch-dsl", "opensearch-py", "pinecone-client", "psycopg2-binary", "pymilvus[model]", "qdrant-client", "qianfan", "snowflake-connector-python", "transformers", "weaviate-client", "xinference-client", "zhipuai"] anthropic = ["anthropic"] azuresearch = ["azure-common", "azure-identity", "azure-search-documents", "fastembed"] bedrock = ["boto3", "botocore"] @@ -10081,6 +10137,8 @@ bigquery = ["google-cloud-bigquery"] chromadb = ["chromadb"] clickhouse = ["clickhouse_connect"] duckdb = ["duckdb"] +faiss-cpu = ["faiss-cpu"] +faiss-gpu = ["faiss-gpu"] gemini = ["google-generativeai"] google = ["google-cloud-aiplatform", "google-generativeai"] hf = ["transformers"] @@ -10091,6 +10149,7 @@ mysql = ["PyMySQL"] ollama = ["httpx", "ollama"] openai = ["openai"] opensearch = ["opensearch-dsl", "opensearch-py"] +pgvector = ["langchain-postgres (>=0.0.12)"] pinecone = ["fastembed", "pinecone-client"] postgres = ["db-dtypes", "psycopg2-binary"] qdrant = ["fastembed", "qdrant-client"] @@ -10099,6 +10158,7 @@ snowflake = ["snowflake-connector-python"] test = ["tox"] vllm = ["vllm"] weaviate = ["weaviate-client"] +xinference-client = ["xinference-client"] zhipuai = ["zhipuai"] [[package]] @@ -10940,4 +11000,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "ef927b98c33d704d680e08db0e5c7d9a4e05454c66fcd6a5f656a65eb08e886b" +content-hash = "e4794898403da4ad7b51f248a6c07632a949114c1b569406d3aa6a94c62510a5" diff --git a/api/pyproject.toml b/api/pyproject.toml index ee7cf4d618..a79e1641d0 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -206,13 +206,14 @@ cloudscraper = "1.2.71" duckduckgo-search = "~6.3.0" jsonpath-ng = "1.6.1" matplotlib = "~3.8.2" +mplfonts = "~0.0.8" newspaper3k = "0.2.8" nltk = "3.9.1" numexpr = "~2.9.0" pydub = "~0.25.1" qrcode = "~7.4.2" twilio = "~9.0.4" -vanna = { version = "0.7.3", extras = ["postgres", "mysql", "clickhouse", "duckdb"] } +vanna = { version = "0.7.5", extras = ["postgres", "mysql", "clickhouse", "duckdb", "oracle"] } wikipedia = "1.4.0" yfinance = "~0.2.40" From de5dfd99f65151402140f6e0afc16a13154cbe89 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 5 Nov 2024 10:57:32 +0800 Subject: [PATCH 19/29] chore: translate i18n files (#10273) Co-authored-by: laipz8200 <16485841+laipz8200@users.noreply.github.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> --- web/i18n/de-DE/workflow.ts | 17 +++++++++++++++++ web/i18n/es-ES/workflow.ts | 17 +++++++++++++++++ web/i18n/fa-IR/workflow.ts | 17 +++++++++++++++++ web/i18n/fr-FR/workflow.ts | 17 +++++++++++++++++ web/i18n/hi-IN/workflow.ts | 17 +++++++++++++++++ web/i18n/it-IT/workflow.ts | 17 +++++++++++++++++ web/i18n/ja-JP/workflow.ts | 17 +++++++++++++++++ web/i18n/ko-KR/workflow.ts | 17 +++++++++++++++++ web/i18n/pl-PL/workflow.ts | 17 +++++++++++++++++ web/i18n/pt-BR/workflow.ts | 17 +++++++++++++++++ web/i18n/ro-RO/workflow.ts | 17 +++++++++++++++++ web/i18n/ru-RU/workflow.ts | 17 +++++++++++++++++ web/i18n/tr-TR/workflow.ts | 17 +++++++++++++++++ web/i18n/uk-UA/workflow.ts | 17 +++++++++++++++++ web/i18n/vi-VN/workflow.ts | 17 +++++++++++++++++ web/i18n/zh-Hant/workflow.ts | 17 +++++++++++++++++ 16 files changed, 272 insertions(+) diff --git a/web/i18n/de-DE/workflow.ts b/web/i18n/de-DE/workflow.ts index bde0250fcc..d05070c308 100644 --- a/web/i18n/de-DE/workflow.ts +++ b/web/i18n/de-DE/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Iteration', iteration_other: '{{count}} Iterationen', currentIteration: 'Aktuelle Iteration', + ErrorMethod: { + operationTerminated: 'beendet', + removeAbnormalOutput: 'remove-abnormale_ausgabe', + continueOnError: 'Fehler "Fortfahren bei"', + }, + MaxParallelismTitle: 'Maximale Parallelität', + parallelMode: 'Paralleler Modus', + errorResponseMethod: 'Methode der Fehlerantwort', + error_one: '{{Anzahl}} Fehler', + error_other: '{{Anzahl}} Irrtümer', + MaxParallelismDesc: 'Die maximale Parallelität wird verwendet, um die Anzahl der Aufgaben zu steuern, die gleichzeitig in einer einzigen Iteration ausgeführt werden.', + parallelPanelDesc: 'Im parallelen Modus unterstützen Aufgaben in der Iteration die parallele Ausführung.', + parallelModeEnableDesc: 'Im parallelen Modus unterstützen Aufgaben innerhalb von Iterationen die parallele Ausführung. Sie können dies im Eigenschaftenbereich auf der rechten Seite konfigurieren.', + answerNodeWarningDesc: 'Warnung im parallelen Modus: Antwortknoten, Zuweisungen von Konversationsvariablen und persistente Lese-/Schreibvorgänge innerhalb von Iterationen können Ausnahmen verursachen.', + parallelModeEnableTitle: 'Paralleler Modus aktiviert', + parallelModeUpper: 'PARALLELER MODUS', + comma: ',', }, note: { editor: { diff --git a/web/i18n/es-ES/workflow.ts b/web/i18n/es-ES/workflow.ts index 59a330e7f4..6c9af49c4d 100644 --- a/web/i18n/es-ES/workflow.ts +++ b/web/i18n/es-ES/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Iteración', iteration_other: '{{count}} Iteraciones', currentIteration: 'Iteración actual', + ErrorMethod: { + operationTerminated: 'Terminado', + continueOnError: 'Continuar en el error', + removeAbnormalOutput: 'eliminar-salida-anormal', + }, + comma: ',', + errorResponseMethod: 'Método de respuesta a errores', + error_one: '{{conteo}} Error', + parallelPanelDesc: 'En el modo paralelo, las tareas de la iteración admiten la ejecución en paralelo.', + MaxParallelismTitle: 'Máximo paralelismo', + error_other: '{{conteo}} Errores', + parallelMode: 'Modo paralelo', + parallelModeEnableDesc: 'En el modo paralelo, las tareas dentro de las iteraciones admiten la ejecución en paralelo. Puede configurar esto en el panel de propiedades a la derecha.', + parallelModeUpper: 'MODO PARALELO', + MaxParallelismDesc: 'El paralelismo máximo se utiliza para controlar el número de tareas ejecutadas simultáneamente en una sola iteración.', + answerNodeWarningDesc: 'Advertencia de modo paralelo: Los nodos de respuesta, las asignaciones de variables de conversación y las operaciones de lectura/escritura persistentes dentro de las iteraciones pueden provocar excepciones.', + parallelModeEnableTitle: 'Modo paralelo habilitado', }, note: { addNote: 'Agregar nota', diff --git a/web/i18n/fa-IR/workflow.ts b/web/i18n/fa-IR/workflow.ts index b1f9384159..4b00390663 100644 --- a/web/i18n/fa-IR/workflow.ts +++ b/web/i18n/fa-IR/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} تکرار', iteration_other: '{{count}} تکرارها', currentIteration: 'تکرار فعلی', + ErrorMethod: { + continueOnError: 'ادامه در خطا', + operationTerminated: 'فسخ', + removeAbnormalOutput: 'حذف خروجی غیرطبیعی', + }, + error_one: '{{تعداد}} خطا', + error_other: '{{تعداد}} خطاهای', + parallelMode: 'حالت موازی', + errorResponseMethod: 'روش پاسخ به خطا', + parallelModeEnableTitle: 'حالت موازی فعال است', + parallelModeUpper: 'حالت موازی', + comma: ',', + parallelModeEnableDesc: 'در حالت موازی، وظایف درون تکرارها از اجرای موازی پشتیبانی می کنند. می توانید این را در پانل ویژگی ها در سمت راست پیکربندی کنید.', + MaxParallelismTitle: 'حداکثر موازی سازی', + parallelPanelDesc: 'در حالت موازی، وظایف در تکرار از اجرای موازی پشتیبانی می کنند.', + MaxParallelismDesc: 'حداکثر موازی سازی برای کنترل تعداد وظایف اجرا شده به طور همزمان در یک تکرار واحد استفاده می شود.', + answerNodeWarningDesc: 'هشدار حالت موازی: گره های پاسخ، تکالیف متغیر مکالمه و عملیات خواندن/نوشتن مداوم در تکرارها ممکن است باعث استثنائات شود.', }, note: { addNote: 'افزودن یادداشت', diff --git a/web/i18n/fr-FR/workflow.ts b/web/i18n/fr-FR/workflow.ts index e56932455f..e736e2cb07 100644 --- a/web/i18n/fr-FR/workflow.ts +++ b/web/i18n/fr-FR/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Itération', iteration_other: '{{count}} Itérations', currentIteration: 'Itération actuelle', + ErrorMethod: { + operationTerminated: 'Terminé', + removeAbnormalOutput: 'remove-abnormal-output', + continueOnError: 'continuer sur l’erreur', + }, + comma: ',', + error_one: '{{compte}} Erreur', + error_other: '{{compte}} Erreurs', + parallelModeEnableDesc: 'En mode parallèle, les tâches au sein des itérations prennent en charge l’exécution parallèle. Vous pouvez le configurer dans le panneau des propriétés à droite.', + parallelModeUpper: 'MODE PARALLÈLE', + parallelPanelDesc: 'En mode parallèle, les tâches de l’itération prennent en charge l’exécution parallèle.', + MaxParallelismDesc: 'Le parallélisme maximal est utilisé pour contrôler le nombre de tâches exécutées simultanément en une seule itération.', + errorResponseMethod: 'Méthode de réponse aux erreurs', + MaxParallelismTitle: 'Parallélisme maximal', + answerNodeWarningDesc: 'Avertissement en mode parallèle : les nœuds de réponse, les affectations de variables de conversation et les opérations de lecture/écriture persistantes au sein des itérations peuvent provoquer des exceptions.', + parallelModeEnableTitle: 'Mode parallèle activé', + parallelMode: 'Mode parallèle', }, note: { addNote: 'Ajouter note', diff --git a/web/i18n/hi-IN/workflow.ts b/web/i18n/hi-IN/workflow.ts index 1473f78ccd..4112643488 100644 --- a/web/i18n/hi-IN/workflow.ts +++ b/web/i18n/hi-IN/workflow.ts @@ -577,6 +577,23 @@ const translation = { iteration_one: '{{count}} इटरेशन', iteration_other: '{{count}} इटरेशन्स', currentIteration: 'वर्तमान इटरेशन', + ErrorMethod: { + operationTerminated: 'समाप्त', + continueOnError: 'जारी रखें-पर-त्रुटि', + removeAbnormalOutput: 'निकालें-असामान्य-आउटपुट', + }, + comma: ',', + error_other: '{{गिनती}} त्रुटियों', + error_one: '{{गिनती}} चूक', + parallelMode: 'समानांतर मोड', + parallelModeUpper: 'समानांतर मोड', + errorResponseMethod: 'त्रुटि प्रतिक्रिया विधि', + MaxParallelismTitle: 'अधिकतम समांतरता', + parallelModeEnableTitle: 'समानांतर मोड सक्षम किया गया', + parallelModeEnableDesc: 'समानांतर मोड में, पुनरावृत्तियों के भीतर कार्य समानांतर निष्पादन का समर्थन करते हैं। आप इसे दाईं ओर गुण पैनल में कॉन्फ़िगर कर सकते हैं।', + parallelPanelDesc: 'समानांतर मोड में, पुनरावृत्ति में कार्य समानांतर निष्पादन का समर्थन करते हैं।', + MaxParallelismDesc: 'अधिकतम समांतरता का उपयोग एकल पुनरावृत्ति में एक साथ निष्पादित कार्यों की संख्या को नियंत्रित करने के लिए किया जाता है।', + answerNodeWarningDesc: 'समानांतर मोड चेतावनी: उत्तर नोड्स, वार्तालाप चर असाइनमेंट, और पुनरावृत्तियों के भीतर लगातार पढ़ने/लिखने की कार्रवाई अपवाद पैदा कर सकती है।', }, note: { addNote: 'नोट जोड़ें', diff --git a/web/i18n/it-IT/workflow.ts b/web/i18n/it-IT/workflow.ts index 19fa7bfbb5..756fb665af 100644 --- a/web/i18n/it-IT/workflow.ts +++ b/web/i18n/it-IT/workflow.ts @@ -584,6 +584,23 @@ const translation = { iteration_one: '{{count}} Iterazione', iteration_other: '{{count}} Iterazioni', currentIteration: 'Iterazione Corrente', + ErrorMethod: { + operationTerminated: 'Terminato', + continueOnError: 'continua sull\'errore', + removeAbnormalOutput: 'rimuovi-output-anomalo', + }, + error_one: '{{conteggio}} Errore', + parallelMode: 'Modalità parallela', + MaxParallelismTitle: 'Parallelismo massimo', + error_other: '{{conteggio}} Errori', + parallelModeEnableDesc: 'In modalità parallela, le attività all\'interno delle iterazioni supportano l\'esecuzione parallela. È possibile configurare questa opzione nel pannello delle proprietà a destra.', + MaxParallelismDesc: 'Il parallelismo massimo viene utilizzato per controllare il numero di attività eseguite contemporaneamente in una singola iterazione.', + errorResponseMethod: 'Metodo di risposta all\'errore', + parallelModeEnableTitle: 'Modalità parallela abilitata', + parallelModeUpper: 'MODALITÀ PARALLELA', + comma: ',', + parallelPanelDesc: 'In modalità parallela, le attività nell\'iterazione supportano l\'esecuzione parallela.', + answerNodeWarningDesc: 'Avviso in modalità parallela: i nodi di risposta, le assegnazioni di variabili di conversazione e le operazioni di lettura/scrittura persistenti all\'interno delle iterazioni possono causare eccezioni.', }, note: { addNote: 'Aggiungi Nota', diff --git a/web/i18n/ja-JP/workflow.ts b/web/i18n/ja-JP/workflow.ts index b6c7786081..a82ba71e48 100644 --- a/web/i18n/ja-JP/workflow.ts +++ b/web/i18n/ja-JP/workflow.ts @@ -558,6 +558,23 @@ const translation = { iteration_one: '{{count}} イテレーション', iteration_other: '{{count}} イテレーション', currentIteration: '現在のイテレーション', + ErrorMethod: { + operationTerminated: '終了', + continueOnError: 'エラー時に続行', + removeAbnormalOutput: 'アブノーマルアウトプットの削除', + }, + comma: ',', + error_other: '{{カウント}}エラー', + error_one: '{{カウント}}エラー', + parallelModeUpper: 'パラレルモード', + parallelMode: 'パラレルモード', + MaxParallelismTitle: '最大並列処理', + errorResponseMethod: 'エラー応答方式', + parallelPanelDesc: '並列モードでは、イテレーションのタスクは並列実行をサポートします。', + parallelModeEnableDesc: '並列モードでは、イテレーション内のタスクは並列実行をサポートします。これは、右側のプロパティパネルで構成できます。', + parallelModeEnableTitle: 'パラレルモード有効', + MaxParallelismDesc: '最大並列処理は、1 回の反復で同時に実行されるタスクの数を制御するために使用されます。', + answerNodeWarningDesc: '並列モードの警告: 応答ノード、会話変数の割り当て、およびイテレーション内の永続的な読み取り/書き込み操作により、例外が発生する可能性があります。', }, note: { addNote: 'コメントを追加', diff --git a/web/i18n/ko-KR/workflow.ts b/web/i18n/ko-KR/workflow.ts index b62aff2068..589831401c 100644 --- a/web/i18n/ko-KR/workflow.ts +++ b/web/i18n/ko-KR/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} 반복', iteration_other: '{{count}} 반복', currentIteration: '현재 반복', + ErrorMethod: { + operationTerminated: '종료', + continueOnError: '오류 발생 시 계속', + removeAbnormalOutput: '비정상 출력 제거', + }, + comma: ',', + error_one: '{{개수}} 오류', + parallelMode: '병렬 모드', + errorResponseMethod: '오류 응답 방법', + parallelModeUpper: '병렬 모드', + MaxParallelismTitle: '최대 병렬 처리', + error_other: '{{개수}} 오류', + parallelModeEnableTitle: 'Parallel Mode Enabled(병렬 모드 사용)', + parallelPanelDesc: '병렬 모드에서 반복의 작업은 병렬 실행을 지원합니다.', + parallelModeEnableDesc: '병렬 모드에서는 반복 내의 작업이 병렬 실행을 지원합니다. 오른쪽의 속성 패널에서 이를 구성할 수 있습니다.', + MaxParallelismDesc: '최대 병렬 처리는 단일 반복에서 동시에 실행되는 작업 수를 제어하는 데 사용됩니다.', + answerNodeWarningDesc: '병렬 모드 경고: 응답 노드, 대화 변수 할당 및 반복 내의 지속적인 읽기/쓰기 작업으로 인해 예외가 발생할 수 있습니다.', }, note: { editor: { diff --git a/web/i18n/pl-PL/workflow.ts b/web/i18n/pl-PL/workflow.ts index aace1b2642..f118f7945c 100644 --- a/web/i18n/pl-PL/workflow.ts +++ b/web/i18n/pl-PL/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Iteracja', iteration_other: '{{count}} Iteracje', currentIteration: 'Bieżąca iteracja', + ErrorMethod: { + continueOnError: 'kontynuacja w przypadku błędu', + operationTerminated: 'Zakończone', + removeAbnormalOutput: 'usuń-nieprawidłowe-wyjście', + }, + comma: ',', + parallelModeUpper: 'TRYB RÓWNOLEGŁY', + parallelModeEnableTitle: 'Włączony tryb równoległy', + MaxParallelismTitle: 'Maksymalna równoległość', + error_one: '{{liczba}} Błąd', + error_other: '{{liczba}} Błędy', + parallelPanelDesc: 'W trybie równoległym zadania w iteracji obsługują wykonywanie równoległe.', + parallelMode: 'Tryb równoległy', + MaxParallelismDesc: 'Maksymalna równoległość służy do kontrolowania liczby zadań wykonywanych jednocześnie w jednej iteracji.', + parallelModeEnableDesc: 'W trybie równoległym zadania w iteracjach obsługują wykonywanie równoległe. Możesz to skonfigurować w panelu właściwości po prawej stronie.', + answerNodeWarningDesc: 'Ostrzeżenie w trybie równoległym: węzły odpowiedzi, przypisania zmiennych konwersacji i trwałe operacje odczytu/zapisu w iteracjach mogą powodować wyjątki.', + errorResponseMethod: 'Metoda odpowiedzi na błąd', }, note: { editor: { diff --git a/web/i18n/pt-BR/workflow.ts b/web/i18n/pt-BR/workflow.ts index f0f2fec0e2..44afda5cd4 100644 --- a/web/i18n/pt-BR/workflow.ts +++ b/web/i18n/pt-BR/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Iteração', iteration_other: '{{count}} Iterações', currentIteration: 'Iteração atual', + ErrorMethod: { + continueOnError: 'continuar em erro', + removeAbnormalOutput: 'saída anormal de remoção', + operationTerminated: 'Terminada', + }, + MaxParallelismTitle: 'Paralelismo máximo', + parallelModeEnableTitle: 'Modo paralelo ativado', + errorResponseMethod: 'Método de resposta de erro', + error_other: '{{contagem}} Erros', + parallelMode: 'Modo paralelo', + parallelModeUpper: 'MODO PARALELO', + error_one: '{{contagem}} Erro', + parallelModeEnableDesc: 'No modo paralelo, as tarefas dentro das iterações dão suporte à execução paralela. Você pode configurar isso no painel de propriedades à direita.', + comma: ',', + MaxParallelismDesc: 'O paralelismo máximo é usado para controlar o número de tarefas executadas simultaneamente em uma única iteração.', + answerNodeWarningDesc: 'Aviso de modo paralelo: nós de resposta, atribuições de variáveis de conversação e operações persistentes de leitura/gravação em iterações podem causar exceções.', + parallelPanelDesc: 'No modo paralelo, as tarefas na iteração dão suporte à execução paralela.', }, note: { editor: { diff --git a/web/i18n/ro-RO/workflow.ts b/web/i18n/ro-RO/workflow.ts index ab0100d347..d8cd84f730 100644 --- a/web/i18n/ro-RO/workflow.ts +++ b/web/i18n/ro-RO/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Iterație', iteration_other: '{{count}} Iterații', currentIteration: 'Iterație curentă', + ErrorMethod: { + operationTerminated: 'Încheiată', + continueOnError: 'continuare-la-eroare', + removeAbnormalOutput: 'elimină-ieșire-anormală', + }, + parallelModeEnableTitle: 'Modul paralel activat', + errorResponseMethod: 'Metoda de răspuns la eroare', + comma: ',', + parallelModeEnableDesc: 'În modul paralel, sarcinile din iterații acceptă execuția paralelă. Puteți configura acest lucru în panoul de proprietăți din dreapta.', + parallelModeUpper: 'MOD PARALEL', + MaxParallelismTitle: 'Paralelism maxim', + parallelMode: 'Mod paralel', + error_other: '{{număr}} Erori', + error_one: '{{număr}} Eroare', + parallelPanelDesc: 'În modul paralel, activitățile din iterație acceptă execuția paralelă.', + MaxParallelismDesc: 'Paralelismul maxim este utilizat pentru a controla numărul de sarcini executate simultan într-o singură iterație.', + answerNodeWarningDesc: 'Avertisment modul paralel: Nodurile de răspuns, atribuirea variabilelor de conversație și operațiunile persistente de citire/scriere în iterații pot cauza excepții.', }, note: { editor: { diff --git a/web/i18n/ru-RU/workflow.ts b/web/i18n/ru-RU/workflow.ts index 27735fbb7d..c822f8c3e5 100644 --- a/web/i18n/ru-RU/workflow.ts +++ b/web/i18n/ru-RU/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Итерация', iteration_other: '{{count}} Итераций', currentIteration: 'Текущая итерация', + ErrorMethod: { + operationTerminated: 'Прекращено', + continueOnError: 'продолжить по ошибке', + removeAbnormalOutput: 'удалить аномальный вывод', + }, + comma: ',', + error_other: '{{Количество}} Ошибки', + errorResponseMethod: 'Метод реагирования на ошибку', + MaxParallelismTitle: 'Максимальный параллелизм', + parallelModeUpper: 'ПАРАЛЛЕЛЬНЫЙ РЕЖИМ', + error_one: '{{Количество}} Ошибка', + parallelModeEnableTitle: 'Параллельный режим включен', + parallelMode: 'Параллельный режим', + parallelPanelDesc: 'В параллельном режиме задачи в итерации поддерживают параллельное выполнение.', + parallelModeEnableDesc: 'В параллельном режиме задачи в итерациях поддерживают параллельное выполнение. Вы можете настроить это на панели свойств справа.', + MaxParallelismDesc: 'Максимальный параллелизм используется для управления количеством задач, выполняемых одновременно в одной итерации.', + answerNodeWarningDesc: 'Предупреждение о параллельном режиме: узлы ответов, присвоение переменных диалога и постоянные операции чтения и записи в итерациях могут вызывать исключения.', }, note: { addNote: 'Добавить заметку', diff --git a/web/i18n/tr-TR/workflow.ts b/web/i18n/tr-TR/workflow.ts index 82718ebc03..e6e25f6d0e 100644 --- a/web/i18n/tr-TR/workflow.ts +++ b/web/i18n/tr-TR/workflow.ts @@ -558,6 +558,23 @@ const translation = { iteration_one: '{{count}} Yineleme', iteration_other: '{{count}} Yineleme', currentIteration: 'Mevcut Yineleme', + ErrorMethod: { + operationTerminated: 'Sonlandırıldı', + continueOnError: 'Hata Üzerine Devam Et', + removeAbnormalOutput: 'anormal çıktıyı kaldır', + }, + parallelModeUpper: 'PARALEL MOD', + parallelMode: 'Paralel Mod', + MaxParallelismTitle: 'Maksimum paralellik', + error_one: '{{sayı}} Hata', + errorResponseMethod: 'Hata yanıtı yöntemi', + comma: ',', + parallelModeEnableTitle: 'Paralel Mod Etkin', + error_other: '{{sayı}} Hata', + parallelPanelDesc: 'Paralel modda, yinelemedeki görevler paralel yürütmeyi destekler.', + answerNodeWarningDesc: 'Paralel mod uyarısı: Yinelemeler içindeki yanıt düğümleri, konuşma değişkeni atamaları ve kalıcı okuma/yazma işlemleri özel durumlara neden olabilir.', + parallelModeEnableDesc: 'Paralel modda, yinelemeler içindeki görevler paralel yürütmeyi destekler. Bunu sağdaki özellikler panelinde yapılandırabilirsiniz.', + MaxParallelismDesc: 'Maksimum paralellik, tek bir yinelemede aynı anda yürütülen görevlerin sayısını kontrol etmek için kullanılır.', }, note: { addNote: 'Not Ekle', diff --git a/web/i18n/uk-UA/workflow.ts b/web/i18n/uk-UA/workflow.ts index 1828b6499f..663b5e4c13 100644 --- a/web/i18n/uk-UA/workflow.ts +++ b/web/i18n/uk-UA/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Ітерація', iteration_other: '{{count}} Ітерацій', currentIteration: 'Поточна ітерація', + ErrorMethod: { + operationTerminated: 'Припинено', + continueOnError: 'Продовжити після помилки', + removeAbnormalOutput: 'видалити-ненормальний-вивід', + }, + error_one: '{{count}} Помилка', + comma: ',', + MaxParallelismTitle: 'Максимальна паралельність', + parallelModeUpper: 'ПАРАЛЕЛЬНИЙ РЕЖИМ', + error_other: '{{count}} Помилки', + parallelMode: 'Паралельний режим', + parallelModeEnableTitle: 'Увімкнено паралельний режим', + errorResponseMethod: 'Метод реагування на помилку', + parallelPanelDesc: 'У паралельному режимі завдання в ітерації підтримують паралельне виконання.', + parallelModeEnableDesc: 'У паралельному режимі завдання всередині ітерацій підтримують паралельне виконання. Ви можете налаштувати це на панелі властивостей праворуч.', + MaxParallelismDesc: 'Максимальний паралелізм використовується для контролю числа завдань, що виконуються одночасно за одну ітерацію.', + answerNodeWarningDesc: 'Попередження в паралельному режимі: вузли відповідей, призначення змінних розмови та постійні операції читання/запису в межах ітерацій можуть спричинити винятки.', }, note: { editor: { diff --git a/web/i18n/vi-VN/workflow.ts b/web/i18n/vi-VN/workflow.ts index 2866af8a2a..1176fdd2b5 100644 --- a/web/i18n/vi-VN/workflow.ts +++ b/web/i18n/vi-VN/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}} Lặp', iteration_other: '{{count}} Lặp', currentIteration: 'Lặp hiện tại', + ErrorMethod: { + operationTerminated: 'Chấm dứt', + removeAbnormalOutput: 'loại bỏ-bất thường-đầu ra', + continueOnError: 'Tiếp tục lỗi', + }, + comma: ',', + error_other: '{{đếm}} Lỗi', + error_one: '{{đếm}} Lỗi', + MaxParallelismTitle: 'Song song tối đa', + parallelPanelDesc: 'Ở chế độ song song, các tác vụ trong quá trình lặp hỗ trợ thực thi song song.', + parallelMode: 'Chế độ song song', + parallelModeEnableTitle: 'Đã bật Chế độ song song', + errorResponseMethod: 'Phương pháp phản hồi lỗi', + MaxParallelismDesc: 'Tính song song tối đa được sử dụng để kiểm soát số lượng tác vụ được thực hiện đồng thời trong một lần lặp.', + answerNodeWarningDesc: 'Cảnh báo chế độ song song: Các nút trả lời, bài tập biến hội thoại và các thao tác đọc/ghi liên tục trong các lần lặp có thể gây ra ngoại lệ.', + parallelModeEnableDesc: 'Trong chế độ song song, các tác vụ trong các lần lặp hỗ trợ thực thi song song. Bạn có thể định cấu hình điều này trong bảng thuộc tính ở bên phải.', + parallelModeUpper: 'CHẾ ĐỘ SONG SONG', }, note: { editor: { diff --git a/web/i18n/zh-Hant/workflow.ts b/web/i18n/zh-Hant/workflow.ts index d65b3999d2..f3fbfdedc2 100644 --- a/web/i18n/zh-Hant/workflow.ts +++ b/web/i18n/zh-Hant/workflow.ts @@ -557,6 +557,23 @@ const translation = { iteration_one: '{{count}}個迭代', iteration_other: '{{count}}個迭代', currentIteration: '當前迭代', + ErrorMethod: { + operationTerminated: '終止', + removeAbnormalOutput: 'remove-abnormal-output', + continueOnError: '出錯時繼續', + }, + comma: ',', + parallelMode: '並行模式', + parallelModeEnableTitle: 'Parallel Mode 已啟用', + MaxParallelismTitle: '最大並行度', + parallelModeUpper: '並行模式', + parallelPanelDesc: '在並行模式下,反覆運算中的任務支援並行執行。', + error_one: '{{count}}錯誤', + errorResponseMethod: '錯誤回應方法', + parallelModeEnableDesc: '在並行模式下,反覆運算中的任務支援並行執行。您可以在右側的 properties 面板中進行配置。', + answerNodeWarningDesc: '並行模式警告:反覆運算中的應答節點、對話變數賦值和持久讀/寫操作可能會導致異常。', + error_other: '{{count}}錯誤', + MaxParallelismDesc: '最大並行度用於控制在單個反覆運算中同時執行的任務數。', }, note: { editor: { From 302f4407f6ec88c1cd68cc4ab8d809a8f301c472 Mon Sep 17 00:00:00 2001 From: NFish Date: Tue, 5 Nov 2024 12:38:31 +0800 Subject: [PATCH 20/29] refactor the logic of refreshing access_token (#10068) --- web/app/account/avatar.tsx | 5 +- .../header/account-dropdown/index.tsx | 5 +- web/app/components/swr-initor.tsx | 39 ++---- web/app/signin/normalForm.tsx | 5 +- web/hooks/use-refresh-token.ts | 99 -------------- web/service/base.ts | 128 +++++++++++------- web/service/refresh-token.ts | 75 ++++++++++ 7 files changed, 171 insertions(+), 185 deletions(-) delete mode 100644 web/hooks/use-refresh-token.ts create mode 100644 web/service/refresh-token.ts diff --git a/web/app/account/avatar.tsx b/web/app/account/avatar.tsx index 2b9aeba5da..544e43ab27 100644 --- a/web/app/account/avatar.tsx +++ b/web/app/account/avatar.tsx @@ -23,8 +23,9 @@ export default function AppSelector() { params: {}, }) - if (localStorage?.getItem('console_token')) - localStorage.removeItem('console_token') + localStorage.removeItem('setup_status') + localStorage.removeItem('console_token') + localStorage.removeItem('refresh_token') router.push('/signin') } diff --git a/web/app/components/header/account-dropdown/index.tsx b/web/app/components/header/account-dropdown/index.tsx index 712906ebae..14f079c0f2 100644 --- a/web/app/components/header/account-dropdown/index.tsx +++ b/web/app/components/header/account-dropdown/index.tsx @@ -47,8 +47,9 @@ export default function AppSelector({ isMobile }: IAppSelector) { params: {}, }) - if (localStorage?.getItem('console_token')) - localStorage.removeItem('console_token') + localStorage.removeItem('setup_status') + localStorage.removeItem('console_token') + localStorage.removeItem('refresh_token') router.push('/signin') } diff --git a/web/app/components/swr-initor.tsx b/web/app/components/swr-initor.tsx index ff9a7b832f..2a119df996 100644 --- a/web/app/components/swr-initor.tsx +++ b/web/app/components/swr-initor.tsx @@ -4,7 +4,6 @@ import { SWRConfig } from 'swr' import { useCallback, useEffect, useState } from 'react' import type { ReactNode } from 'react' import { usePathname, useRouter, useSearchParams } from 'next/navigation' -import useRefreshToken from '@/hooks/use-refresh-token' import { fetchSetupStatus } from '@/service/common' type SwrInitorProps = { @@ -15,12 +14,11 @@ const SwrInitor = ({ }: SwrInitorProps) => { const router = useRouter() const searchParams = useSearchParams() - const pathname = usePathname() - const { getNewAccessToken } = useRefreshToken() - const consoleToken = searchParams.get('access_token') - const refreshToken = searchParams.get('refresh_token') + const consoleToken = decodeURIComponent(searchParams.get('access_token') || '') + const refreshToken = decodeURIComponent(searchParams.get('refresh_token') || '') const consoleTokenFromLocalStorage = localStorage?.getItem('console_token') const refreshTokenFromLocalStorage = localStorage?.getItem('refresh_token') + const pathname = usePathname() const [init, setInit] = useState(false) const isSetupFinished = useCallback(async () => { @@ -41,25 +39,6 @@ const SwrInitor = ({ } }, []) - const setRefreshToken = useCallback(async () => { - try { - if (!(consoleToken || refreshToken || consoleTokenFromLocalStorage || refreshTokenFromLocalStorage)) - return Promise.reject(new Error('No token found')) - - if (consoleTokenFromLocalStorage && refreshTokenFromLocalStorage) - await getNewAccessToken() - - if (consoleToken && refreshToken) { - localStorage.setItem('console_token', consoleToken) - localStorage.setItem('refresh_token', refreshToken) - await getNewAccessToken() - } - } - catch (error) { - return Promise.reject(error) - } - }, [consoleToken, refreshToken, consoleTokenFromLocalStorage, refreshTokenFromLocalStorage, getNewAccessToken]) - useEffect(() => { (async () => { try { @@ -68,9 +47,15 @@ const SwrInitor = ({ router.replace('/install') return } - await setRefreshToken() - if (searchParams.has('access_token') || searchParams.has('refresh_token')) + if (!((consoleToken && refreshToken) || (consoleTokenFromLocalStorage && refreshTokenFromLocalStorage))) { + router.replace('/signin') + return + } + if (searchParams.has('access_token') || searchParams.has('refresh_token')) { + consoleToken && localStorage.setItem('console_token', consoleToken) + refreshToken && localStorage.setItem('refresh_token', refreshToken) router.replace(pathname) + } setInit(true) } @@ -78,7 +63,7 @@ const SwrInitor = ({ router.replace('/signin') } })() - }, [isSetupFinished, setRefreshToken, router, pathname, searchParams]) + }, [isSetupFinished, router, pathname, searchParams, consoleToken, refreshToken, consoleTokenFromLocalStorage, refreshTokenFromLocalStorage]) return init ? ( diff --git a/web/app/signin/normalForm.tsx b/web/app/signin/normalForm.tsx index c0f2d89b37..f4f46c68ba 100644 --- a/web/app/signin/normalForm.tsx +++ b/web/app/signin/normalForm.tsx @@ -12,11 +12,9 @@ import cn from '@/utils/classnames' import { getSystemFeatures, invitationCheck } from '@/service/common' import { defaultSystemFeatures } from '@/types/feature' import Toast from '@/app/components/base/toast' -import useRefreshToken from '@/hooks/use-refresh-token' import { IS_CE_EDITION } from '@/config' const NormalForm = () => { - const { getNewAccessToken } = useRefreshToken() const { t } = useTranslation() const router = useRouter() const searchParams = useSearchParams() @@ -38,7 +36,6 @@ const NormalForm = () => { if (consoleToken && refreshToken) { localStorage.setItem('console_token', consoleToken) localStorage.setItem('refresh_token', refreshToken) - getNewAccessToken() router.replace('/apps') return } @@ -71,7 +68,7 @@ const NormalForm = () => { setSystemFeatures(defaultSystemFeatures) } finally { setIsLoading(false) } - }, [consoleToken, refreshToken, message, router, invite_token, isInviteLink, getNewAccessToken]) + }, [consoleToken, refreshToken, message, router, invite_token, isInviteLink]) useEffect(() => { init() }, [init]) diff --git a/web/hooks/use-refresh-token.ts b/web/hooks/use-refresh-token.ts deleted file mode 100644 index 53dc4faf00..0000000000 --- a/web/hooks/use-refresh-token.ts +++ /dev/null @@ -1,99 +0,0 @@ -'use client' -import { useCallback, useEffect, useRef } from 'react' -import { jwtDecode } from 'jwt-decode' -import dayjs from 'dayjs' -import utc from 'dayjs/plugin/utc' -import { useRouter } from 'next/navigation' -import type { CommonResponse } from '@/models/common' -import { fetchNewToken } from '@/service/common' -import { fetchWithRetry } from '@/utils' - -dayjs.extend(utc) - -const useRefreshToken = () => { - const router = useRouter() - const timer = useRef() - const advanceTime = useRef(5 * 60 * 1000) - - const getExpireTime = useCallback((token: string) => { - if (!token) - return 0 - const decoded = jwtDecode(token) - return (decoded.exp || 0) * 1000 - }, []) - - const getCurrentTimeStamp = useCallback(() => { - return dayjs.utc().valueOf() - }, []) - - const handleError = useCallback(() => { - localStorage?.removeItem('is_refreshing') - localStorage?.removeItem('console_token') - localStorage?.removeItem('refresh_token') - router.replace('/signin') - }, []) - - const getNewAccessToken = useCallback(async () => { - const currentAccessToken = localStorage?.getItem('console_token') - const currentRefreshToken = localStorage?.getItem('refresh_token') - if (!currentAccessToken || !currentRefreshToken) { - handleError() - return new Error('No access token or refresh token found') - } - if (localStorage?.getItem('is_refreshing') === '1') { - clearTimeout(timer.current) - timer.current = setTimeout(() => { - getNewAccessToken() - }, 1000) - return null - } - const currentTokenExpireTime = getExpireTime(currentAccessToken) - if (getCurrentTimeStamp() + advanceTime.current > currentTokenExpireTime) { - localStorage?.setItem('is_refreshing', '1') - const [e, res] = await fetchWithRetry(fetchNewToken({ - body: { refresh_token: currentRefreshToken }, - }) as Promise) - if (e) { - handleError() - return e - } - const { access_token, refresh_token } = res.data - localStorage?.setItem('is_refreshing', '0') - localStorage?.setItem('console_token', access_token) - localStorage?.setItem('refresh_token', refresh_token) - const newTokenExpireTime = getExpireTime(access_token) - clearTimeout(timer.current) - timer.current = setTimeout(() => { - getNewAccessToken() - }, newTokenExpireTime - advanceTime.current - getCurrentTimeStamp()) - } - else { - const newTokenExpireTime = getExpireTime(currentAccessToken) - clearTimeout(timer.current) - timer.current = setTimeout(() => { - getNewAccessToken() - }, newTokenExpireTime - advanceTime.current - getCurrentTimeStamp()) - } - return null - }, [getExpireTime, getCurrentTimeStamp, handleError]) - - const handleVisibilityChange = useCallback(() => { - if (document.visibilityState === 'visible') - getNewAccessToken() - }, []) - - useEffect(() => { - window.addEventListener('visibilitychange', handleVisibilityChange) - return () => { - window.removeEventListener('visibilitychange', handleVisibilityChange) - clearTimeout(timer.current) - localStorage?.removeItem('is_refreshing') - } - }, []) - - return { - getNewAccessToken, - } -} - -export default useRefreshToken diff --git a/web/service/base.ts b/web/service/base.ts index fbdd5c1fd3..fcf8d8bd7d 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -1,3 +1,4 @@ +import { refreshAccessTokenOrRelogin } from './refresh-token' import { API_PREFIX, IS_CE_EDITION, PUBLIC_API_PREFIX } from '@/config' import Toast from '@/app/components/base/toast' import type { AnnotationReply, MessageEnd, MessageReplace, ThoughtItem } from '@/app/components/base/chat/chat/type' @@ -356,39 +357,8 @@ const baseFetch = ( if (!/^(2|3)\d{2}$/.test(String(res.status))) { const bodyJson = res.json() switch (res.status) { - case 401: { - if (isPublicAPI) { - return bodyJson.then((data: ResponseError) => { - if (data.code === 'web_sso_auth_required') - requiredWebSSOLogin() - - if (data.code === 'unauthorized') { - removeAccessToken() - globalThis.location.reload() - } - - return Promise.reject(data) - }) - } - const loginUrl = `${globalThis.location.origin}/signin` - bodyJson.then((data: ResponseError) => { - if (data.code === 'init_validate_failed' && IS_CE_EDITION && !silent) - Toast.notify({ type: 'error', message: data.message, duration: 4000 }) - else if (data.code === 'not_init_validated' && IS_CE_EDITION) - globalThis.location.href = `${globalThis.location.origin}/init` - else if (data.code === 'not_setup' && IS_CE_EDITION) - globalThis.location.href = `${globalThis.location.origin}/install` - else if (location.pathname !== '/signin' || !IS_CE_EDITION) - globalThis.location.href = loginUrl - else if (!silent) - Toast.notify({ type: 'error', message: data.message }) - }).catch(() => { - // Handle any other errors - globalThis.location.href = loginUrl - }) - - break - } + case 401: + return Promise.reject(resClone) case 403: bodyJson.then((data: ResponseError) => { if (!silent) @@ -484,7 +454,9 @@ export const upload = (options: any, isPublicAPI?: boolean, url?: string, search export const ssePost = ( url: string, fetchOptions: FetchOptionType, - { + otherOptions: IOtherOptions, +) => { + const { isPublicAPI = false, onData, onCompleted, @@ -507,8 +479,7 @@ export const ssePost = ( onTextReplace, onError, getAbortController, - }: IOtherOptions, -) => { + } = otherOptions const abortController = new AbortController() const options = Object.assign({}, baseOptions, { @@ -532,21 +503,29 @@ export const ssePost = ( globalThis.fetch(urlWithPrefix, options as RequestInit) .then((res) => { if (!/^(2|3)\d{2}$/.test(String(res.status))) { - res.json().then((data: any) => { - if (isPublicAPI) { - if (data.code === 'web_sso_auth_required') - requiredWebSSOLogin() + if (res.status === 401) { + refreshAccessTokenOrRelogin(TIME_OUT).then(() => { + ssePost(url, fetchOptions, otherOptions) + }).catch(() => { + res.json().then((data: any) => { + if (isPublicAPI) { + if (data.code === 'web_sso_auth_required') + requiredWebSSOLogin() - if (data.code === 'unauthorized') { - removeAccessToken() - globalThis.location.reload() - } - if (res.status === 401) - return - } - Toast.notify({ type: 'error', message: data.message || 'Server Error' }) - }) - onError?.('Server Error') + if (data.code === 'unauthorized') { + removeAccessToken() + globalThis.location.reload() + } + } + }) + }) + } + else { + res.json().then((data) => { + Toast.notify({ type: 'error', message: data.message || 'Server Error' }) + }) + onError?.('Server Error') + } return } return handleStream(res, (str: string, isFirstMessage: boolean, moreInfo: IOnDataMoreInfo) => { @@ -568,7 +547,54 @@ export const ssePost = ( // base request export const request = (url: string, options = {}, otherOptions?: IOtherOptions) => { - return baseFetch(url, options, otherOptions || {}) + return new Promise((resolve, reject) => { + const otherOptionsForBaseFetch = otherOptions || {} + baseFetch(url, options, otherOptionsForBaseFetch).then(resolve).catch((errResp) => { + if (errResp?.status === 401) { + return refreshAccessTokenOrRelogin(TIME_OUT).then(() => { + baseFetch(url, options, otherOptionsForBaseFetch).then(resolve).catch(reject) + }).catch(() => { + const { + isPublicAPI = false, + silent, + } = otherOptionsForBaseFetch + const bodyJson = errResp.json() + if (isPublicAPI) { + return bodyJson.then((data: ResponseError) => { + if (data.code === 'web_sso_auth_required') + requiredWebSSOLogin() + + if (data.code === 'unauthorized') { + removeAccessToken() + globalThis.location.reload() + } + + return Promise.reject(data) + }) + } + const loginUrl = `${globalThis.location.origin}/signin` + bodyJson.then((data: ResponseError) => { + if (data.code === 'init_validate_failed' && IS_CE_EDITION && !silent) + Toast.notify({ type: 'error', message: data.message, duration: 4000 }) + else if (data.code === 'not_init_validated' && IS_CE_EDITION) + globalThis.location.href = `${globalThis.location.origin}/init` + else if (data.code === 'not_setup' && IS_CE_EDITION) + globalThis.location.href = `${globalThis.location.origin}/install` + else if (location.pathname !== '/signin' || !IS_CE_EDITION) + globalThis.location.href = loginUrl + else if (!silent) + Toast.notify({ type: 'error', message: data.message }) + }).catch(() => { + // Handle any other errors + globalThis.location.href = loginUrl + }) + }) + } + else { + reject(errResp) + } + }) + }) } // request methods diff --git a/web/service/refresh-token.ts b/web/service/refresh-token.ts new file mode 100644 index 0000000000..8bd2215041 --- /dev/null +++ b/web/service/refresh-token.ts @@ -0,0 +1,75 @@ +import { apiPrefix } from '@/config' +import { fetchWithRetry } from '@/utils' + +let isRefreshing = false +function waitUntilTokenRefreshed() { + return new Promise((resolve, reject) => { + function _check() { + const isRefreshingSign = localStorage.getItem('is_refreshing') + if ((isRefreshingSign && isRefreshingSign === '1') || isRefreshing) { + setTimeout(() => { + _check() + }, 1000) + } + else { + resolve() + } + } + _check() + }) +} + +// only one request can send +async function getNewAccessToken(): Promise { + try { + const isRefreshingSign = localStorage.getItem('is_refreshing') + if ((isRefreshingSign && isRefreshingSign === '1') || isRefreshing) { + await waitUntilTokenRefreshed() + } + else { + globalThis.localStorage.setItem('is_refreshing', '1') + isRefreshing = true + const refresh_token = globalThis.localStorage.getItem('refresh_token') + + // Do not use baseFetch to refresh tokens. + // If a 401 response occurs and baseFetch itself attempts to refresh the token, + // it can lead to an infinite loop if the refresh attempt also returns 401. + // To avoid this, handle token refresh separately in a dedicated function + // that does not call baseFetch and uses a single retry mechanism. + const [error, ret] = await fetchWithRetry(globalThis.fetch(`${apiPrefix}/refresh-token`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json;utf-8', + }, + body: JSON.stringify({ refresh_token }), + })) + if (error) { + return Promise.reject(error) + } + else { + if (ret.status === 401) + return Promise.reject(ret) + + const { data } = await ret.json() + globalThis.localStorage.setItem('console_token', data.access_token) + globalThis.localStorage.setItem('refresh_token', data.refresh_token) + } + } + } + catch (error) { + console.error(error) + return Promise.reject(error) + } + finally { + isRefreshing = false + globalThis.localStorage.removeItem('is_refreshing') + } +} + +export async function refreshAccessTokenOrRelogin(timeout: number) { + return Promise.race([new Promise((resolve, reject) => setTimeout(() => { + isRefreshing = false + globalThis.localStorage.removeItem('is_refreshing') + reject(new Error('request timeout')) + }, timeout)), getNewAccessToken()]) +} From 08c731fd847d416f43f713f44c0dcbe4d2288ad7 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 5 Nov 2024 14:23:18 +0800 Subject: [PATCH 21/29] fix(node): correct file property name in function switch (#10284) --- api/core/workflow/nodes/list_operator/node.py | 2 +- .../core/workflow/nodes/test_list_operator.py | 49 +++++++++++++++++-- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 0406b97eb8..49e7ca85fd 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -157,7 +157,7 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]: return lambda x: x.type case "extension": return lambda x: x.extension or "" - case "mimetype": + case "mime_type": return lambda x: x.mime_type or "" case "transfer_method": return lambda x: x.transfer_method diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 53e3c93fcc..0f5c8bf51b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -2,11 +2,11 @@ from unittest.mock import MagicMock import pytest -from core.file import File -from core.file.models import FileTransferMethod, FileType +from core.file import File, FileTransferMethod, FileType from core.variables import ArrayFileSegment from core.workflow.nodes.list_operator.entities import FilterBy, FilterCondition, Limit, ListOperatorNodeData, OrderBy -from core.workflow.nodes.list_operator.node import ListOperatorNode +from core.workflow.nodes.list_operator.exc import InvalidKeyError +from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func from models.workflow import WorkflowNodeExecutionStatus @@ -109,3 +109,46 @@ def test_filter_files_by_type(list_operator_node): assert expected_file["tenant_id"] == result_file.tenant_id assert expected_file["transfer_method"] == result_file.transfer_method assert expected_file["related_id"] == result_file.related_id + + +def test_get_file_extract_string_func(): + # Create a File object + file = File( + tenant_id="test_tenant", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + filename="test_file.txt", + extension=".txt", + mime_type="text/plain", + remote_url="https://example.com/test_file.txt", + related_id="test_related_id", + ) + + # Test each case + assert _get_file_extract_string_func(key="name")(file) == "test_file.txt" + assert _get_file_extract_string_func(key="type")(file) == "document" + assert _get_file_extract_string_func(key="extension")(file) == ".txt" + assert _get_file_extract_string_func(key="mime_type")(file) == "text/plain" + assert _get_file_extract_string_func(key="transfer_method")(file) == "local_file" + assert _get_file_extract_string_func(key="url")(file) == "https://example.com/test_file.txt" + + # Test with empty values + empty_file = File( + tenant_id="test_tenant", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + filename=None, + extension=None, + mime_type=None, + remote_url=None, + related_id="test_related_id", + ) + + assert _get_file_extract_string_func(key="name")(empty_file) == "" + assert _get_file_extract_string_func(key="extension")(empty_file) == "" + assert _get_file_extract_string_func(key="mime_type")(empty_file) == "" + assert _get_file_extract_string_func(key="url")(empty_file) == "" + + # Test invalid key + with pytest.raises(InvalidKeyError): + _get_file_extract_string_func(key="invalid_key") From 249b897872c65aea27e0505bb5d681ffc0b16e3c Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 5 Nov 2024 14:40:57 +0800 Subject: [PATCH 22/29] feat(model): add validation for custom disclaimer length (#10287) --- api/models/model.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/api/models/model.py b/api/models/model.py index bd124cce8e..d049cd373d 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1298,7 +1298,7 @@ class Site(db.Model): privacy_policy = db.Column(db.String(255)) show_workflow_steps = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") + _custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", sa.TEXT, default="") customize_domain = db.Column(db.String(255)) customize_token_strategy = db.Column(db.String(255), nullable=False) prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @@ -1309,6 +1309,16 @@ class Site(db.Model): updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) code = db.Column(db.String(255)) + @property + def custom_disclaimer(self): + return self._custom_disclaimer + + @custom_disclaimer.setter + def custom_disclaimer(self, value: str): + if len(value) > 512: + raise ValueError("Custom disclaimer cannot exceed 512 characters.") + self._custom_disclaimer = value + @staticmethod def generate_code(n): while True: From cb245b54354245388824e2f5541481cf633c27ef Mon Sep 17 00:00:00 2001 From: Matsuda Date: Tue, 5 Nov 2024 15:41:15 +0900 Subject: [PATCH 23/29] fix(model_runtime): fix wrong max_tokens for Claude 3.5 Haiku on Amazon Bedrock (#10286) --- .../bedrock/llm/anthropic.claude-3-5-haiku-v1.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-5-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-5-haiku-v1.yaml index 7c676136db..35fc8d0d11 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-5-haiku-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/anthropic.claude-3-5-haiku-v1.yaml @@ -16,9 +16,9 @@ parameter_rules: use_template: max_tokens required: true type: int - default: 4096 + default: 8192 min: 1 - max: 4096 + max: 8192 help: zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. From 4847548779f2ca93683702d787aae21666263fc2 Mon Sep 17 00:00:00 2001 From: Matsuda Date: Tue, 5 Nov 2024 15:41:39 +0900 Subject: [PATCH 24/29] feat(model_runtime): add new model 'claude-3-5-haiku-20241022' (#10285) --- .../anthropic/llm/_position.yaml | 1 + .../llm/claude-3-5-haiku-20241022.yaml | 39 +++++++++++++++++++ 2 files changed, 40 insertions(+) create mode 100644 api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-haiku-20241022.yaml diff --git a/api/core/model_runtime/model_providers/anthropic/llm/_position.yaml b/api/core/model_runtime/model_providers/anthropic/llm/_position.yaml index aca9456313..b7b28a70d4 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/_position.yaml @@ -1,3 +1,4 @@ +- claude-3-5-haiku-20241022 - claude-3-5-sonnet-20241022 - claude-3-5-sonnet-20240620 - claude-3-haiku-20240307 diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-haiku-20241022.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-haiku-20241022.yaml new file mode 100644 index 0000000000..cae4c67e4a --- /dev/null +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-haiku-20241022.yaml @@ -0,0 +1,39 @@ +model: claude-3-5-haiku-20241022 +label: + en_US: claude-3-5-haiku-20241022 +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 200000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_tokens + use_template: max_tokens + required: true + default: 8192 + min: 1 + max: 8192 + - name: response_format + use_template: response_format +pricing: + input: '1.00' + output: '5.00' + unit: '0.000001' + currency: USD From bf9349c4dc22d4cbfe76ec1db057cf5a53dd3aca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9D=9E=E6=B3=95=E6=93=8D=E4=BD=9C?= Date: Tue, 5 Nov 2024 14:42:47 +0800 Subject: [PATCH 25/29] feat: add xAI model provider (#10272) --- .../model_providers/x/__init__.py | 0 .../model_providers/x/_assets/x-ai-logo.svg | 1 + .../model_providers/x/llm/__init__.py | 0 .../model_providers/x/llm/grok-beta.yaml | 63 ++++++ .../model_providers/x/llm/llm.py | 37 ++++ api/core/model_runtime/model_providers/x/x.py | 25 +++ .../model_runtime/model_providers/x/x.yaml | 38 ++++ api/tests/integration_tests/.env.example | 4 + .../model_runtime/x/__init__.py | 0 .../model_runtime/x/test_llm.py | 204 ++++++++++++++++++ 10 files changed, 372 insertions(+) create mode 100644 api/core/model_runtime/model_providers/x/__init__.py create mode 100644 api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg create mode 100644 api/core/model_runtime/model_providers/x/llm/__init__.py create mode 100644 api/core/model_runtime/model_providers/x/llm/grok-beta.yaml create mode 100644 api/core/model_runtime/model_providers/x/llm/llm.py create mode 100644 api/core/model_runtime/model_providers/x/x.py create mode 100644 api/core/model_runtime/model_providers/x/x.yaml create mode 100644 api/tests/integration_tests/model_runtime/x/__init__.py create mode 100644 api/tests/integration_tests/model_runtime/x/test_llm.py diff --git a/api/core/model_runtime/model_providers/x/__init__.py b/api/core/model_runtime/model_providers/x/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg b/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg new file mode 100644 index 0000000000..f8b745cb13 --- /dev/null +++ b/api/core/model_runtime/model_providers/x/_assets/x-ai-logo.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/x/llm/__init__.py b/api/core/model_runtime/model_providers/x/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml b/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml new file mode 100644 index 0000000000..7c305735b9 --- /dev/null +++ b/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml @@ -0,0 +1,63 @@ +model: grok-beta +label: + en_US: Grok beta +model_type: llm +features: + - multi-tool-call +model_properties: + mode: chat + context_size: 131072 +parameter_rules: + - name: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 2.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: 0 + max: 2.0 + precision: 1 + required: false + help: + en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim." + zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/x/llm/llm.py b/api/core/model_runtime/model_providers/x/llm/llm.py new file mode 100644 index 0000000000..3f5325a857 --- /dev/null +++ b/api/core/model_runtime/model_providers/x/llm/llm.py @@ -0,0 +1,37 @@ +from collections.abc import Generator +from typing import Optional, Union + +from yarl import URL + +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult +from core.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + + +class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel): + def _invoke( + self, + model: str, + credentials: dict, + prompt_messages: list[PromptMessage], + model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, + stream: bool = True, + user: Optional[str] = None, + ) -> Union[LLMResult, Generator]: + self._add_custom_parameters(credentials) + return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + super().validate_credentials(model, credentials) + + @staticmethod + def _add_custom_parameters(credentials) -> None: + credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1" + credentials["mode"] = LLMMode.CHAT.value + credentials["function_calling_type"] = "tool_call" diff --git a/api/core/model_runtime/model_providers/x/x.py b/api/core/model_runtime/model_providers/x/x.py new file mode 100644 index 0000000000..e3f2b8eeba --- /dev/null +++ b/api/core/model_runtime/model_providers/x/x.py @@ -0,0 +1,25 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class XAIProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + model_instance.validate_credentials(model="grok-beta", credentials=credentials) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f"{self.get_provider_schema().provider} credentials validate failed") + raise ex diff --git a/api/core/model_runtime/model_providers/x/x.yaml b/api/core/model_runtime/model_providers/x/x.yaml new file mode 100644 index 0000000000..90d1cbfe7e --- /dev/null +++ b/api/core/model_runtime/model_providers/x/x.yaml @@ -0,0 +1,38 @@ +provider: x +label: + en_US: xAI +description: + en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe. +icon_small: + en_US: x-ai-logo.svg +icon_large: + en_US: x-ai-logo.svg +help: + title: + en_US: Get your token from xAI + zh_Hans: 从 xAI 获取 token + url: + en_US: https://x.ai/api +supported_model_types: + - llm +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key + - variable: endpoint_url + label: + en_US: API Base + type: text-input + required: false + default: https://api.x.ai/v1 + placeholder: + zh_Hans: 在此输入您的 API Base + en_US: Enter your API Base diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 99728a8271..6fd144c5c2 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -95,3 +95,7 @@ GPUSTACK_API_KEY= # Gitee AI Credentials GITEE_AI_API_KEY= + +# xAI Credentials +XAI_API_KEY= +XAI_API_BASE= diff --git a/api/tests/integration_tests/model_runtime/x/__init__.py b/api/tests/integration_tests/model_runtime/x/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/model_runtime/x/test_llm.py b/api/tests/integration_tests/model_runtime/x/test_llm.py new file mode 100644 index 0000000000..647a2f6480 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/x/test_llm.py @@ -0,0 +1,204 @@ +import os +from collections.abc import Generator + +import pytest + +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.x.llm.llm import XAILargeLanguageModel + +"""FOR MOCK FIXTURES, DO NOT REMOVE""" +from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock + + +def test_predefined_models(): + model = XAILargeLanguageModel() + model_schemas = model.predefined_models() + + assert len(model_schemas) >= 1 + assert isinstance(model_schemas[0], AIModelEntity) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_validate_credentials_for_chat_model(setup_openai_mock): + model = XAILargeLanguageModel() + + with pytest.raises(CredentialsValidateFailedError): + # model name to gpt-3.5-turbo because of mocking + model.validate_credentials( + model="gpt-3.5-turbo", + credentials={"api_key": "invalid_key", "endpoint_url": os.environ.get("XAI_API_BASE"), "mode": "chat"}, + ) + + model.validate_credentials( + model="grok-beta", + credentials={ + "api_key": os.environ.get("XAI_API_KEY"), + "endpoint_url": os.environ.get("XAI_API_BASE"), + "mode": "chat", + }, + ) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model(setup_openai_mock): + model = XAILargeLanguageModel() + + result = model.invoke( + model="grok-beta", + credentials={ + "api_key": os.environ.get("XAI_API_KEY"), + "endpoint_url": os.environ.get("XAI_API_BASE"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={ + "temperature": 0.0, + "top_p": 1.0, + "presence_penalty": 0.0, + "frequency_penalty": 0.0, + "max_tokens": 10, + }, + stop=["How"], + stream=False, + user="foo", + ) + + assert isinstance(result, LLMResult) + assert len(result.message.content) > 0 + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_chat_model_with_tools(setup_openai_mock): + model = XAILargeLanguageModel() + + result = model.invoke( + model="grok-beta", + credentials={ + "api_key": os.environ.get("XAI_API_KEY"), + "endpoint_url": os.environ.get("XAI_API_BASE"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage( + content="what's the weather today in London?", + ), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + PromptMessageTool( + name="get_stock_price", + description="Get the current stock price", + parameters={ + "type": "object", + "properties": {"symbol": {"type": "string", "description": "The stock symbol"}}, + "required": ["symbol"], + }, + ), + ], + stream=False, + user="foo", + ) + + assert isinstance(result, LLMResult) + assert isinstance(result.message, AssistantPromptMessage) + + +@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True) +def test_invoke_stream_chat_model(setup_openai_mock): + model = XAILargeLanguageModel() + + result = model.invoke( + model="grok-beta", + credentials={ + "api_key": os.environ.get("XAI_API_KEY"), + "endpoint_url": os.environ.get("XAI_API_BASE"), + "mode": "chat", + }, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + model_parameters={"temperature": 0.0, "max_tokens": 100}, + stream=True, + user="foo", + ) + + assert isinstance(result, Generator) + + for chunk in result: + assert isinstance(chunk, LLMResultChunk) + assert isinstance(chunk.delta, LLMResultChunkDelta) + assert isinstance(chunk.delta.message, AssistantPromptMessage) + assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + if chunk.delta.finish_reason is not None: + assert chunk.delta.usage is not None + assert chunk.delta.usage.completion_tokens > 0 + + +def test_get_num_tokens(): + model = XAILargeLanguageModel() + + num_tokens = model.get_num_tokens( + model="grok-beta", + credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")}, + prompt_messages=[UserPromptMessage(content="Hello World!")], + ) + + assert num_tokens == 10 + + num_tokens = model.get_num_tokens( + model="grok-beta", + credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")}, + prompt_messages=[ + SystemPromptMessage( + content="You are a helpful AI assistant.", + ), + UserPromptMessage(content="Hello World!"), + ], + tools=[ + PromptMessageTool( + name="get_weather", + description="Determine weather in my location", + parameters={ + "type": "object", + "properties": { + "location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"}, + "unit": {"type": "string", "enum": ["c", "f"]}, + }, + "required": ["location"], + }, + ), + ], + ) + + assert num_tokens == 77 From 233bffdb7d4ce967295d03cd283b104b202c945e Mon Sep 17 00:00:00 2001 From: eux Date: Tue, 5 Nov 2024 14:42:59 +0800 Subject: [PATCH 26/29] fix: borken faq url in CONTRIBUTING.md (#10275) --- CONTRIBUTING.md | 2 +- CONTRIBUTING_VI.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8f57cd545e..da2928d189 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -81,7 +81,7 @@ Dify requires the following dependencies to build, make sure they're installed o Dify is composed of a backend and a frontend. Navigate to the backend directory by `cd api/`, then follow the [Backend README](api/README.md) to install it. In a separate terminal, navigate to the frontend directory by `cd web/`, then follow the [Frontend README](web/README.md) to install. -Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/self-host-faq) for a list of common issues and steps to troubleshoot. +Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/install-faq) for a list of common issues and steps to troubleshoot. ### 5. Visit dify in your browser diff --git a/CONTRIBUTING_VI.md b/CONTRIBUTING_VI.md index 80e68a046e..a77239ff38 100644 --- a/CONTRIBUTING_VI.md +++ b/CONTRIBUTING_VI.md @@ -79,7 +79,7 @@ Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đ Dify bao gồm một backend và một frontend. Đi đến thư mục backend bằng lệnh `cd api/`, sau đó làm theo hướng dẫn trong [README của Backend](api/README.md) để cài đặt. Trong một terminal khác, đi đến thư mục frontend bằng lệnh `cd web/`, sau đó làm theo hướng dẫn trong [README của Frontend](web/README.md) để cài đặt. -Kiểm tra [FAQ về cài đặt](https://docs.dify.ai/learn-more/faq/self-host-faq) để xem danh sách các vấn đề thường gặp và các bước khắc phục. +Kiểm tra [FAQ về cài đặt](https://docs.dify.ai/learn-more/faq/install-faq) để xem danh sách các vấn đề thường gặp và các bước khắc phục. ### 5. Truy cập Dify trong trình duyệt của bạn From 5f21d13572aa0a60ab4c7482626d0003e7def2c2 Mon Sep 17 00:00:00 2001 From: pinsily <13160724868@163.com> Date: Tue, 5 Nov 2024 14:47:15 +0800 Subject: [PATCH 27/29] fix: handle KeyError when accessing rules in CleanProcessor.clean (#10258) --- api/core/indexing_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index fb9fe8f210..e2a94073cf 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -598,7 +598,7 @@ class IndexingRunner: rules = DatasetProcessRule.AUTOMATIC_RULES else: rules = json.loads(processing_rule.rules) if processing_rule.rules else {} - document_text = CleanProcessor.clean(text, rules) + document_text = CleanProcessor.clean(text, {"rules": rules}) return document_text From 68e0b0ac84e1cbde7233f9e4ab66236bffe20b4f Mon Sep 17 00:00:00 2001 From: Matsuda Date: Tue, 5 Nov 2024 17:09:53 +0900 Subject: [PATCH 28/29] fix typo: writeOpner to writeOpener (#10290) --- web/i18n/pl-PL/app-debug.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/web/i18n/pl-PL/app-debug.ts b/web/i18n/pl-PL/app-debug.ts index 7cf6c77cb4..cf7232e563 100644 --- a/web/i18n/pl-PL/app-debug.ts +++ b/web/i18n/pl-PL/app-debug.ts @@ -355,7 +355,7 @@ const translation = { openingStatement: { title: 'Wstęp do rozmowy', add: 'Dodaj', - writeOpner: 'Napisz wstęp', + writeOpener: 'Napisz wstęp', placeholder: 'Tutaj napisz swoją wiadomość wprowadzającą, możesz użyć zmiennych, spróbuj wpisać {{variable}}.', openingQuestion: 'Pytania otwierające', From ae254f0a10114060ee32ff521eb8bafec2acf792 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Tue, 5 Nov 2024 16:30:23 +0800 Subject: [PATCH 29/29] fix(http_request): improve parameter initialization and reorganize tests (#10297) --- .../workflow/nodes/http_request/executor.py | 6 +- .../test_http_request_executor.py | 198 ++++++++++++++++++ .../test_http_request_node.py | 169 +-------------- 3 files changed, 203 insertions(+), 170 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py rename api/tests/unit_tests/core/workflow/nodes/{ => http_request}/test_http_request_node.py (52%) diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 6204fc2644..d90dfcc766 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -88,8 +88,10 @@ class Executor: self.url = self.variable_pool.convert_template(self.node_data.url).text def _init_params(self): - params = self.variable_pool.convert_template(self.node_data.params).text - self.params = _plain_text_to_dict(params) + params = _plain_text_to_dict(self.node_data.params) + for key in params: + params[key] = self.variable_pool.convert_template(params[key]).text + self.params = params def _init_headers(self): headers = self.variable_pool.convert_template(self.node_data.headers).text diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py new file mode 100644 index 0000000000..12c469a81a --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -0,0 +1,198 @@ +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.nodes.http_request import ( + BodyData, + HttpRequestNodeAuthorization, + HttpRequestNodeBody, + HttpRequestNodeData, +) +from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout +from core.workflow.nodes.http_request.executor import Executor + + +def test_executor_with_json_body_and_number_variable(): + # Prepare the variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add(["pre_node_id", "number"], 42) + + # Prepare the node data + node_data = HttpRequestNodeData( + title="Test JSON Body with Number Variable", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: application/json", + params="", + body=HttpRequestNodeBody( + type="json", + data=[ + BodyData( + key="", + type="text", + value='{"number": {{#pre_node_id.number#}}}', + ) + ], + ), + ) + + # Initialize the Executor + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + # Check the executor's data + assert executor.method == "post" + assert executor.url == "https://api.example.com/data" + assert executor.headers == {"Content-Type": "application/json"} + assert executor.params == {} + assert executor.json == {"number": 42} + assert executor.data is None + assert executor.files is None + assert executor.content is None + + # Check the raw request (to_log method) + raw_request = executor.to_log() + assert "POST /data HTTP/1.1" in raw_request + assert "Host: api.example.com" in raw_request + assert "Content-Type: application/json" in raw_request + assert '{"number": 42}' in raw_request + + +def test_executor_with_json_body_and_object_variable(): + # Prepare the variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) + + # Prepare the node data + node_data = HttpRequestNodeData( + title="Test JSON Body with Object Variable", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: application/json", + params="", + body=HttpRequestNodeBody( + type="json", + data=[ + BodyData( + key="", + type="text", + value="{{#pre_node_id.object#}}", + ) + ], + ), + ) + + # Initialize the Executor + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + # Check the executor's data + assert executor.method == "post" + assert executor.url == "https://api.example.com/data" + assert executor.headers == {"Content-Type": "application/json"} + assert executor.params == {} + assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"} + assert executor.data is None + assert executor.files is None + assert executor.content is None + + # Check the raw request (to_log method) + raw_request = executor.to_log() + assert "POST /data HTTP/1.1" in raw_request + assert "Host: api.example.com" in raw_request + assert "Content-Type: application/json" in raw_request + assert '"name": "John Doe"' in raw_request + assert '"age": 30' in raw_request + assert '"email": "john@example.com"' in raw_request + + +def test_executor_with_json_body_and_nested_object_variable(): + # Prepare the variable pool + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) + + # Prepare the node data + node_data = HttpRequestNodeData( + title="Test JSON Body with Nested Object Variable", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: application/json", + params="", + body=HttpRequestNodeBody( + type="json", + data=[ + BodyData( + key="", + type="text", + value='{"object": {{#pre_node_id.object#}}}', + ) + ], + ), + ) + + # Initialize the Executor + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + # Check the executor's data + assert executor.method == "post" + assert executor.url == "https://api.example.com/data" + assert executor.headers == {"Content-Type": "application/json"} + assert executor.params == {} + assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}} + assert executor.data is None + assert executor.files is None + assert executor.content is None + + # Check the raw request (to_log method) + raw_request = executor.to_log() + assert "POST /data HTTP/1.1" in raw_request + assert "Host: api.example.com" in raw_request + assert "Content-Type: application/json" in raw_request + assert '"object": {' in raw_request + assert '"name": "John Doe"' in raw_request + assert '"age": 30' in raw_request + assert '"email": "john@example.com"' in raw_request + + +def test_extract_selectors_from_template_with_newline(): + variable_pool = VariablePool() + variable_pool.add(("node_id", "custom_query"), "line1\nline2") + node_data = HttpRequestNodeData( + title="Test JSON Body with Nested Object Variable", + method="post", + url="https://api.example.com/data", + authorization=HttpRequestNodeAuthorization(type="no-auth"), + headers="Content-Type: application/json", + params="test: {{#node_id.custom_query#}}", + body=HttpRequestNodeBody( + type="none", + data=[], + ), + ) + + executor = Executor( + node_data=node_data, + timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), + variable_pool=variable_pool, + ) + + assert executor.params == {"test": "line1\nline2"} diff --git a/api/tests/unit_tests/core/workflow/nodes/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py similarity index 52% rename from api/tests/unit_tests/core/workflow/nodes/test_http_request_node.py rename to api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 720037d05f..741a3a1894 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -1,5 +1,3 @@ -import json - import httpx from core.app.entities.app_invoke_entities import InvokeFrom @@ -16,8 +14,7 @@ from core.workflow.nodes.http_request import ( HttpRequestNodeBody, HttpRequestNodeData, ) -from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout -from core.workflow.nodes.http_request.executor import Executor, _plain_text_to_dict +from core.workflow.nodes.http_request.executor import _plain_text_to_dict from models.enums import UserFrom from models.workflow import WorkflowNodeExecutionStatus, WorkflowType @@ -203,167 +200,3 @@ def test_http_request_node_form_with_file(monkeypatch): assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs is not None assert result.outputs["body"] == "" - - -def test_executor_with_json_body_and_number_variable(): - # Prepare the variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - variable_pool.add(["pre_node_id", "number"], 42) - - # Prepare the node data - node_data = HttpRequestNodeData( - title="Test JSON Body with Number Variable", - method="post", - url="https://api.example.com/data", - authorization=HttpRequestNodeAuthorization(type="no-auth"), - headers="Content-Type: application/json", - params="", - body=HttpRequestNodeBody( - type="json", - data=[ - BodyData( - key="", - type="text", - value='{"number": {{#pre_node_id.number#}}}', - ) - ], - ), - ) - - # Initialize the Executor - executor = Executor( - node_data=node_data, - timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), - variable_pool=variable_pool, - ) - - # Check the executor's data - assert executor.method == "post" - assert executor.url == "https://api.example.com/data" - assert executor.headers == {"Content-Type": "application/json"} - assert executor.params == {} - assert executor.json == {"number": 42} - assert executor.data is None - assert executor.files is None - assert executor.content is None - - # Check the raw request (to_log method) - raw_request = executor.to_log() - assert "POST /data HTTP/1.1" in raw_request - assert "Host: api.example.com" in raw_request - assert "Content-Type: application/json" in raw_request - assert '{"number": 42}' in raw_request - - -def test_executor_with_json_body_and_object_variable(): - # Prepare the variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) - - # Prepare the node data - node_data = HttpRequestNodeData( - title="Test JSON Body with Object Variable", - method="post", - url="https://api.example.com/data", - authorization=HttpRequestNodeAuthorization(type="no-auth"), - headers="Content-Type: application/json", - params="", - body=HttpRequestNodeBody( - type="json", - data=[ - BodyData( - key="", - type="text", - value="{{#pre_node_id.object#}}", - ) - ], - ), - ) - - # Initialize the Executor - executor = Executor( - node_data=node_data, - timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), - variable_pool=variable_pool, - ) - - # Check the executor's data - assert executor.method == "post" - assert executor.url == "https://api.example.com/data" - assert executor.headers == {"Content-Type": "application/json"} - assert executor.params == {} - assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"} - assert executor.data is None - assert executor.files is None - assert executor.content is None - - # Check the raw request (to_log method) - raw_request = executor.to_log() - assert "POST /data HTTP/1.1" in raw_request - assert "Host: api.example.com" in raw_request - assert "Content-Type: application/json" in raw_request - assert '"name": "John Doe"' in raw_request - assert '"age": 30' in raw_request - assert '"email": "john@example.com"' in raw_request - - -def test_executor_with_json_body_and_nested_object_variable(): - # Prepare the variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) - - # Prepare the node data - node_data = HttpRequestNodeData( - title="Test JSON Body with Nested Object Variable", - method="post", - url="https://api.example.com/data", - authorization=HttpRequestNodeAuthorization(type="no-auth"), - headers="Content-Type: application/json", - params="", - body=HttpRequestNodeBody( - type="json", - data=[ - BodyData( - key="", - type="text", - value='{"object": {{#pre_node_id.object#}}}', - ) - ], - ), - ) - - # Initialize the Executor - executor = Executor( - node_data=node_data, - timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), - variable_pool=variable_pool, - ) - - # Check the executor's data - assert executor.method == "post" - assert executor.url == "https://api.example.com/data" - assert executor.headers == {"Content-Type": "application/json"} - assert executor.params == {} - assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}} - assert executor.data is None - assert executor.files is None - assert executor.content is None - - # Check the raw request (to_log method) - raw_request = executor.to_log() - assert "POST /data HTTP/1.1" in raw_request - assert "Host: api.example.com" in raw_request - assert "Content-Type: application/json" in raw_request - assert '"object": {' in raw_request - assert '"name": "John Doe"' in raw_request - assert '"age": 30' in raw_request - assert '"email": "john@example.com"' in raw_request