mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into feat/support-extractor-tools
This commit is contained in:
commit
cae7f7523b
|
|
@ -81,7 +81,7 @@ Dify requires the following dependencies to build, make sure they're installed o
|
|||
|
||||
Dify is composed of a backend and a frontend. Navigate to the backend directory by `cd api/`, then follow the [Backend README](api/README.md) to install it. In a separate terminal, navigate to the frontend directory by `cd web/`, then follow the [Frontend README](web/README.md) to install.
|
||||
|
||||
Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/self-host-faq) for a list of common issues and steps to troubleshoot.
|
||||
Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/install-faq) for a list of common issues and steps to troubleshoot.
|
||||
|
||||
### 5. Visit dify in your browser
|
||||
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đ
|
|||
|
||||
Dify bao gồm một backend và một frontend. Đi đến thư mục backend bằng lệnh `cd api/`, sau đó làm theo hướng dẫn trong [README của Backend](api/README.md) để cài đặt. Trong một terminal khác, đi đến thư mục frontend bằng lệnh `cd web/`, sau đó làm theo hướng dẫn trong [README của Frontend](web/README.md) để cài đặt.
|
||||
|
||||
Kiểm tra [FAQ về cài đặt](https://docs.dify.ai/learn-more/faq/self-host-faq) để xem danh sách các vấn đề thường gặp và các bước khắc phục.
|
||||
Kiểm tra [FAQ về cài đặt](https://docs.dify.ai/learn-more/faq/install-faq) để xem danh sách các vấn đề thường gặp và các bước khắc phục.
|
||||
|
||||
### 5. Truy cập Dify trong trình duyệt của bạn
|
||||
|
||||
|
|
|
|||
|
|
@ -327,6 +327,9 @@ SSRF_DEFAULT_MAX_RETRIES=3
|
|||
BATCH_UPLOAD_LIMIT=10
|
||||
KEYWORD_DATA_SOURCE_TYPE=database
|
||||
|
||||
# Workflow file upload limit
|
||||
WORKFLOW_FILE_UPLOAD_LIMIT=10
|
||||
|
||||
# CODE EXECUTION CONFIGURATION
|
||||
CODE_EXECUTION_ENDPOINT=http://127.0.0.1:8194
|
||||
CODE_EXECUTION_API_KEY=dify-sandbox
|
||||
|
|
|
|||
|
|
@ -55,12 +55,7 @@ RUN apt-get update \
|
|||
&& echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \
|
||||
&& apt-get update \
|
||||
# For Security
|
||||
&& apt-get install -y --no-install-recommends expat=2.6.3-2 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 \
|
||||
&& if [ "$(dpkg --print-architecture)" = "amd64" ]; then \
|
||||
apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1+b1; \
|
||||
else \
|
||||
apt-get install -y --no-install-recommends zlib1g=1:1.3.dfsg+really1.3.1-1; \
|
||||
fi \
|
||||
&& apt-get install -y --no-install-recommends expat=2.6.3-2 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-6 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \
|
||||
# install a chinese font to support the use of tools like matplotlib
|
||||
&& apt-get install -y fonts-noto-cjk \
|
||||
&& apt-get autoremove -y \
|
||||
|
|
|
|||
|
|
@ -216,6 +216,11 @@ class FileUploadConfig(BaseSettings):
|
|||
default=20,
|
||||
)
|
||||
|
||||
WORKFLOW_FILE_UPLOAD_LIMIT: PositiveInt = Field(
|
||||
description="Maximum number of files allowed in a workflow upload operation",
|
||||
default=10,
|
||||
)
|
||||
|
||||
|
||||
class HttpConfig(BaseSettings):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,24 @@
|
|||
from flask_restful import fields
|
||||
|
||||
parameters__system_parameters = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
"workflow_file_upload_limit": fields.Integer,
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"suggested_questions_after_answer": fields.Raw,
|
||||
"speech_to_text": fields.Raw,
|
||||
"text_to_speech": fields.Raw,
|
||||
"retriever_resource": fields.Raw,
|
||||
"annotation_reply": fields.Raw,
|
||||
"more_like_this": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"sensitive_word_avoidance": fields.Raw,
|
||||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(parameters__system_parameters),
|
||||
}
|
||||
|
|
@ -2,11 +2,15 @@ import mimetypes
|
|||
import os
|
||||
import re
|
||||
import urllib.parse
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
|
||||
class FileInfo(BaseModel):
|
||||
filename: str
|
||||
|
|
@ -56,3 +60,38 @@ def guess_file_info_from_response(response: httpx.Response):
|
|||
mimetype=mimetype,
|
||||
size=int(response.headers.get("Content-Length", -1)),
|
||||
)
|
||||
|
||||
|
||||
def get_parameters_from_feature_dict(*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]):
|
||||
return {
|
||||
"opening_statement": features_dict.get("opening_statement"),
|
||||
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
|
||||
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||
"user_input_form": user_input_form,
|
||||
"sensitive_word_avoidance": features_dict.get(
|
||||
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||
),
|
||||
"file_upload": features_dict.get(
|
||||
"file_upload",
|
||||
{
|
||||
"image": {
|
||||
"enabled": False,
|
||||
"number_limits": 3,
|
||||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
||||
},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from flask_restful import fields, marshal_with
|
||||
from flask_restful import marshal_with
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common import fields
|
||||
from controllers.common import helpers as controller_helpers
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import AppUnavailableError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
|
|
@ -11,43 +12,14 @@ from services.app_service import AppService
|
|||
class AppParameterApi(InstalledAppResource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
variable_fields = {
|
||||
"key": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"type": fields.String,
|
||||
"default": fields.String,
|
||||
"max_length": fields.Integer,
|
||||
"options": fields.List(fields.String),
|
||||
}
|
||||
|
||||
system_parameters_fields = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"suggested_questions_after_answer": fields.Raw,
|
||||
"speech_to_text": fields.Raw,
|
||||
"text_to_speech": fields.Raw,
|
||||
"retriever_resource": fields.Raw,
|
||||
"annotation_reply": fields.Raw,
|
||||
"more_like_this": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"sensitive_word_avoidance": fields.Raw,
|
||||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(system_parameters_fields),
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, installed_app: InstalledApp):
|
||||
"""Retrieve app parameters."""
|
||||
app_model = installed_app.app
|
||||
|
||||
if app_model is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
workflow = app_model.workflow
|
||||
if workflow is None:
|
||||
|
|
@ -57,43 +29,16 @@ class AppParameterApi(InstalledAppResource):
|
|||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return {
|
||||
"opening_statement": features_dict.get("opening_statement"),
|
||||
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||
"suggested_questions_after_answer": features_dict.get(
|
||||
"suggested_questions_after_answer", {"enabled": False}
|
||||
),
|
||||
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||
"user_input_form": user_input_form,
|
||||
"sensitive_word_avoidance": features_dict.get(
|
||||
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||
),
|
||||
"file_upload": features_dict.get(
|
||||
"file_upload",
|
||||
{
|
||||
"image": {
|
||||
"enabled": False,
|
||||
"number_limits": 3,
|
||||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
},
|
||||
}
|
||||
return controller_helpers.get_parameters_from_feature_dict(
|
||||
features_dict=features_dict, user_input_form=user_input_form
|
||||
)
|
||||
|
||||
|
||||
class ExploreAppMetaApi(InstalledAppResource):
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ class FileApi(Resource):
|
|||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from flask_restful import Resource, fields, marshal_with
|
||||
from flask_restful import Resource, marshal_with
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common import fields
|
||||
from controllers.common import helpers as controller_helpers
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
|
|
@ -11,40 +12,8 @@ from services.app_service import AppService
|
|||
class AppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
variable_fields = {
|
||||
"key": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"type": fields.String,
|
||||
"default": fields.String,
|
||||
"max_length": fields.Integer,
|
||||
"options": fields.List(fields.String),
|
||||
}
|
||||
|
||||
system_parameters_fields = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"suggested_questions_after_answer": fields.Raw,
|
||||
"speech_to_text": fields.Raw,
|
||||
"text_to_speech": fields.Raw,
|
||||
"retriever_resource": fields.Raw,
|
||||
"annotation_reply": fields.Raw,
|
||||
"more_like_this": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"sensitive_word_avoidance": fields.Raw,
|
||||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(system_parameters_fields),
|
||||
}
|
||||
|
||||
@validate_app_token
|
||||
@marshal_with(parameters_fields)
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app parameters."""
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
|
|
@ -56,43 +25,16 @@ class AppParameterApi(Resource):
|
|||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return {
|
||||
"opening_statement": features_dict.get("opening_statement"),
|
||||
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||
"suggested_questions_after_answer": features_dict.get(
|
||||
"suggested_questions_after_answer", {"enabled": False}
|
||||
),
|
||||
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||
"user_input_form": user_input_form,
|
||||
"sensitive_word_avoidance": features_dict.get(
|
||||
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||
),
|
||||
"file_upload": features_dict.get(
|
||||
"file_upload",
|
||||
{
|
||||
"image": {
|
||||
"enabled": False,
|
||||
"number_limits": 3,
|
||||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
},
|
||||
}
|
||||
return controller_helpers.get_parameters_from_feature_dict(
|
||||
features_dict=features_dict, user_input_form=user_input_form
|
||||
)
|
||||
|
||||
|
||||
class AppMetaApi(Resource):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from flask_restful import fields, marshal_with
|
||||
from flask_restful import marshal_with
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.common import fields
|
||||
from controllers.common import helpers as controller_helpers
|
||||
from controllers.web import api
|
||||
from controllers.web.error import AppUnavailableError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
|
|
@ -11,39 +12,7 @@ from services.app_service import AppService
|
|||
class AppParameterApi(WebApiResource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
variable_fields = {
|
||||
"key": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"type": fields.String,
|
||||
"default": fields.String,
|
||||
"max_length": fields.Integer,
|
||||
"options": fields.List(fields.String),
|
||||
}
|
||||
|
||||
system_parameters_fields = {
|
||||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"file_size_limit": fields.Integer,
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
"opening_statement": fields.String,
|
||||
"suggested_questions": fields.Raw,
|
||||
"suggested_questions_after_answer": fields.Raw,
|
||||
"speech_to_text": fields.Raw,
|
||||
"text_to_speech": fields.Raw,
|
||||
"retriever_resource": fields.Raw,
|
||||
"annotation_reply": fields.Raw,
|
||||
"more_like_this": fields.Raw,
|
||||
"user_input_form": fields.Raw,
|
||||
"sensitive_word_avoidance": fields.Raw,
|
||||
"file_upload": fields.Raw,
|
||||
"system_parameters": fields.Nested(system_parameters_fields),
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
@marshal_with(fields.parameters_fields)
|
||||
def get(self, app_model: App, end_user):
|
||||
"""Retrieve app parameters."""
|
||||
if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
|
||||
|
|
@ -55,43 +24,16 @@ class AppParameterApi(WebApiResource):
|
|||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app_model.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
|
||||
features_dict = app_model_config.to_dict()
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
return {
|
||||
"opening_statement": features_dict.get("opening_statement"),
|
||||
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||
"suggested_questions_after_answer": features_dict.get(
|
||||
"suggested_questions_after_answer", {"enabled": False}
|
||||
),
|
||||
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||
"user_input_form": user_input_form,
|
||||
"sensitive_word_avoidance": features_dict.get(
|
||||
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||
),
|
||||
"file_upload": features_dict.get(
|
||||
"file_upload",
|
||||
{
|
||||
"image": {
|
||||
"enabled": False,
|
||||
"number_limits": 3,
|
||||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
},
|
||||
}
|
||||
return controller_helpers.get_parameters_from_feature_dict(
|
||||
features_dict=features_dict, user_input_form=user_input_form
|
||||
)
|
||||
|
||||
|
||||
class AppMeta(WebApiResource):
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.file.models import FileExtraConfig
|
||||
from models import FileUploadConfig
|
||||
from core.file import FileExtraConfig
|
||||
|
||||
|
||||
class FileUploadConfigManager:
|
||||
|
|
@ -43,6 +42,6 @@ class FileUploadConfigManager:
|
|||
if not config.get("file_upload"):
|
||||
config["file_upload"] = {}
|
||||
else:
|
||||
FileUploadConfig.model_validate(config["file_upload"])
|
||||
FileExtraConfig.model_validate(config["file_upload"])
|
||||
|
||||
return config, ["file_upload"]
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from core.app.entities.queue_entities import (
|
|||
QueueIterationStartEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
|
|
@ -314,7 +315,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
|
|
|
|||
|
|
@ -22,7 +22,10 @@ class BaseAppGenerator:
|
|||
user_inputs = user_inputs or {}
|
||||
# Filter input variables from form configuration, handle required fields, default values, and option values
|
||||
variables = app_config.variables
|
||||
user_inputs = {var.variable: self._validate_input(inputs=user_inputs, var=var) for var in variables}
|
||||
user_inputs = {
|
||||
var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var)
|
||||
for var in variables
|
||||
}
|
||||
user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()}
|
||||
# Convert files in inputs to File
|
||||
entity_dictionary = {item.variable: item for item in app_config.variables}
|
||||
|
|
@ -74,57 +77,66 @@ class BaseAppGenerator:
|
|||
|
||||
return user_inputs
|
||||
|
||||
def _validate_input(self, *, inputs: Mapping[str, Any], var: "VariableEntity"):
|
||||
user_input_value = inputs.get(var.variable)
|
||||
def _validate_inputs(
|
||||
self,
|
||||
*,
|
||||
variable_entity: "VariableEntity",
|
||||
value: Any,
|
||||
):
|
||||
if value is None:
|
||||
if variable_entity.required:
|
||||
raise ValueError(f"{variable_entity.variable} is required in input form")
|
||||
return value
|
||||
|
||||
if not user_input_value:
|
||||
if var.required:
|
||||
raise ValueError(f"{var.variable} is required in input form")
|
||||
else:
|
||||
return None
|
||||
|
||||
if var.type in {
|
||||
if variable_entity.type in {
|
||||
VariableEntityType.TEXT_INPUT,
|
||||
VariableEntityType.SELECT,
|
||||
VariableEntityType.PARAGRAPH,
|
||||
} and not isinstance(user_input_value, str):
|
||||
raise ValueError(f"(type '{var.type}') {var.variable} in input form must be a string")
|
||||
} and not isinstance(value, str):
|
||||
raise ValueError(
|
||||
f"(type '{variable_entity.type}') {variable_entity.variable} in input form must be a string"
|
||||
)
|
||||
|
||||
if var.type == VariableEntityType.NUMBER and isinstance(user_input_value, str):
|
||||
if variable_entity.type == VariableEntityType.NUMBER and isinstance(value, str):
|
||||
# may raise ValueError if user_input_value is not a valid number
|
||||
try:
|
||||
if "." in user_input_value:
|
||||
return float(user_input_value)
|
||||
if "." in value:
|
||||
return float(value)
|
||||
else:
|
||||
return int(user_input_value)
|
||||
return int(value)
|
||||
except ValueError:
|
||||
raise ValueError(f"{var.variable} in input form must be a valid number")
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a valid number")
|
||||
|
||||
match var.type:
|
||||
match variable_entity.type:
|
||||
case VariableEntityType.SELECT:
|
||||
if user_input_value not in var.options:
|
||||
raise ValueError(f"{var.variable} in input form must be one of the following: {var.options}")
|
||||
if value not in variable_entity.options:
|
||||
raise ValueError(
|
||||
f"{variable_entity.variable} in input form must be one of the following: "
|
||||
f"{variable_entity.options}"
|
||||
)
|
||||
case VariableEntityType.TEXT_INPUT | VariableEntityType.PARAGRAPH:
|
||||
if var.max_length and len(user_input_value) > var.max_length:
|
||||
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} characters")
|
||||
if variable_entity.max_length and len(value) > variable_entity.max_length:
|
||||
raise ValueError(
|
||||
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} "
|
||||
"characters"
|
||||
)
|
||||
case VariableEntityType.FILE:
|
||||
if not isinstance(user_input_value, dict) and not isinstance(user_input_value, File):
|
||||
raise ValueError(f"{var.variable} in input form must be a file")
|
||||
if not isinstance(value, dict) and not isinstance(value, File):
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a file")
|
||||
case VariableEntityType.FILE_LIST:
|
||||
# if number of files exceeds the limit, raise ValueError
|
||||
if not (
|
||||
isinstance(user_input_value, list)
|
||||
and (
|
||||
all(isinstance(item, dict) for item in user_input_value)
|
||||
or all(isinstance(item, File) for item in user_input_value)
|
||||
)
|
||||
isinstance(value, list)
|
||||
and (all(isinstance(item, dict) for item in value) or all(isinstance(item, File) for item in value))
|
||||
):
|
||||
raise ValueError(f"{var.variable} in input form must be a list of files")
|
||||
raise ValueError(f"{variable_entity.variable} in input form must be a list of files")
|
||||
|
||||
if var.max_length and len(user_input_value) > var.max_length:
|
||||
raise ValueError(f"{var.variable} in input form must be less than {var.max_length} files")
|
||||
if variable_entity.max_length and len(value) > variable_entity.max_length:
|
||||
raise ValueError(
|
||||
f"{variable_entity.variable} in input form must be less than {variable_entity.max_length} files"
|
||||
)
|
||||
|
||||
return user_input_value
|
||||
return value
|
||||
|
||||
def _sanitize_value(self, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from core.app.entities.queue_entities import (
|
|||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
|
|
@ -275,7 +276,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueNodeFailedEvent):
|
||||
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent):
|
||||
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
|
||||
|
||||
response = self._workflow_node_finish_to_stream_response(
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from core.app.entities.queue_entities import (
|
|||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
|
|
@ -30,6 +31,7 @@ from core.workflow.graph_engine.entities.event import (
|
|||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeInIterationFailedEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunStartedEvent,
|
||||
|
|
@ -193,6 +195,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
node_run_index=event.route_node_state.index,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
|
|
@ -246,9 +249,40 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeInIterationFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeInIterationFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result
|
||||
else {},
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
|
|
@ -326,6 +360,7 @@ class WorkflowBasedAppRunner(AppRunner):
|
|||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_iteration_output,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
|
||||
|
|
|
|||
|
|
@ -107,7 +107,8 @@ class QueueIterationNextEvent(AppQueueEvent):
|
|||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
node_run_index: int
|
||||
output: Optional[Any] = None # output for the current iteration
|
||||
|
||||
|
|
@ -273,6 +274,8 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
|||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
|
||||
|
||||
class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
|
|
@ -306,6 +309,37 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
|||
error: Optional[str] = None
|
||||
|
||||
|
||||
class QueueNodeInIterationFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeInIterationFailedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.NODE_FAILED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_data: BaseNodeData
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
||||
class QueueNodeFailedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueNodeFailedEvent entity
|
||||
|
|
@ -332,6 +366,7 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
|||
inputs: Optional[dict[str, Any]] = None
|
||||
process_data: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
execution_metadata: Optional[dict[NodeRunMetadataKey, Any]] = None
|
||||
|
||||
error: str
|
||||
|
||||
|
|
|
|||
|
|
@ -244,6 +244,7 @@ class NodeStartStreamResponse(StreamResponse):
|
|||
parent_parallel_id: Optional[str] = None
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
iteration_id: Optional[str] = None
|
||||
parallel_run_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
workflow_run_id: str
|
||||
|
|
@ -432,6 +433,7 @@ class IterationNodeNextStreamResponse(StreamResponse):
|
|||
extras: dict = {}
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
|
||||
event: StreamEvent = StreamEvent.ITERATION_NEXT
|
||||
workflow_run_id: str
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from core.app.entities.queue_entities import (
|
|||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
|
|
@ -35,6 +36,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
|||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
|
|
@ -251,6 +253,12 @@ class WorkflowCycleManage:
|
|||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.execution_metadata = json.dumps(
|
||||
{
|
||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
}
|
||||
)
|
||||
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
session.add(workflow_node_execution)
|
||||
|
|
@ -305,7 +313,9 @@ class WorkflowCycleManage:
|
|||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_failed(self, event: QueueNodeFailedEvent) -> WorkflowNodeExecution:
|
||||
def _handle_workflow_node_execution_failed(
|
||||
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param event: queue node failed event
|
||||
|
|
@ -318,16 +328,19 @@ class WorkflowCycleManage:
|
|||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
|
||||
{
|
||||
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
|
||||
WorkflowNodeExecution.error: event.error,
|
||||
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
|
||||
WorkflowNodeExecution.process_data: json.dumps(process_data) if event.process_data else None,
|
||||
WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None,
|
||||
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
|
||||
WorkflowNodeExecution.finished_at: finished_at,
|
||||
WorkflowNodeExecution.elapsed_time: elapsed_time,
|
||||
WorkflowNodeExecution.execution_metadata: execution_metadata,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -342,6 +355,7 @@ class WorkflowCycleManage:
|
|||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
|
||||
self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
|
||||
|
||||
|
|
@ -448,6 +462,7 @@ class WorkflowCycleManage:
|
|||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
parallel_run_id=event.parallel_mode_run_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -464,7 +479,7 @@ class WorkflowCycleManage:
|
|||
|
||||
def _workflow_node_finish_to_stream_response(
|
||||
self,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent,
|
||||
event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeInIterationFailedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
|
|
@ -608,6 +623,7 @@ class WorkflowCycleManage:
|
|||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -633,7 +649,9 @@ class WorkflowCycleManage:
|
|||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
if event.error is None
|
||||
else WorkflowNodeExecutionStatus.FAILED,
|
||||
error=None,
|
||||
elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(),
|
||||
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
|
||||
|
|
|
|||
|
|
@ -598,7 +598,7 @@ class IndexingRunner:
|
|||
rules = DatasetProcessRule.AUTOMATIC_RULES
|
||||
else:
|
||||
rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
|
||||
document_text = CleanProcessor.clean(text, rules)
|
||||
document_text = CleanProcessor.clean(text, {"rules": rules})
|
||||
|
||||
return document_text
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
- claude-3-5-haiku-20241022
|
||||
- claude-3-5-sonnet-20241022
|
||||
- claude-3-5-sonnet-20240620
|
||||
- claude-3-haiku-20240307
|
||||
|
|
|
|||
|
|
@ -0,0 +1,39 @@
|
|||
model: claude-3-5-haiku-20241022
|
||||
label:
|
||||
en_US: claude-3-5-haiku-20241022
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '1.00'
|
||||
output: '5.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
model: anthropic.claude-3-5-haiku-20241022-v1:0
|
||||
label:
|
||||
en_US: Claude 3.5 Haiku
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
# docs: https://docs.anthropic.com/claude/docs/system-prompts
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.001'
|
||||
output: '0.005'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
model: us.anthropic.claude-3-5-haiku-20241022-v1:0
|
||||
label:
|
||||
en_US: Claude 3.5 Haiku(US.Cross Region Inference)
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 200000
|
||||
# docs: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
parameter_rules:
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
type: int
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。
|
||||
en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter.
|
||||
# docs: https://docs.anthropic.com/claude/docs/system-prompts
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
required: false
|
||||
type: float
|
||||
default: 1
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
help:
|
||||
zh_Hans: 生成内容的随机性。
|
||||
en_US: The amount of randomness injected into the response.
|
||||
- name: top_p
|
||||
required: false
|
||||
type: float
|
||||
default: 0.999
|
||||
min: 0.000
|
||||
max: 1.000
|
||||
help:
|
||||
zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。
|
||||
en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both.
|
||||
- name: top_k
|
||||
required: false
|
||||
type: int
|
||||
default: 0
|
||||
min: 0
|
||||
# tip docs from aws has error, max value is 500
|
||||
max: 500
|
||||
help:
|
||||
zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。
|
||||
en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.001'
|
||||
output: '0.005'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
import requests
|
||||
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
|
|
@ -16,8 +17,18 @@ class GiteeAIProvider(ModelProvider):
|
|||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
model_instance.validate_credentials(model="Qwen2-7B-Instruct", credentials=credentials)
|
||||
api_key = credentials.get("api_key")
|
||||
if not api_key:
|
||||
raise CredentialsValidateFailedError("Credentials validation failed: api_key not given")
|
||||
|
||||
# send a get request to validate the credentials
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
response = requests.get("https://ai.gitee.com/api/base/account/me", headers=headers, timeout=(10, 300))
|
||||
|
||||
if response.status_code != 200:
|
||||
raise CredentialsValidateFailedError(
|
||||
f"Credentials validation failed with status code {response.status_code}"
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
|
|
|
|||
|
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" fill="currentColor" viewBox="0 0 24 24" aria-hidden="true" class="" focusable="false" style="fill:currentColor;height:28px;width:28px"><path d="m3.005 8.858 8.783 12.544h3.904L6.908 8.858zM6.905 15.825 3 21.402h3.907l1.951-2.788zM16.585 2l-6.75 9.64 1.953 2.79L20.492 2zM17.292 7.965v13.437h3.2V3.395z"></path></svg>
|
||||
|
After Width: | Height: | Size: 356 B |
|
|
@ -0,0 +1,63 @@
|
|||
model: grok-beta
|
||||
label:
|
||||
en_US: Grok beta
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
label:
|
||||
en_US: "Temperature"
|
||||
zh_Hans: "采样温度"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 2.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: top_p
|
||||
label:
|
||||
en_US: "Top P"
|
||||
zh_Hans: "Top P"
|
||||
type: float
|
||||
default: 0.7
|
||||
min: 0.0
|
||||
max: 1.0
|
||||
precision: 1
|
||||
required: true
|
||||
help:
|
||||
en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time."
|
||||
zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。"
|
||||
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
label:
|
||||
en_US: "Frequency Penalty"
|
||||
zh_Hans: "频率惩罚"
|
||||
type: float
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2.0
|
||||
precision: 1
|
||||
required: false
|
||||
help:
|
||||
en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim."
|
||||
zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。"
|
||||
|
||||
- name: user
|
||||
use_template: text
|
||||
label:
|
||||
en_US: "User"
|
||||
zh_Hans: "用户"
|
||||
type: string
|
||||
required: false
|
||||
help:
|
||||
en_US: "Used to track and differentiate conversation requests from different users."
|
||||
zh_Hans: "用于追踪和区分不同用户的对话请求。"
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(
|
||||
self,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
user: Optional[str] = None,
|
||||
) -> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials) -> None:
|
||||
credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1"
|
||||
credentials["mode"] = LLMMode.CHAT.value
|
||||
credentials["function_calling_type"] = "tool_call"
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class XAIProvider(ModelProvider):
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
model_instance.validate_credentials(model="grok-beta", credentials=credentials)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f"{self.get_provider_schema().provider} credentials validate failed")
|
||||
raise ex
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
provider: x
|
||||
label:
|
||||
en_US: xAI
|
||||
description:
|
||||
en_US: xAI is a company working on building artificial intelligence to accelerate human scientific discovery. We are guided by our mission to advance our collective understanding of the universe.
|
||||
icon_small:
|
||||
en_US: x-ai-logo.svg
|
||||
icon_large:
|
||||
en_US: x-ai-logo.svg
|
||||
help:
|
||||
title:
|
||||
en_US: Get your token from xAI
|
||||
zh_Hans: 从 xAI 获取 token
|
||||
url:
|
||||
en_US: https://x.ai/api
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: endpoint_url
|
||||
label:
|
||||
en_US: API Base
|
||||
type: text-input
|
||||
required: false
|
||||
default: https://api.x.ai/v1
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Base
|
||||
en_US: Enter your API Base
|
||||
|
|
@ -14,6 +14,7 @@ import requests
|
|||
from docx import Document as DocxDocument
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -86,7 +87,7 @@ class WordExtractor(BaseExtractor):
|
|||
image_count += 1
|
||||
if rel.is_external:
|
||||
url = rel.reltype
|
||||
response = requests.get(url, stream=True)
|
||||
response = ssrf_proxy.get(url, stream=True)
|
||||
if response.status_code == 200:
|
||||
image_ext = mimetypes.guess_extension(response.headers["Content-Type"])
|
||||
file_uuid = str(uuid.uuid4())
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from hmac import new as hmac_new
|
|||
from json import loads as json_loads
|
||||
from threading import Lock
|
||||
from time import sleep, time
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from httpx import get, post
|
||||
from requests import get as requests_get
|
||||
|
|
@ -15,27 +15,27 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter,
|
|||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class AIPPTGenerateTool(BuiltinTool):
|
||||
class AIPPTGenerateToolAdapter:
|
||||
"""
|
||||
A tool for generating a ppt
|
||||
"""
|
||||
|
||||
_api_base_url = URL("https://co.aippt.cn/api")
|
||||
_api_token_cache = {}
|
||||
_api_token_cache_lock: Optional[Lock] = None
|
||||
_style_cache = {}
|
||||
_style_cache_lock: Optional[Lock] = None
|
||||
|
||||
_api_token_cache_lock = Lock()
|
||||
_style_cache_lock = Lock()
|
||||
|
||||
_task = {}
|
||||
_task_type_map = {
|
||||
"auto": 1,
|
||||
"markdown": 7,
|
||||
}
|
||||
_tool: BuiltinTool
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._api_token_cache_lock = Lock()
|
||||
self._style_cache_lock = Lock()
|
||||
def __init__(self, tool: BuiltinTool = None):
|
||||
self._tool = tool
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
|
|
@ -51,11 +51,11 @@ class AIPPTGenerateTool(BuiltinTool):
|
|||
"""
|
||||
title = tool_parameters.get("title", "")
|
||||
if not title:
|
||||
return self.create_text_message("Please provide a title for the ppt")
|
||||
return self._tool.create_text_message("Please provide a title for the ppt")
|
||||
|
||||
model = tool_parameters.get("model", "aippt")
|
||||
if not model:
|
||||
return self.create_text_message("Please provide a model for the ppt")
|
||||
return self._tool.create_text_message("Please provide a model for the ppt")
|
||||
|
||||
outline = tool_parameters.get("outline", "")
|
||||
|
||||
|
|
@ -68,8 +68,8 @@ class AIPPTGenerateTool(BuiltinTool):
|
|||
)
|
||||
|
||||
# get suit
|
||||
color = tool_parameters.get("color")
|
||||
style = tool_parameters.get("style")
|
||||
color: str = tool_parameters.get("color")
|
||||
style: str = tool_parameters.get("style")
|
||||
|
||||
if color == "__default__":
|
||||
color_id = ""
|
||||
|
|
@ -93,9 +93,9 @@ class AIPPTGenerateTool(BuiltinTool):
|
|||
# generate ppt
|
||||
_, ppt_url = self._generate_ppt(task_id=task_id, suit_id=suit_id, user_id=user_id)
|
||||
|
||||
return self.create_text_message(
|
||||
return self._tool.create_text_message(
|
||||
"""the ppt has been created successfully,"""
|
||||
f"""the ppt url is {ppt_url}"""
|
||||
f"""the ppt url is {ppt_url} ."""
|
||||
"""please give the ppt url to user and direct user to download it."""
|
||||
)
|
||||
|
||||
|
|
@ -111,8 +111,8 @@ class AIPPTGenerateTool(BuiltinTool):
|
|||
"""
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
response = post(
|
||||
str(self._api_base_url / "ai" / "chat" / "v2" / "task"),
|
||||
|
|
@ -139,8 +139,8 @@ class AIPPTGenerateTool(BuiltinTool):
|
|||
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
|
||||
response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60))
|
||||
|
|
@ -183,8 +183,8 @@ class AIPPTGenerateTool(BuiltinTool):
|
|||
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
|
||||
response = requests_get(url=api_url, headers=headers, stream=True, timeout=(10, 60))
|
||||
|
|
@ -236,14 +236,15 @@ class AIPPTGenerateTool(BuiltinTool):
|
|||
"""
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id=user_id),
|
||||
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id=user_id),
|
||||
}
|
||||
|
||||
response = post(
|
||||
str(self._api_base_url / "design" / "v2" / "save"),
|
||||
headers=headers,
|
||||
data={"task_id": task_id, "template_id": suit_id},
|
||||
timeout=(10, 60),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
|
|
@ -350,11 +351,13 @@ class AIPPTGenerateTool(BuiltinTool):
|
|||
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str:
|
||||
@staticmethod
|
||||
def _calculate_sign(access_key: str, secret_key: str, timestamp: int) -> str:
|
||||
return b64encode(
|
||||
hmac_new(
|
||||
key=secret_key.encode("utf-8"), msg=f"GET@/api/grant/token/@{timestamp}".encode(), digestmod=sha1
|
||||
key=secret_key.encode("utf-8"),
|
||||
msg=f"GET@/api/grant/token/@{timestamp}".encode(),
|
||||
digestmod=sha1,
|
||||
).digest()
|
||||
).decode("utf-8")
|
||||
|
||||
|
|
@ -419,10 +422,12 @@ class AIPPTGenerateTool(BuiltinTool):
|
|||
:param credentials: the credentials
|
||||
:return: Tuple[list[dict[id, color]], list[dict[id, style]]
|
||||
"""
|
||||
if not self.runtime.credentials.get("aippt_access_key") or not self.runtime.credentials.get("aippt_secret_key"):
|
||||
if not self._tool.runtime.credentials.get("aippt_access_key") or not self._tool.runtime.credentials.get(
|
||||
"aippt_secret_key"
|
||||
):
|
||||
raise Exception("Please provide aippt credentials")
|
||||
|
||||
return self._get_styles(credentials=self.runtime.credentials, user_id=user_id)
|
||||
return self._get_styles(credentials=self._tool.runtime.credentials, user_id=user_id)
|
||||
|
||||
def _get_suit(self, style_id: int, colour_id: int) -> int:
|
||||
"""
|
||||
|
|
@ -430,8 +435,8 @@ class AIPPTGenerateTool(BuiltinTool):
|
|||
"""
|
||||
headers = {
|
||||
"x-channel": "",
|
||||
"x-api-key": self.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self.runtime.credentials, user_id="__dify_system__"),
|
||||
"x-api-key": self._tool.runtime.credentials["aippt_access_key"],
|
||||
"x-token": self._get_api_token(credentials=self._tool.runtime.credentials, user_id="__dify_system__"),
|
||||
}
|
||||
response = get(
|
||||
str(self._api_base_url / "template_component" / "suit" / "search"),
|
||||
|
|
@ -496,3 +501,18 @@ class AIPPTGenerateTool(BuiltinTool):
|
|||
],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class AIPPTGenerateTool(BuiltinTool):
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
return AIPPTGenerateToolAdapter(self)._invoke(user_id, tool_parameters)
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
return AIPPTGenerateToolAdapter(self).get_runtime_parameters()
|
||||
|
||||
@classmethod
|
||||
def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str:
|
||||
return AIPPTGenerateToolAdapter()._get_api_token(credentials, user_id)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ class NodeRunMetadataKey(str, Enum):
|
|||
PARALLEL_START_NODE_ID = "parallel_start_node_id"
|
||||
PARENT_PARALLEL_ID = "parent_parallel_id"
|
||||
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
|
||||
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ class BaseNodeEvent(GraphEngineEvent):
|
|||
|
||||
class NodeRunStartedEvent(BaseNodeEvent):
|
||||
predecessor_node_id: Optional[str] = None
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""predecessor node id"""
|
||||
|
||||
|
||||
|
|
@ -81,6 +82,10 @@ class NodeRunFailedEvent(BaseNodeEvent):
|
|||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
class NodeInIterationFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
###########################################
|
||||
# Parallel Branch Events
|
||||
###########################################
|
||||
|
|
@ -129,6 +134,8 @@ class BaseIterationEvent(GraphEngineEvent):
|
|||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
parallel_mode_run_id: Optional[str] = None
|
||||
"""iteratoin run in parallel mode run id"""
|
||||
|
||||
|
||||
class IterationRunStartedEvent(BaseIterationEvent):
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import time
|
|||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from copy import copy, deepcopy
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
|
@ -724,6 +725,16 @@ class GraphEngine:
|
|||
"""
|
||||
return time.perf_counter() - start_at > max_execution_time
|
||||
|
||||
def create_copy(self):
|
||||
"""
|
||||
create a graph engine copy
|
||||
:return: with a new variable pool instance of graph engine
|
||||
"""
|
||||
new_instance = copy(self)
|
||||
new_instance.graph_runtime_state = copy(self.graph_runtime_state)
|
||||
new_instance.graph_runtime_state.variable_pool = deepcopy(self.graph_runtime_state.variable_pool)
|
||||
return new_instance
|
||||
|
||||
|
||||
class GraphRunFailedError(Exception):
|
||||
def __init__(self, error: str):
|
||||
|
|
|
|||
|
|
@ -12,6 +12,12 @@ from core.workflow.nodes.code.entities import CodeNodeData
|
|||
from core.workflow.nodes.enums import NodeType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .exc import (
|
||||
CodeNodeError,
|
||||
DepthLimitError,
|
||||
OutputValidationError,
|
||||
)
|
||||
|
||||
|
||||
class CodeNode(BaseNode[CodeNodeData]):
|
||||
_node_data_cls = CodeNodeData
|
||||
|
|
@ -60,7 +66,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
|
||||
# Transform result
|
||||
result = self._transform_result(result, self.node_data.outputs)
|
||||
except (CodeExecutionError, ValueError) as e:
|
||||
except (CodeExecutionError, CodeNodeError) as e:
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e))
|
||||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
|
||||
|
|
@ -76,10 +82,10 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
if value is None:
|
||||
return None
|
||||
else:
|
||||
raise ValueError(f"Output variable `{variable}` must be a string")
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a string")
|
||||
|
||||
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{variable}` must be"
|
||||
f" less than {dify_config.CODE_MAX_STRING_LENGTH} characters"
|
||||
)
|
||||
|
|
@ -97,10 +103,10 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
if value is None:
|
||||
return None
|
||||
else:
|
||||
raise ValueError(f"Output variable `{variable}` must be a number")
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a number")
|
||||
|
||||
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"Output variable `{variable}` is out of range,"
|
||||
f" it must be between {dify_config.CODE_MIN_NUMBER} and {dify_config.CODE_MAX_NUMBER}."
|
||||
)
|
||||
|
|
@ -108,7 +114,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
if isinstance(value, float):
|
||||
# raise error if precision is too high
|
||||
if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"Output variable `{variable}` has too high precision,"
|
||||
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
|
||||
)
|
||||
|
|
@ -125,7 +131,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
:return:
|
||||
"""
|
||||
if depth > dify_config.CODE_MAX_DEPTH:
|
||||
raise ValueError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
|
||||
transformed_result = {}
|
||||
if output_schema is None:
|
||||
|
|
@ -177,14 +183,14 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
depth=depth + 1,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}.{output_name} is not a valid array."
|
||||
f" make sure all elements are of the same type."
|
||||
)
|
||||
elif output_value is None:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Output {prefix}.{output_name} is not a valid type.")
|
||||
raise OutputValidationError(f"Output {prefix}.{output_name} is not a valid type.")
|
||||
|
||||
return result
|
||||
|
||||
|
|
@ -192,7 +198,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
for output_name, output_config in output_schema.items():
|
||||
dot = "." if prefix else ""
|
||||
if output_name not in result:
|
||||
raise ValueError(f"Output {prefix}{dot}{output_name} is missing.")
|
||||
raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.")
|
||||
|
||||
if output_config.type == "object":
|
||||
# check if output is object
|
||||
|
|
@ -200,7 +206,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
if isinstance(result.get(output_name), type(None)):
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an object,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
|
|
@ -228,13 +234,13 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
if isinstance(result[output_name], type(None)):
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH} elements."
|
||||
)
|
||||
|
|
@ -249,13 +255,13 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
if isinstance(result[output_name], type(None)):
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_STRING_ARRAY_LENGTH:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_STRING_ARRAY_LENGTH} elements."
|
||||
)
|
||||
|
|
@ -270,13 +276,13 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
if isinstance(result[output_name], type(None)):
|
||||
transformed_result[output_name] = None
|
||||
else:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name} is not an array,"
|
||||
f" got {type(result.get(output_name))} instead."
|
||||
)
|
||||
else:
|
||||
if len(result[output_name]) > dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"The length of output variable `{prefix}{dot}{output_name}` must be"
|
||||
f" less than {dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH} elements."
|
||||
)
|
||||
|
|
@ -286,7 +292,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
if value is None:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
raise OutputValidationError(
|
||||
f"Output {prefix}{dot}{output_name}[{i}] is not an object,"
|
||||
f" got {type(value)} instead at index {i}."
|
||||
)
|
||||
|
|
@ -303,13 +309,13 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
for i, value in enumerate(result[output_name])
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"Output type {output_config.type} is not supported.")
|
||||
raise OutputValidationError(f"Output type {output_config.type} is not supported.")
|
||||
|
||||
parameters_validated[output_name] = True
|
||||
|
||||
# check if all output parameters are validated
|
||||
if len(parameters_validated) != len(result):
|
||||
raise ValueError("Not all output parameters are validated.")
|
||||
raise CodeNodeError("Not all output parameters are validated.")
|
||||
|
||||
return transformed_result
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
class CodeNodeError(ValueError):
|
||||
"""Base class for code node errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OutputValidationError(CodeNodeError):
|
||||
"""Raised when there is an output validation error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DepthLimitError(CodeNodeError):
|
||||
"""Raised when the depth limit is reached."""
|
||||
|
||||
pass
|
||||
|
|
@ -198,10 +198,8 @@ def _download_file_content(file: File) -> bytes:
|
|||
response = ssrf_proxy.get(file.remote_url)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
elif file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
return file_manager.download(file)
|
||||
else:
|
||||
raise ValueError(f"Unsupported transfer method: {file.transfer_method}")
|
||||
return file_manager.download(file)
|
||||
except Exception as e:
|
||||
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,18 @@
|
|||
class HttpRequestNodeError(ValueError):
|
||||
"""Custom error for HTTP request node."""
|
||||
|
||||
|
||||
class AuthorizationConfigError(HttpRequestNodeError):
|
||||
"""Raised when authorization config is missing or invalid."""
|
||||
|
||||
|
||||
class FileFetchError(HttpRequestNodeError):
|
||||
"""Raised when a file cannot be fetched."""
|
||||
|
||||
|
||||
class InvalidHttpMethodError(HttpRequestNodeError):
|
||||
"""Raised when an invalid HTTP method is used."""
|
||||
|
||||
|
||||
class ResponseSizeError(HttpRequestNodeError):
|
||||
"""Raised when the response size exceeds the allowed threshold."""
|
||||
|
|
@ -18,6 +18,12 @@ from .entities import (
|
|||
HttpRequestNodeTimeout,
|
||||
Response,
|
||||
)
|
||||
from .exc import (
|
||||
AuthorizationConfigError,
|
||||
FileFetchError,
|
||||
InvalidHttpMethodError,
|
||||
ResponseSizeError,
|
||||
)
|
||||
|
||||
BODY_TYPE_TO_CONTENT_TYPE = {
|
||||
"json": "application/json",
|
||||
|
|
@ -51,7 +57,7 @@ class Executor:
|
|||
# If authorization API key is present, convert the API key using the variable pool
|
||||
if node_data.authorization.type == "api-key":
|
||||
if node_data.authorization.config is None:
|
||||
raise ValueError("authorization config is required")
|
||||
raise AuthorizationConfigError("authorization config is required")
|
||||
node_data.authorization.config.api_key = variable_pool.convert_template(
|
||||
node_data.authorization.config.api_key
|
||||
).text
|
||||
|
|
@ -82,8 +88,10 @@ class Executor:
|
|||
self.url = self.variable_pool.convert_template(self.node_data.url).text
|
||||
|
||||
def _init_params(self):
|
||||
params = self.variable_pool.convert_template(self.node_data.params).text
|
||||
self.params = _plain_text_to_dict(params)
|
||||
params = _plain_text_to_dict(self.node_data.params)
|
||||
for key in params:
|
||||
params[key] = self.variable_pool.convert_template(params[key]).text
|
||||
self.params = params
|
||||
|
||||
def _init_headers(self):
|
||||
headers = self.variable_pool.convert_template(self.node_data.headers).text
|
||||
|
|
@ -116,7 +124,7 @@ class Executor:
|
|||
file_selector = data[0].file
|
||||
file_variable = self.variable_pool.get_file(file_selector)
|
||||
if file_variable is None:
|
||||
raise ValueError(f"cannot fetch file with selector {file_selector}")
|
||||
raise FileFetchError(f"cannot fetch file with selector {file_selector}")
|
||||
file = file_variable.value
|
||||
self.content = file_manager.download(file)
|
||||
case "x-www-form-urlencoded":
|
||||
|
|
@ -155,12 +163,12 @@ class Executor:
|
|||
headers = deepcopy(self.headers) or {}
|
||||
if self.auth.type == "api-key":
|
||||
if self.auth.config is None:
|
||||
raise ValueError("self.authorization config is required")
|
||||
raise AuthorizationConfigError("self.authorization config is required")
|
||||
if authorization.config is None:
|
||||
raise ValueError("authorization config is required")
|
||||
raise AuthorizationConfigError("authorization config is required")
|
||||
|
||||
if self.auth.config.api_key is None:
|
||||
raise ValueError("api_key is required")
|
||||
raise AuthorizationConfigError("api_key is required")
|
||||
|
||||
if not authorization.config.header:
|
||||
authorization.config.header = "Authorization"
|
||||
|
|
@ -183,7 +191,7 @@ class Executor:
|
|||
else dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE
|
||||
)
|
||||
if executor_response.size > threshold_size:
|
||||
raise ValueError(
|
||||
raise ResponseSizeError(
|
||||
f'{"File" if executor_response.is_file else "Text"} size is too large,'
|
||||
f' max size is {threshold_size / 1024 / 1024:.2f} MB,'
|
||||
f' but current size is {executor_response.readable_size}.'
|
||||
|
|
@ -196,7 +204,7 @@ class Executor:
|
|||
do http request depending on api bundle
|
||||
"""
|
||||
if self.method not in {"get", "head", "post", "put", "delete", "patch"}:
|
||||
raise ValueError(f"Invalid http method {self.method}")
|
||||
raise InvalidHttpMethodError(f"Invalid http method {self.method}")
|
||||
|
||||
request_args = {
|
||||
"url": self.url,
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from .entities import (
|
|||
HttpRequestNodeTimeout,
|
||||
Response,
|
||||
)
|
||||
from .exc import HttpRequestNodeError
|
||||
|
||||
HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
|
||||
connect=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT,
|
||||
|
|
@ -77,7 +78,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
|
|||
"request": http_executor.to_log(),
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
except HttpRequestNodeError as e:
|
||||
logger.warning(f"http request node {self.node_id} failed to run: {e}")
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
|
@ -5,6 +6,12 @@ from pydantic import Field
|
|||
from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData
|
||||
|
||||
|
||||
class ErrorHandleMode(str, Enum):
|
||||
TERMINATED = "terminated"
|
||||
CONTINUE_ON_ERROR = "continue-on-error"
|
||||
REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output"
|
||||
|
||||
|
||||
class IterationNodeData(BaseIterationNodeData):
|
||||
"""
|
||||
Iteration Node Data.
|
||||
|
|
@ -13,6 +20,9 @@ class IterationNodeData(BaseIterationNodeData):
|
|||
parent_loop_id: Optional[str] = None # redundant field, not used currently
|
||||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
is_parallel: bool = False # open the parallel mode or not
|
||||
parallel_nums: int = 10 # the numbers of parallel
|
||||
error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error
|
||||
|
||||
|
||||
class IterationStartNodeData(BaseNodeData):
|
||||
|
|
|
|||
|
|
@ -1,12 +1,20 @@
|
|||
import logging
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from concurrent.futures import Future, wait
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
from queue import Empty, Queue
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.variables import IntegerSegment
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.entities.node_entities import (
|
||||
NodeRunMetadataKey,
|
||||
NodeRunResult,
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseGraphEvent,
|
||||
BaseNodeEvent,
|
||||
|
|
@ -17,6 +25,9 @@ from core.workflow.graph_engine.entities.event import (
|
|||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeInIterationFailedEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
|
|
@ -24,9 +35,11 @@ from core.workflow.graph_engine.entities.graph import Graph
|
|||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -38,6 +51,17 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
_node_data_cls = IterationNodeData
|
||||
_node_type = NodeType.ITERATION
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
return {
|
||||
"type": "iteration",
|
||||
"config": {
|
||||
"is_parallel": False,
|
||||
"parallel_nums": 10,
|
||||
"error_handle_mode": ErrorHandleMode.TERMINATED.value,
|
||||
},
|
||||
}
|
||||
|
||||
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run the node.
|
||||
|
|
@ -83,7 +107,7 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
|
||||
|
||||
# init graph engine
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=self.tenant_id,
|
||||
|
|
@ -123,108 +147,64 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
index=0,
|
||||
pre_iteration_output=None,
|
||||
)
|
||||
|
||||
outputs: list[Any] = []
|
||||
try:
|
||||
for _ in range(len(iterator_list_value)):
|
||||
# run workflow
|
||||
rst = graph_engine.run()
|
||||
for event in rst:
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
||||
event.in_iteration_id = self.node_id
|
||||
|
||||
if (
|
||||
isinstance(event, BaseNodeEvent)
|
||||
and event.node_type == NodeType.ITERATION_START
|
||||
and not isinstance(event, NodeRunStreamChunkEvent)
|
||||
):
|
||||
if self.node_data.is_parallel:
|
||||
futures: list[Future] = []
|
||||
q = Queue()
|
||||
thread_pool = GraphEngineThreadPool(max_workers=self.node_data.parallel_nums, max_submit_count=100)
|
||||
for index, item in enumerate(iterator_list_value):
|
||||
future: Future = thread_pool.submit(
|
||||
self._run_single_iter_parallel,
|
||||
current_app._get_current_object(),
|
||||
q,
|
||||
iterator_list_value,
|
||||
inputs,
|
||||
outputs,
|
||||
start_at,
|
||||
graph_engine,
|
||||
iteration_graph,
|
||||
index,
|
||||
item,
|
||||
)
|
||||
future.add_done_callback(thread_pool.task_done_callback)
|
||||
futures.append(future)
|
||||
succeeded_count = 0
|
||||
while True:
|
||||
try:
|
||||
event = q.get(timeout=1)
|
||||
if event is None:
|
||||
break
|
||||
if isinstance(event, IterationRunNextEvent):
|
||||
succeeded_count += 1
|
||||
if succeeded_count == len(futures):
|
||||
q.put(None)
|
||||
yield event
|
||||
if isinstance(event, RunCompletedEvent):
|
||||
q.put(None)
|
||||
for f in futures:
|
||||
if not f.done():
|
||||
f.cancel()
|
||||
yield event
|
||||
if isinstance(event, IterationRunFailedEvent):
|
||||
q.put(None)
|
||||
yield event
|
||||
except Empty:
|
||||
continue
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
if event.route_node_state.node_run_result:
|
||||
metadata = event.route_node_state.node_run_result.metadata
|
||||
if not metadata:
|
||||
metadata = {}
|
||||
|
||||
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
||||
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
||||
index_variable = variable_pool.get([self.node_id, "index"])
|
||||
if not isinstance(index_variable, IntegerSegment):
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=f"Invalid index variable type: {type(index_variable)}",
|
||||
)
|
||||
)
|
||||
return
|
||||
metadata[NodeRunMetadataKey.ITERATION_INDEX] = index_variable.value
|
||||
event.route_node_state.node_run_result.metadata = metadata
|
||||
|
||||
yield event
|
||||
elif isinstance(event, BaseGraphEvent):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
# iteration run failed
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": jsonable_encoder(outputs)},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
return
|
||||
else:
|
||||
event = cast(InNodeEvent, event)
|
||||
yield event
|
||||
|
||||
# append to iteration output variable list
|
||||
current_iteration_output_variable = variable_pool.get(self.node_data.output_selector)
|
||||
if current_iteration_output_variable is None:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=f"Iteration output variable {self.node_data.output_selector} not found",
|
||||
)
|
||||
# wait all threads
|
||||
wait(futures)
|
||||
else:
|
||||
for _ in range(len(iterator_list_value)):
|
||||
yield from self._run_single_iter(
|
||||
iterator_list_value,
|
||||
variable_pool,
|
||||
inputs,
|
||||
outputs,
|
||||
start_at,
|
||||
graph_engine,
|
||||
iteration_graph,
|
||||
)
|
||||
return
|
||||
current_iteration_output = current_iteration_output_variable.to_object()
|
||||
outputs.append(current_iteration_output)
|
||||
|
||||
# remove all nodes outputs from variable pool
|
||||
for node_id in iteration_graph.node_ids:
|
||||
variable_pool.remove([node_id])
|
||||
|
||||
# move to next iteration
|
||||
current_index_variable = variable_pool.get([self.node_id, "index"])
|
||||
if not isinstance(current_index_variable, IntegerSegment):
|
||||
raise ValueError(f"iteration {self.node_id} current index not found")
|
||||
|
||||
next_index = current_index_variable.value + 1
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
index=next_index,
|
||||
pre_iteration_output=jsonable_encoder(current_iteration_output),
|
||||
)
|
||||
|
||||
yield IterationRunSucceededEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
|
|
@ -330,3 +310,231 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||
}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
def _handle_event_metadata(
|
||||
self, event: BaseNodeEvent, iter_run_index: str, parallel_mode_run_id: str
|
||||
) -> NodeRunStartedEvent | BaseNodeEvent:
|
||||
"""
|
||||
add iteration metadata to event.
|
||||
"""
|
||||
if not isinstance(event, BaseNodeEvent):
|
||||
return event
|
||||
if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
|
||||
event.parallel_mode_run_id = parallel_mode_run_id
|
||||
return event
|
||||
if event.route_node_state.node_run_result:
|
||||
metadata = event.route_node_state.node_run_result.metadata
|
||||
if not metadata:
|
||||
metadata = {}
|
||||
|
||||
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
||||
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
||||
if self.node_data.is_parallel:
|
||||
metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
|
||||
else:
|
||||
metadata[NodeRunMetadataKey.ITERATION_INDEX] = iter_run_index
|
||||
event.route_node_state.node_run_result.metadata = metadata
|
||||
return event
|
||||
|
||||
def _run_single_iter(
|
||||
self,
|
||||
iterator_list_value: list[str],
|
||||
variable_pool: VariablePool,
|
||||
inputs: dict[str, list],
|
||||
outputs: list,
|
||||
start_at: datetime,
|
||||
graph_engine: "GraphEngine",
|
||||
iteration_graph: Graph,
|
||||
parallel_mode_run_id: Optional[str] = None,
|
||||
) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
run single iteration
|
||||
"""
|
||||
try:
|
||||
rst = graph_engine.run()
|
||||
# get current iteration index
|
||||
current_index = variable_pool.get([self.node_id, "index"]).value
|
||||
next_index = int(current_index) + 1
|
||||
|
||||
if current_index is None:
|
||||
raise ValueError(f"iteration {self.node_id} current index not found")
|
||||
for event in rst:
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
||||
event.in_iteration_id = self.node_id
|
||||
|
||||
if (
|
||||
isinstance(event, BaseNodeEvent)
|
||||
and event.node_type == NodeType.ITERATION_START
|
||||
and not isinstance(event, NodeRunStreamChunkEvent)
|
||||
):
|
||||
continue
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
yield self._handle_event_metadata(event, current_index, parallel_mode_run_id)
|
||||
elif isinstance(event, BaseGraphEvent):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
# iteration run failed
|
||||
if self.node_data.is_parallel:
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": jsonable_encoder(outputs)},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=event.error,
|
||||
)
|
||||
else:
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": jsonable_encoder(outputs)},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=event.error,
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
return
|
||||
else:
|
||||
event = cast(InNodeEvent, event)
|
||||
metadata_event = self._handle_event_metadata(event, current_index, parallel_mode_run_id)
|
||||
if isinstance(event, NodeRunFailedEvent):
|
||||
if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
|
||||
yield NodeInIterationFailedEvent(
|
||||
**metadata_event.model_dump(),
|
||||
)
|
||||
outputs.insert(current_index, None)
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
index=next_index,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
pre_iteration_output=None,
|
||||
)
|
||||
return
|
||||
elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
|
||||
yield NodeInIterationFailedEvent(
|
||||
**metadata_event.model_dump(),
|
||||
)
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
index=next_index,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
pre_iteration_output=None,
|
||||
)
|
||||
return
|
||||
elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": None},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=event.error,
|
||||
)
|
||||
yield metadata_event
|
||||
|
||||
current_iteration_output = variable_pool.get(self.node_data.output_selector).value
|
||||
outputs.insert(current_index, current_iteration_output)
|
||||
# remove all nodes outputs from variable pool
|
||||
for node_id in iteration_graph.node_ids:
|
||||
variable_pool.remove([node_id])
|
||||
|
||||
# move to next iteration
|
||||
variable_pool.add([self.node_id, "index"], next_index)
|
||||
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add([self.node_id, "item"], iterator_list_value[next_index])
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
index=next_index,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
pre_iteration_output=jsonable_encoder(current_iteration_output) if current_iteration_output else None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Iteration run failed:{str(e)}")
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={"output": None},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
|
||||
error=str(e),
|
||||
)
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
|
||||
def _run_single_iter_parallel(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
q: Queue,
|
||||
iterator_list_value: list[str],
|
||||
inputs: dict[str, list],
|
||||
outputs: list,
|
||||
start_at: datetime,
|
||||
graph_engine: "GraphEngine",
|
||||
iteration_graph: Graph,
|
||||
index: int,
|
||||
item: Any,
|
||||
) -> Generator[NodeEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
run single iteration in parallel mode
|
||||
"""
|
||||
with flask_app.app_context():
|
||||
parallel_mode_run_id = uuid.uuid4().hex
|
||||
graph_engine_copy = graph_engine.create_copy()
|
||||
variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool
|
||||
variable_pool_copy.add([self.node_id, "index"], index)
|
||||
variable_pool_copy.add([self.node_id, "item"], item)
|
||||
for event in self._run_single_iter(
|
||||
iterator_list_value=iterator_list_value,
|
||||
variable_pool=variable_pool_copy,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
start_at=start_at,
|
||||
graph_engine=graph_engine_copy,
|
||||
iteration_graph=iteration_graph,
|
||||
parallel_mode_run_id=parallel_mode_run_id,
|
||||
):
|
||||
q.put(event)
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]:
|
|||
return lambda x: x.type
|
||||
case "extension":
|
||||
return lambda x: x.extension or ""
|
||||
case "mimetype":
|
||||
case "mime_type":
|
||||
return lambda x: x.mime_type or ""
|
||||
case "transfer_method":
|
||||
return lambda x: x.transfer_method
|
||||
|
|
@ -295,4 +295,4 @@ def _order_file(*, order: Literal["asc", "desc"], order_by: str = "", array: Seq
|
|||
extract_func = _get_file_extract_number_func(key=order_by)
|
||||
return sorted(array, key=lambda x: extract_func(x), reverse=order == "desc")
|
||||
else:
|
||||
raise ValueError(f"Invalid order key: {order_by}")
|
||||
raise InvalidKeyError(f"Invalid order key: {order_by}")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,26 @@
|
|||
class LLMNodeError(ValueError):
|
||||
"""Base class for LLM Node errors."""
|
||||
|
||||
|
||||
class VariableNotFoundError(LLMNodeError):
|
||||
"""Raised when a required variable is not found."""
|
||||
|
||||
|
||||
class InvalidContextStructureError(LLMNodeError):
|
||||
"""Raised when the context structure is invalid."""
|
||||
|
||||
|
||||
class InvalidVariableTypeError(LLMNodeError):
|
||||
"""Raised when the variable type is invalid."""
|
||||
|
||||
|
||||
class ModelNotExistError(LLMNodeError):
|
||||
"""Raised when the specified model does not exist."""
|
||||
|
||||
|
||||
class LLMModeRequiredError(LLMNodeError):
|
||||
"""Raised when LLM mode is required but not provided."""
|
||||
|
||||
|
||||
class NoPromptFoundError(LLMNodeError):
|
||||
"""Raised when no prompt is found in the LLM configuration."""
|
||||
|
|
@ -56,6 +56,15 @@ from .entities import (
|
|||
LLMNodeData,
|
||||
ModelConfig,
|
||||
)
|
||||
from .exc import (
|
||||
InvalidContextStructureError,
|
||||
InvalidVariableTypeError,
|
||||
LLMModeRequiredError,
|
||||
LLMNodeError,
|
||||
ModelNotExistError,
|
||||
NoPromptFoundError,
|
||||
VariableNotFoundError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.file.models import File
|
||||
|
|
@ -103,7 +112,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
yield event
|
||||
|
||||
if context:
|
||||
node_inputs["#context#"] = context # type: ignore
|
||||
node_inputs["#context#"] = context
|
||||
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(self.node_data.model)
|
||||
|
|
@ -115,7 +124,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
if self.node_data.memory:
|
||||
query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
|
||||
if not query:
|
||||
raise ValueError("Query not found")
|
||||
raise VariableNotFoundError("Query not found")
|
||||
query = query.text
|
||||
else:
|
||||
query = None
|
||||
|
|
@ -161,7 +170,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
usage = event.usage
|
||||
finish_reason = event.finish_reason
|
||||
break
|
||||
except Exception as e:
|
||||
except LLMNodeError as e:
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
|
|
@ -275,7 +284,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
variable_name = variable_selector.variable
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
if variable is None:
|
||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
||||
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
|
||||
|
||||
def parse_dict(input_dict: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
|
|
@ -325,7 +334,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
for variable_selector in variable_selectors:
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
if variable is None:
|
||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
||||
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
|
||||
if isinstance(variable, NoneSegment):
|
||||
inputs[variable_selector.variable] = ""
|
||||
inputs[variable_selector.variable] = variable.to_object()
|
||||
|
|
@ -338,7 +347,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
for variable_selector in query_variable_selectors:
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
if variable is None:
|
||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
||||
raise VariableNotFoundError(f"Variable {variable_selector.variable} not found")
|
||||
if isinstance(variable, NoneSegment):
|
||||
continue
|
||||
inputs[variable_selector.variable] = variable.to_object()
|
||||
|
|
@ -355,7 +364,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
return variable.value
|
||||
elif isinstance(variable, NoneSegment | ArrayAnySegment):
|
||||
return []
|
||||
raise ValueError(f"Invalid variable type: {type(variable)}")
|
||||
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
|
||||
|
||||
def _fetch_context(self, node_data: LLMNodeData):
|
||||
if not node_data.context.enabled:
|
||||
|
|
@ -376,7 +385,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
context_str += item + "\n"
|
||||
else:
|
||||
if "content" not in item:
|
||||
raise ValueError(f"Invalid context structure: {item}")
|
||||
raise InvalidContextStructureError(f"Invalid context structure: {item}")
|
||||
|
||||
context_str += item["content"] + "\n"
|
||||
|
||||
|
|
@ -441,7 +450,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
|
|
@ -460,12 +469,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
# get model mode
|
||||
model_mode = node_data_model.mode
|
||||
if not model_mode:
|
||||
raise ValueError("LLM mode is required.")
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model {model_name} not exist.")
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=provider_name,
|
||||
|
|
@ -564,7 +573,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
filtered_prompt_messages.append(prompt_message)
|
||||
|
||||
if not filtered_prompt_messages:
|
||||
raise ValueError(
|
||||
raise NoPromptFoundError(
|
||||
"No prompt found in the LLM configuration. "
|
||||
"Please ensure a prompt is properly configured before proceeding."
|
||||
)
|
||||
|
|
@ -636,7 +645,7 @@ class LLMNode(BaseNode[LLMNodeData]):
|
|||
variable_template_parser = VariableTemplateParser(template=prompt_template.text)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
else:
|
||||
raise ValueError(f"Invalid prompt template type: {type(prompt_template)}")
|
||||
raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}")
|
||||
|
||||
variable_mapping = {}
|
||||
for variable_selector in variable_selectors:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,50 @@
|
|||
class ParameterExtractorNodeError(ValueError):
|
||||
"""Base error for ParameterExtractorNode."""
|
||||
|
||||
|
||||
class InvalidModelTypeError(ParameterExtractorNodeError):
|
||||
"""Raised when the model is not a Large Language Model."""
|
||||
|
||||
|
||||
class ModelSchemaNotFoundError(ParameterExtractorNodeError):
|
||||
"""Raised when the model schema is not found."""
|
||||
|
||||
|
||||
class InvalidInvokeResultError(ParameterExtractorNodeError):
|
||||
"""Raised when the invoke result is invalid."""
|
||||
|
||||
|
||||
class InvalidTextContentTypeError(ParameterExtractorNodeError):
|
||||
"""Raised when the text content type is invalid."""
|
||||
|
||||
|
||||
class InvalidNumberOfParametersError(ParameterExtractorNodeError):
|
||||
"""Raised when the number of parameters is invalid."""
|
||||
|
||||
|
||||
class RequiredParameterMissingError(ParameterExtractorNodeError):
|
||||
"""Raised when a required parameter is missing."""
|
||||
|
||||
|
||||
class InvalidSelectValueError(ParameterExtractorNodeError):
|
||||
"""Raised when a select value is invalid."""
|
||||
|
||||
|
||||
class InvalidNumberValueError(ParameterExtractorNodeError):
|
||||
"""Raised when a number value is invalid."""
|
||||
|
||||
|
||||
class InvalidBoolValueError(ParameterExtractorNodeError):
|
||||
"""Raised when a bool value is invalid."""
|
||||
|
||||
|
||||
class InvalidStringValueError(ParameterExtractorNodeError):
|
||||
"""Raised when a string value is invalid."""
|
||||
|
||||
|
||||
class InvalidArrayValueError(ParameterExtractorNodeError):
|
||||
"""Raised when an array value is invalid."""
|
||||
|
||||
|
||||
class InvalidModelModeError(ParameterExtractorNodeError):
|
||||
"""Raised when the model mode is invalid."""
|
||||
|
|
@ -32,6 +32,21 @@ from extensions.ext_database import db
|
|||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import ParameterExtractorNodeData
|
||||
from .exc import (
|
||||
InvalidArrayValueError,
|
||||
InvalidBoolValueError,
|
||||
InvalidInvokeResultError,
|
||||
InvalidModelModeError,
|
||||
InvalidModelTypeError,
|
||||
InvalidNumberOfParametersError,
|
||||
InvalidNumberValueError,
|
||||
InvalidSelectValueError,
|
||||
InvalidStringValueError,
|
||||
InvalidTextContentTypeError,
|
||||
ModelSchemaNotFoundError,
|
||||
ParameterExtractorNodeError,
|
||||
RequiredParameterMissingError,
|
||||
)
|
||||
from .prompts import (
|
||||
CHAT_EXAMPLE,
|
||||
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
|
||||
|
|
@ -85,7 +100,7 @@ class ParameterExtractorNode(LLMNode):
|
|||
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||
raise ValueError("Model is not a Large Language Model")
|
||||
raise InvalidModelTypeError("Model is not a Large Language Model")
|
||||
|
||||
llm_model = model_instance.model_type_instance
|
||||
model_schema = llm_model.get_model_schema(
|
||||
|
|
@ -93,7 +108,7 @@ class ParameterExtractorNode(LLMNode):
|
|||
credentials=model_config.credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ValueError("Model schema not found")
|
||||
raise ModelSchemaNotFoundError("Model schema not found")
|
||||
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(
|
||||
|
|
@ -155,7 +170,7 @@ class ParameterExtractorNode(LLMNode):
|
|||
process_data["usage"] = jsonable_encoder(usage)
|
||||
process_data["tool_call"] = jsonable_encoder(tool_call)
|
||||
process_data["llm_text"] = text
|
||||
except Exception as e:
|
||||
except ParameterExtractorNodeError as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=inputs,
|
||||
|
|
@ -177,7 +192,7 @@ class ParameterExtractorNode(LLMNode):
|
|||
|
||||
try:
|
||||
result = self._validate_result(data=node_data, result=result or {})
|
||||
except Exception as e:
|
||||
except ParameterExtractorNodeError as e:
|
||||
error = str(e)
|
||||
|
||||
# transform result into standard format
|
||||
|
|
@ -217,11 +232,11 @@ class ParameterExtractorNode(LLMNode):
|
|||
|
||||
# handle invoke result
|
||||
if not isinstance(invoke_result, LLMResult):
|
||||
raise ValueError(f"Invalid invoke result: {invoke_result}")
|
||||
raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}")
|
||||
|
||||
text = invoke_result.message.content
|
||||
if not isinstance(text, str):
|
||||
raise ValueError(f"Invalid text content type: {type(text)}. Expected str.")
|
||||
raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.")
|
||||
|
||||
usage = invoke_result.usage
|
||||
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
|
||||
|
|
@ -344,7 +359,7 @@ class ParameterExtractorNode(LLMNode):
|
|||
files=files,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid model mode: {model_mode}")
|
||||
raise InvalidModelModeError(f"Invalid model mode: {model_mode}")
|
||||
|
||||
def _generate_prompt_engineering_completion_prompt(
|
||||
self,
|
||||
|
|
@ -449,36 +464,36 @@ class ParameterExtractorNode(LLMNode):
|
|||
Validate result.
|
||||
"""
|
||||
if len(data.parameters) != len(result):
|
||||
raise ValueError("Invalid number of parameters")
|
||||
raise InvalidNumberOfParametersError("Invalid number of parameters")
|
||||
|
||||
for parameter in data.parameters:
|
||||
if parameter.required and parameter.name not in result:
|
||||
raise ValueError(f"Parameter {parameter.name} is required")
|
||||
raise RequiredParameterMissingError(f"Parameter {parameter.name} is required")
|
||||
|
||||
if parameter.type == "select" and parameter.options and result.get(parameter.name) not in parameter.options:
|
||||
raise ValueError(f"Invalid `select` value for parameter {parameter.name}")
|
||||
raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == "number" and not isinstance(result.get(parameter.name), int | float):
|
||||
raise ValueError(f"Invalid `number` value for parameter {parameter.name}")
|
||||
raise InvalidNumberValueError(f"Invalid `number` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == "bool" and not isinstance(result.get(parameter.name), bool):
|
||||
raise ValueError(f"Invalid `bool` value for parameter {parameter.name}")
|
||||
raise InvalidBoolValueError(f"Invalid `bool` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == "string" and not isinstance(result.get(parameter.name), str):
|
||||
raise ValueError(f"Invalid `string` value for parameter {parameter.name}")
|
||||
raise InvalidStringValueError(f"Invalid `string` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type.startswith("array"):
|
||||
parameters = result.get(parameter.name)
|
||||
if not isinstance(parameters, list):
|
||||
raise ValueError(f"Invalid `array` value for parameter {parameter.name}")
|
||||
raise InvalidArrayValueError(f"Invalid `array` value for parameter {parameter.name}")
|
||||
nested_type = parameter.type[6:-1]
|
||||
for item in parameters:
|
||||
if nested_type == "number" and not isinstance(item, int | float):
|
||||
raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
|
||||
raise InvalidArrayValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
|
||||
if nested_type == "string" and not isinstance(item, str):
|
||||
raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
|
||||
raise InvalidArrayValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
|
||||
if nested_type == "object" and not isinstance(item, dict):
|
||||
raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
|
||||
raise InvalidArrayValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
|
||||
return result
|
||||
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||
|
|
@ -634,7 +649,7 @@ class ParameterExtractorNode(LLMNode):
|
|||
user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text)
|
||||
return [system_prompt_messages, user_prompt_message]
|
||||
else:
|
||||
raise ValueError(f"Model mode {model_mode} not support.")
|
||||
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
|
||||
|
||||
def _get_prompt_engineering_prompt_template(
|
||||
self,
|
||||
|
|
@ -669,7 +684,7 @@ class ParameterExtractorNode(LLMNode):
|
|||
.replace("}γγγ", "")
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Model mode {model_mode} not support.")
|
||||
raise InvalidModelModeError(f"Model mode {model_mode} not support.")
|
||||
|
||||
def _calculate_rest_token(
|
||||
self,
|
||||
|
|
@ -683,12 +698,12 @@ class ParameterExtractorNode(LLMNode):
|
|||
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||
raise ValueError("Model is not a Large Language Model")
|
||||
raise InvalidModelTypeError("Model is not a Large Language Model")
|
||||
|
||||
llm_model = model_instance.model_type_instance
|
||||
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
|
||||
if not model_schema:
|
||||
raise ValueError("Model schema not found")
|
||||
raise ModelSchemaNotFoundError("Model schema not found")
|
||||
|
||||
if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
|
||||
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ class ToolNode(BaseNode[ToolNodeData]):
|
|||
)
|
||||
|
||||
# get parameters
|
||||
tool_parameters = tool_runtime.get_runtime_parameters() or []
|
||||
tool_parameters = tool_runtime.parameters or []
|
||||
parameters = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ upload_config_fields = {
|
|||
"image_file_size_limit": fields.Integer,
|
||||
"video_file_size_limit": fields.Integer,
|
||||
"audio_file_size_limit": fields.Integer,
|
||||
"workflow_file_upload_limit": fields.Integer,
|
||||
}
|
||||
|
||||
file_fields = {
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ from .model import (
|
|||
AppMode,
|
||||
Conversation,
|
||||
EndUser,
|
||||
FileUploadConfig,
|
||||
InstalledApp,
|
||||
Message,
|
||||
MessageAnnotation,
|
||||
|
|
@ -50,6 +49,5 @@ __all__ = [
|
|||
"Tenant",
|
||||
"Conversation",
|
||||
"MessageAnnotation",
|
||||
"FileUploadConfig",
|
||||
"ToolFile",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, Optional
|
||||
|
|
@ -9,7 +9,6 @@ from typing import Any, Literal, Optional
|
|||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_login import UserMixin
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import Float, func, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
|
|
@ -25,14 +24,6 @@ from .account import Account, Tenant
|
|||
from .types import StringUUID
|
||||
|
||||
|
||||
class FileUploadConfig(BaseModel):
|
||||
enabled: bool = Field(default=False)
|
||||
allowed_file_types: Sequence[FileType] = Field(default_factory=list)
|
||||
allowed_extensions: Sequence[str] = Field(default_factory=list)
|
||||
allowed_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list)
|
||||
number_limits: int = Field(default=0, gt=0, le=10)
|
||||
|
||||
|
||||
class DifySetup(db.Model):
|
||||
__tablename__ = "dify_setups"
|
||||
__table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
|
||||
|
|
@ -115,7 +106,7 @@ class App(db.Model):
|
|||
return site
|
||||
|
||||
@property
|
||||
def app_model_config(self) -> Optional["AppModelConfig"]:
|
||||
def app_model_config(self):
|
||||
if self.app_model_config_id:
|
||||
return db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first()
|
||||
|
||||
|
|
@ -1307,7 +1298,7 @@ class Site(db.Model):
|
|||
privacy_policy = db.Column(db.String(255))
|
||||
show_workflow_steps = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="")
|
||||
_custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", sa.TEXT, default="")
|
||||
customize_domain = db.Column(db.String(255))
|
||||
customize_token_strategy = db.Column(db.String(255), nullable=False)
|
||||
prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
|
|
@ -1318,6 +1309,16 @@ class Site(db.Model):
|
|||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
code = db.Column(db.String(255))
|
||||
|
||||
@property
|
||||
def custom_disclaimer(self):
|
||||
return self._custom_disclaimer
|
||||
|
||||
@custom_disclaimer.setter
|
||||
def custom_disclaimer(self, value: str):
|
||||
if len(value) > 512:
|
||||
raise ValueError("Custom disclaimer cannot exceed 512 characters.")
|
||||
self._custom_disclaimer = value
|
||||
|
||||
@staticmethod
|
||||
def generate_code(n):
|
||||
while True:
|
||||
|
|
|
|||
|
|
@ -2532,6 +2532,19 @@ files = [
|
|||
{file = "filetype-1.2.0.tar.gz", hash = "sha256:66b56cd6474bf41d8c54660347d37afcc3f7d1970648de365c102ef77548aadb"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fire"
|
||||
version = "0.7.0"
|
||||
description = "A library for automatically generating command line interfaces."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "fire-0.7.0.tar.gz", hash = "sha256:961550f07936eaf65ad1dc8360f2b2bf8408fad46abbfa4d2a3794f8d2a95cdf"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
termcolor = "*"
|
||||
|
||||
[[package]]
|
||||
name = "flasgger"
|
||||
version = "0.9.7.1"
|
||||
|
|
@ -2697,6 +2710,19 @@ files = [
|
|||
{file = "flatbuffers-24.3.25.tar.gz", hash = "sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fontmeta"
|
||||
version = "1.6.1"
|
||||
description = "An Utility to get ttf/otf font metadata"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "fontmeta-1.6.1.tar.gz", hash = "sha256:837e5bc4da879394b41bda1428a8a480eb7c4e993799a93cfb582bab771a9c24"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
fonttools = "*"
|
||||
|
||||
[[package]]
|
||||
name = "fonttools"
|
||||
version = "4.54.1"
|
||||
|
|
@ -5279,6 +5305,22 @@ files = [
|
|||
{file = "monotonic-1.6.tar.gz", hash = "sha256:3a55207bcfed53ddd5c5bae174524062935efed17792e9de2ad0205ce9ad63f7"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mplfonts"
|
||||
version = "0.0.8"
|
||||
description = "Fonts manager for matplotlib"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "mplfonts-0.0.8-py3-none-any.whl", hash = "sha256:b2182e5b0baa216cf016dec19942740e5b48956415708ad2d465e03952112ec1"},
|
||||
{file = "mplfonts-0.0.8.tar.gz", hash = "sha256:0abcb2fc0605645e1e7561c6923014d856f11676899b33b4d89757843f5e7c22"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
fire = ">=0.4.0"
|
||||
fontmeta = ">=1.6.1"
|
||||
matplotlib = ">=3.4"
|
||||
|
||||
[[package]]
|
||||
name = "mpmath"
|
||||
version = "1.3.0"
|
||||
|
|
@ -9300,6 +9342,20 @@ files = [
|
|||
[package.dependencies]
|
||||
tencentcloud-sdk-python-common = "3.0.1257"
|
||||
|
||||
[[package]]
|
||||
name = "termcolor"
|
||||
version = "2.5.0"
|
||||
description = "ANSI color formatting for output in terminal"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "termcolor-2.5.0-py3-none-any.whl", hash = "sha256:37b17b5fc1e604945c2642c872a3764b5d547a48009871aea3edd3afa180afb8"},
|
||||
{file = "termcolor-2.5.0.tar.gz", hash = "sha256:998d8d27da6d48442e8e1f016119076b690d962507531df4890fcd2db2ef8a6f"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
tests = ["pytest", "pytest-cov"]
|
||||
|
||||
[[package]]
|
||||
name = "threadpoolctl"
|
||||
version = "3.5.0"
|
||||
|
|
@ -10046,13 +10102,13 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "vanna"
|
||||
version = "0.7.3"
|
||||
version = "0.7.5"
|
||||
description = "Generate SQL queries from natural language"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
files = [
|
||||
{file = "vanna-0.7.3-py3-none-any.whl", hash = "sha256:82ba39e5d6c503d1c8cca60835ed401d20ec3a3da98d487f529901dcb30061d6"},
|
||||
{file = "vanna-0.7.3.tar.gz", hash = "sha256:4590dd94d2fe180b4efc7a83c867b73144ef58794018910dc226857cfb703077"},
|
||||
{file = "vanna-0.7.5-py3-none-any.whl", hash = "sha256:07458c7befa49de517a8760c2d80a13147278b484c515d49a906acc88edcb835"},
|
||||
{file = "vanna-0.7.5.tar.gz", hash = "sha256:2fdffc58832898e4fc8e93c45b173424db59a22773b22ca348640161d391eacf"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
|
|
@ -10073,7 +10129,7 @@ sqlparse = "*"
|
|||
tabulate = "*"
|
||||
|
||||
[package.extras]
|
||||
all = ["PyMySQL", "anthropic", "azure-common", "azure-identity", "azure-search-documents", "chromadb", "db-dtypes", "duckdb", "fastembed", "google-cloud-aiplatform", "google-cloud-bigquery", "google-generativeai", "httpx", "marqo", "mistralai (>=1.0.0)", "ollama", "openai", "opensearch-dsl", "opensearch-py", "pinecone-client", "psycopg2-binary", "pymilvus[model]", "qdrant-client", "qianfan", "snowflake-connector-python", "transformers", "weaviate-client", "zhipuai"]
|
||||
all = ["PyMySQL", "anthropic", "azure-common", "azure-identity", "azure-search-documents", "boto", "boto3", "botocore", "chromadb", "db-dtypes", "duckdb", "faiss-cpu", "fastembed", "google-cloud-aiplatform", "google-cloud-bigquery", "google-generativeai", "httpx", "langchain_core", "langchain_postgres", "marqo", "mistralai (>=1.0.0)", "ollama", "openai", "opensearch-dsl", "opensearch-py", "pinecone-client", "psycopg2-binary", "pymilvus[model]", "qdrant-client", "qianfan", "snowflake-connector-python", "transformers", "weaviate-client", "xinference-client", "zhipuai"]
|
||||
anthropic = ["anthropic"]
|
||||
azuresearch = ["azure-common", "azure-identity", "azure-search-documents", "fastembed"]
|
||||
bedrock = ["boto3", "botocore"]
|
||||
|
|
@ -10081,6 +10137,8 @@ bigquery = ["google-cloud-bigquery"]
|
|||
chromadb = ["chromadb"]
|
||||
clickhouse = ["clickhouse_connect"]
|
||||
duckdb = ["duckdb"]
|
||||
faiss-cpu = ["faiss-cpu"]
|
||||
faiss-gpu = ["faiss-gpu"]
|
||||
gemini = ["google-generativeai"]
|
||||
google = ["google-cloud-aiplatform", "google-generativeai"]
|
||||
hf = ["transformers"]
|
||||
|
|
@ -10091,6 +10149,7 @@ mysql = ["PyMySQL"]
|
|||
ollama = ["httpx", "ollama"]
|
||||
openai = ["openai"]
|
||||
opensearch = ["opensearch-dsl", "opensearch-py"]
|
||||
pgvector = ["langchain-postgres (>=0.0.12)"]
|
||||
pinecone = ["fastembed", "pinecone-client"]
|
||||
postgres = ["db-dtypes", "psycopg2-binary"]
|
||||
qdrant = ["fastembed", "qdrant-client"]
|
||||
|
|
@ -10099,6 +10158,7 @@ snowflake = ["snowflake-connector-python"]
|
|||
test = ["tox"]
|
||||
vllm = ["vllm"]
|
||||
weaviate = ["weaviate-client"]
|
||||
xinference-client = ["xinference-client"]
|
||||
zhipuai = ["zhipuai"]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -10940,4 +11000,4 @@ cffi = ["cffi (>=1.11)"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.13"
|
||||
content-hash = "ef927b98c33d704d680e08db0e5c7d9a4e05454c66fcd6a5f656a65eb08e886b"
|
||||
content-hash = "e4794898403da4ad7b51f248a6c07632a949114c1b569406d3aa6a94c62510a5"
|
||||
|
|
|
|||
|
|
@ -206,13 +206,14 @@ cloudscraper = "1.2.71"
|
|||
duckduckgo-search = "~6.3.0"
|
||||
jsonpath-ng = "1.6.1"
|
||||
matplotlib = "~3.8.2"
|
||||
mplfonts = "~0.0.8"
|
||||
newspaper3k = "0.2.8"
|
||||
nltk = "3.9.1"
|
||||
numexpr = "~2.9.0"
|
||||
pydub = "~0.25.1"
|
||||
qrcode = "~7.4.2"
|
||||
twilio = "~9.0.4"
|
||||
vanna = { version = "0.7.3", extras = ["postgres", "mysql", "clickhouse", "duckdb"] }
|
||||
vanna = { version = "0.7.5", extras = ["postgres", "mysql", "clickhouse", "duckdb", "oracle"] }
|
||||
wikipedia = "1.4.0"
|
||||
yfinance = "~0.2.40"
|
||||
|
||||
|
|
|
|||
|
|
@ -95,3 +95,7 @@ GPUSTACK_API_KEY=
|
|||
|
||||
# Gitee AI Credentials
|
||||
GITEE_AI_API_KEY=
|
||||
|
||||
# xAI Credentials
|
||||
XAI_API_KEY=
|
||||
XAI_API_BASE=
|
||||
|
|
|
|||
|
|
@ -0,0 +1,204 @@
|
|||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.x.llm.llm import XAILargeLanguageModel
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
|
||||
|
||||
|
||||
def test_predefined_models():
|
||||
model = XAILargeLanguageModel()
|
||||
model_schemas = model.predefined_models()
|
||||
|
||||
assert len(model_schemas) >= 1
|
||||
assert isinstance(model_schemas[0], AIModelEntity)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_validate_credentials_for_chat_model(setup_openai_mock):
|
||||
model = XAILargeLanguageModel()
|
||||
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
# model name to gpt-3.5-turbo because of mocking
|
||||
model.validate_credentials(
|
||||
model="gpt-3.5-turbo",
|
||||
credentials={"api_key": "invalid_key", "endpoint_url": os.environ.get("XAI_API_BASE"), "mode": "chat"},
|
||||
)
|
||||
|
||||
model.validate_credentials(
|
||||
model="grok-beta",
|
||||
credentials={
|
||||
"api_key": os.environ.get("XAI_API_KEY"),
|
||||
"endpoint_url": os.environ.get("XAI_API_BASE"),
|
||||
"mode": "chat",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_chat_model(setup_openai_mock):
|
||||
model = XAILargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="grok-beta",
|
||||
credentials={
|
||||
"api_key": os.environ.get("XAI_API_KEY"),
|
||||
"endpoint_url": os.environ.get("XAI_API_BASE"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={
|
||||
"temperature": 0.0,
|
||||
"top_p": 1.0,
|
||||
"presence_penalty": 0.0,
|
||||
"frequency_penalty": 0.0,
|
||||
"max_tokens": 10,
|
||||
},
|
||||
stop=["How"],
|
||||
stream=False,
|
||||
user="foo",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert len(result.message.content) > 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_chat_model_with_tools(setup_openai_mock):
|
||||
model = XAILargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="grok-beta",
|
||||
credentials={
|
||||
"api_key": os.environ.get("XAI_API_KEY"),
|
||||
"endpoint_url": os.environ.get("XAI_API_BASE"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(
|
||||
content="what's the weather today in London?",
|
||||
),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name="get_weather",
|
||||
description="Determine weather in my location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
),
|
||||
PromptMessageTool(
|
||||
name="get_stock_price",
|
||||
description="Get the current stock price",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"symbol": {"type": "string", "description": "The stock symbol"}},
|
||||
"required": ["symbol"],
|
||||
},
|
||||
),
|
||||
],
|
||||
stream=False,
|
||||
user="foo",
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert isinstance(result.message, AssistantPromptMessage)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_openai_mock", [["chat"]], indirect=True)
|
||||
def test_invoke_stream_chat_model(setup_openai_mock):
|
||||
model = XAILargeLanguageModel()
|
||||
|
||||
result = model.invoke(
|
||||
model="grok-beta",
|
||||
credentials={
|
||||
"api_key": os.environ.get("XAI_API_KEY"),
|
||||
"endpoint_url": os.environ.get("XAI_API_BASE"),
|
||||
"mode": "chat",
|
||||
},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
model_parameters={"temperature": 0.0, "max_tokens": 100},
|
||||
stream=True,
|
||||
user="foo",
|
||||
)
|
||||
|
||||
assert isinstance(result, Generator)
|
||||
|
||||
for chunk in result:
|
||||
assert isinstance(chunk, LLMResultChunk)
|
||||
assert isinstance(chunk.delta, LLMResultChunkDelta)
|
||||
assert isinstance(chunk.delta.message, AssistantPromptMessage)
|
||||
assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True
|
||||
if chunk.delta.finish_reason is not None:
|
||||
assert chunk.delta.usage is not None
|
||||
assert chunk.delta.usage.completion_tokens > 0
|
||||
|
||||
|
||||
def test_get_num_tokens():
|
||||
model = XAILargeLanguageModel()
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="grok-beta",
|
||||
credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")},
|
||||
prompt_messages=[UserPromptMessage(content="Hello World!")],
|
||||
)
|
||||
|
||||
assert num_tokens == 10
|
||||
|
||||
num_tokens = model.get_num_tokens(
|
||||
model="grok-beta",
|
||||
credentials={"api_key": os.environ.get("XAI_API_KEY"), "endpoint_url": os.environ.get("XAI_API_BASE")},
|
||||
prompt_messages=[
|
||||
SystemPromptMessage(
|
||||
content="You are a helpful AI assistant.",
|
||||
),
|
||||
UserPromptMessage(content="Hello World!"),
|
||||
],
|
||||
tools=[
|
||||
PromptMessageTool(
|
||||
name="get_weather",
|
||||
description="Determine weather in my location",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
assert num_tokens == 77
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
import pytest
|
||||
|
||||
from core.app.app_config.entities import VariableEntity, VariableEntityType
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
|
||||
|
||||
def test_validate_inputs_with_zero():
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
var = VariableEntity(
|
||||
variable="test_var",
|
||||
label="test_var",
|
||||
type=VariableEntityType.NUMBER,
|
||||
required=True,
|
||||
)
|
||||
|
||||
# Test with input 0
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var,
|
||||
value=0,
|
||||
)
|
||||
|
||||
assert result == 0
|
||||
|
||||
# Test with input "0" (string)
|
||||
result = base_app_generator._validate_inputs(
|
||||
variable_entity=var,
|
||||
value="0",
|
||||
)
|
||||
|
||||
assert result == 0
|
||||
|
||||
|
||||
def test_validate_input_with_none_for_required_variable():
|
||||
base_app_generator = BaseAppGenerator()
|
||||
|
||||
for var_type in VariableEntityType:
|
||||
var = VariableEntity(
|
||||
variable="test_var",
|
||||
label="test_var",
|
||||
type=var_type,
|
||||
required=True,
|
||||
)
|
||||
|
||||
# Test with input None
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
base_app_generator._validate_inputs(
|
||||
variable_entity=var,
|
||||
value=None,
|
||||
)
|
||||
|
||||
assert str(exc_info.value) == "test_var is required in input form"
|
||||
|
|
@ -0,0 +1,198 @@
|
|||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.http_request import (
|
||||
BodyData,
|
||||
HttpRequestNodeAuthorization,
|
||||
HttpRequestNodeBody,
|
||||
HttpRequestNodeData,
|
||||
)
|
||||
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
|
||||
from core.workflow.nodes.http_request.executor import Executor
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_number_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "number"], 42)
|
||||
|
||||
# Prepare the node data
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Number Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
value='{"number": {{#pre_node_id.number#}}}',
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize the Executor
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Check the executor's data
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == {}
|
||||
assert executor.json == {"number": 42}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
assert executor.content is None
|
||||
|
||||
# Check the raw request (to_log method)
|
||||
raw_request = executor.to_log()
|
||||
assert "POST /data HTTP/1.1" in raw_request
|
||||
assert "Host: api.example.com" in raw_request
|
||||
assert "Content-Type: application/json" in raw_request
|
||||
assert '{"number": 42}' in raw_request
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_object_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||
|
||||
# Prepare the node data
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Object Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
value="{{#pre_node_id.object#}}",
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize the Executor
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Check the executor's data
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == {}
|
||||
assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
assert executor.content is None
|
||||
|
||||
# Check the raw request (to_log method)
|
||||
raw_request = executor.to_log()
|
||||
assert "POST /data HTTP/1.1" in raw_request
|
||||
assert "Host: api.example.com" in raw_request
|
||||
assert "Content-Type: application/json" in raw_request
|
||||
assert '"name": "John Doe"' in raw_request
|
||||
assert '"age": 30' in raw_request
|
||||
assert '"email": "john@example.com"' in raw_request
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_nested_object_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||
|
||||
# Prepare the node data
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Nested Object Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
value='{"object": {{#pre_node_id.object#}}}',
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize the Executor
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Check the executor's data
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == {}
|
||||
assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
assert executor.content is None
|
||||
|
||||
# Check the raw request (to_log method)
|
||||
raw_request = executor.to_log()
|
||||
assert "POST /data HTTP/1.1" in raw_request
|
||||
assert "Host: api.example.com" in raw_request
|
||||
assert "Content-Type: application/json" in raw_request
|
||||
assert '"object": {' in raw_request
|
||||
assert '"name": "John Doe"' in raw_request
|
||||
assert '"age": 30' in raw_request
|
||||
assert '"email": "john@example.com"' in raw_request
|
||||
|
||||
|
||||
def test_extract_selectors_from_template_with_newline():
|
||||
variable_pool = VariablePool()
|
||||
variable_pool.add(("node_id", "custom_query"), "line1\nline2")
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Nested Object Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="test: {{#node_id.custom_query#}}",
|
||||
body=HttpRequestNodeBody(
|
||||
type="none",
|
||||
data=[],
|
||||
),
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
assert executor.params == {"test": "line1\nline2"}
|
||||
|
|
@ -1,5 +1,3 @@
|
|||
import json
|
||||
|
||||
import httpx
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
|
@ -16,8 +14,7 @@ from core.workflow.nodes.http_request import (
|
|||
HttpRequestNodeBody,
|
||||
HttpRequestNodeData,
|
||||
)
|
||||
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
|
||||
from core.workflow.nodes.http_request.executor import Executor, _plain_text_to_dict
|
||||
from core.workflow.nodes.http_request.executor import _plain_text_to_dict
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
|
@ -203,167 +200,3 @@ def test_http_request_node_form_with_file(monkeypatch):
|
|||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["body"] == ""
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_number_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "number"], 42)
|
||||
|
||||
# Prepare the node data
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Number Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
value='{"number": {{#pre_node_id.number#}}}',
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize the Executor
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Check the executor's data
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == {}
|
||||
assert executor.json == {"number": 42}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
assert executor.content is None
|
||||
|
||||
# Check the raw request (to_log method)
|
||||
raw_request = executor.to_log()
|
||||
assert "POST /data HTTP/1.1" in raw_request
|
||||
assert "Host: api.example.com" in raw_request
|
||||
assert "Content-Type: application/json" in raw_request
|
||||
assert '{"number": 42}' in raw_request
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_object_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||
|
||||
# Prepare the node data
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Object Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
value="{{#pre_node_id.object#}}",
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize the Executor
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Check the executor's data
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == {}
|
||||
assert executor.json == {"name": "John Doe", "age": 30, "email": "john@example.com"}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
assert executor.content is None
|
||||
|
||||
# Check the raw request (to_log method)
|
||||
raw_request = executor.to_log()
|
||||
assert "POST /data HTTP/1.1" in raw_request
|
||||
assert "Host: api.example.com" in raw_request
|
||||
assert "Content-Type: application/json" in raw_request
|
||||
assert '"name": "John Doe"' in raw_request
|
||||
assert '"age": 30' in raw_request
|
||||
assert '"email": "john@example.com"' in raw_request
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_nested_object_variable():
|
||||
# Prepare the variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
)
|
||||
variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"})
|
||||
|
||||
# Prepare the node data
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Test JSON Body with Nested Object Variable",
|
||||
method="post",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="Content-Type: application/json",
|
||||
params="",
|
||||
body=HttpRequestNodeBody(
|
||||
type="json",
|
||||
data=[
|
||||
BodyData(
|
||||
key="",
|
||||
type="text",
|
||||
value='{"object": {{#pre_node_id.object#}}}',
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Initialize the Executor
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Check the executor's data
|
||||
assert executor.method == "post"
|
||||
assert executor.url == "https://api.example.com/data"
|
||||
assert executor.headers == {"Content-Type": "application/json"}
|
||||
assert executor.params == {}
|
||||
assert executor.json == {"object": {"name": "John Doe", "age": 30, "email": "john@example.com"}}
|
||||
assert executor.data is None
|
||||
assert executor.files is None
|
||||
assert executor.content is None
|
||||
|
||||
# Check the raw request (to_log method)
|
||||
raw_request = executor.to_log()
|
||||
assert "POST /data HTTP/1.1" in raw_request
|
||||
assert "Host: api.example.com" in raw_request
|
||||
assert "Content-Type: application/json" in raw_request
|
||||
assert '"object": {' in raw_request
|
||||
assert '"name": "John Doe"' in raw_request
|
||||
assert '"age": 30' in raw_request
|
||||
assert '"email": "john@example.com"' in raw_request
|
||||
|
|
@ -10,6 +10,7 @@ from core.workflow.graph_engine.entities.graph import Graph
|
|||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.event import RunCompletedEvent
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode
|
||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from models.enums import UserFrom
|
||||
|
|
@ -185,8 +186,6 @@ def test_run():
|
|||
outputs={"output": "dify 123"},
|
||||
)
|
||||
|
||||
# print("")
|
||||
|
||||
with patch.object(TemplateTransformNode, "_run", new=tt_generator):
|
||||
# execute node
|
||||
result = iteration_node._run()
|
||||
|
|
@ -404,18 +403,458 @@ def test_run_parallel():
|
|||
outputs={"output": "dify 123"},
|
||||
)
|
||||
|
||||
# print("")
|
||||
|
||||
with patch.object(TemplateTransformNode, "_run", new=tt_generator):
|
||||
# execute node
|
||||
result = iteration_node._run()
|
||||
|
||||
count = 0
|
||||
for item in result:
|
||||
# print(type(item), item)
|
||||
count += 1
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
|
||||
|
||||
assert count == 32
|
||||
|
||||
|
||||
def test_iteration_run_in_parallel_mode():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-pe-target",
|
||||
"source": "start",
|
||||
"target": "pe",
|
||||
},
|
||||
{
|
||||
"id": "iteration-1-source-answer-3-target",
|
||||
"source": "iteration-1",
|
||||
"target": "answer-3",
|
||||
},
|
||||
{
|
||||
"id": "iteration-start-source-tt-target",
|
||||
"source": "iteration-start",
|
||||
"target": "tt",
|
||||
},
|
||||
{
|
||||
"id": "iteration-start-source-tt-2-target",
|
||||
"source": "iteration-start",
|
||||
"target": "tt-2",
|
||||
},
|
||||
{
|
||||
"id": "tt-source-if-else-target",
|
||||
"source": "tt",
|
||||
"target": "if-else",
|
||||
},
|
||||
{
|
||||
"id": "tt-2-source-if-else-target",
|
||||
"source": "tt-2",
|
||||
"target": "if-else",
|
||||
},
|
||||
{
|
||||
"id": "if-else-true-answer-2-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "true",
|
||||
"target": "answer-2",
|
||||
},
|
||||
{
|
||||
"id": "if-else-false-answer-4-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "false",
|
||||
"target": "answer-4",
|
||||
},
|
||||
{
|
||||
"id": "pe-source-iteration-1-target",
|
||||
"source": "pe",
|
||||
"target": "iteration-1",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "iteration",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"answer": "{{#tt.output#}}",
|
||||
"iteration_id": "iteration-1",
|
||||
"title": "answer 2",
|
||||
"type": "answer",
|
||||
},
|
||||
"id": "answer-2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"title": "iteration-start",
|
||||
"type": "iteration-start",
|
||||
},
|
||||
"id": "iteration-start",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"template": "{{ arg1 }} 123",
|
||||
"title": "template transform",
|
||||
"type": "template-transform",
|
||||
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
|
||||
},
|
||||
"id": "tt",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"template": "{{ arg1 }} 321",
|
||||
"title": "template transform",
|
||||
"type": "template-transform",
|
||||
"variables": [{"value_selector": ["sys", "query"], "variable": "arg1"}],
|
||||
},
|
||||
"id": "tt-2",
|
||||
},
|
||||
{
|
||||
"data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"},
|
||||
"id": "answer-3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "is",
|
||||
"id": "1721916275284",
|
||||
"value": "hi",
|
||||
"variable_selector": ["sys", "query"],
|
||||
}
|
||||
],
|
||||
"iteration_id": "iteration-1",
|
||||
"logical_operator": "and",
|
||||
"title": "if",
|
||||
"type": "if-else",
|
||||
},
|
||||
"id": "if-else",
|
||||
},
|
||||
{
|
||||
"data": {"answer": "no hi", "iteration_id": "iteration-1", "title": "answer 4", "type": "answer"},
|
||||
"id": "answer-4",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"instruction": "test1",
|
||||
"model": {
|
||||
"completion_params": {"temperature": 0.7},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai",
|
||||
},
|
||||
"parameters": [
|
||||
{"description": "test", "name": "list_output", "required": False, "type": "array[string]"}
|
||||
],
|
||||
"query": ["sys", "query"],
|
||||
"reasoning_mode": "prompt",
|
||||
"title": "pe",
|
||||
"type": "parameter-extractor",
|
||||
},
|
||||
"id": "pe",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "dify",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "1",
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
|
||||
|
||||
parallel_iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
"is_parallel": True,
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
)
|
||||
sequential_iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "迭代",
|
||||
"type": "iteration",
|
||||
"is_parallel": True,
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
)
|
||||
|
||||
def tt_generator(self):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"iterator_selector": "dify"},
|
||||
outputs={"output": "dify 123"},
|
||||
)
|
||||
|
||||
with patch.object(TemplateTransformNode, "_run", new=tt_generator):
|
||||
# execute node
|
||||
parallel_result = parallel_iteration_node._run()
|
||||
sequential_result = sequential_iteration_node._run()
|
||||
assert parallel_iteration_node.node_data.parallel_nums == 10
|
||||
assert parallel_iteration_node.node_data.error_handle_mode == ErrorHandleMode.TERMINATED
|
||||
count = 0
|
||||
parallel_arr = []
|
||||
sequential_arr = []
|
||||
for item in parallel_result:
|
||||
count += 1
|
||||
parallel_arr.append(item)
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
|
||||
assert count == 32
|
||||
|
||||
for item in sequential_result:
|
||||
sequential_arr.append(item)
|
||||
count += 1
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs == {"output": ["dify 123", "dify 123"]}
|
||||
assert count == 64
|
||||
|
||||
|
||||
def test_iteration_run_error_handle():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-pe-target",
|
||||
"source": "start",
|
||||
"target": "pe",
|
||||
},
|
||||
{
|
||||
"id": "iteration-1-source-answer-3-target",
|
||||
"source": "iteration-1",
|
||||
"target": "answer-3",
|
||||
},
|
||||
{
|
||||
"id": "tt-source-if-else-target",
|
||||
"source": "iteration-start",
|
||||
"target": "if-else",
|
||||
},
|
||||
{
|
||||
"id": "if-else-true-answer-2-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "true",
|
||||
"target": "tt",
|
||||
},
|
||||
{
|
||||
"id": "if-else-false-answer-4-target",
|
||||
"source": "if-else",
|
||||
"sourceHandle": "false",
|
||||
"target": "tt2",
|
||||
},
|
||||
{
|
||||
"id": "pe-source-iteration-1-target",
|
||||
"source": "pe",
|
||||
"target": "iteration-1",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt2", "output"],
|
||||
"output_type": "array[string]",
|
||||
"start_node_id": "if-else",
|
||||
"title": "iteration",
|
||||
"type": "iteration",
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"template": "{{ arg1.split(arg2) }}",
|
||||
"title": "template transform",
|
||||
"type": "template-transform",
|
||||
"variables": [
|
||||
{"value_selector": ["iteration-1", "item"], "variable": "arg1"},
|
||||
{"value_selector": ["iteration-1", "index"], "variable": "arg2"},
|
||||
],
|
||||
},
|
||||
"id": "tt",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"template": "{{ arg1 }}",
|
||||
"title": "template transform",
|
||||
"type": "template-transform",
|
||||
"variables": [
|
||||
{"value_selector": ["iteration-1", "item"], "variable": "arg1"},
|
||||
],
|
||||
},
|
||||
"id": "tt2",
|
||||
},
|
||||
{
|
||||
"data": {"answer": "{{#iteration-1.output#}}88888", "title": "answer 3", "type": "answer"},
|
||||
"id": "answer-3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"iteration_id": "iteration-1",
|
||||
"title": "iteration-start",
|
||||
"type": "iteration-start",
|
||||
},
|
||||
"id": "iteration-start",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "is",
|
||||
"id": "1721916275284",
|
||||
"value": "1",
|
||||
"variable_selector": ["iteration-1", "item"],
|
||||
}
|
||||
],
|
||||
"iteration_id": "iteration-1",
|
||||
"logical_operator": "and",
|
||||
"title": "if",
|
||||
"type": "if-else",
|
||||
},
|
||||
"id": "if-else",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"instruction": "test1",
|
||||
"model": {
|
||||
"completion_params": {"temperature": 0.7},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai",
|
||||
},
|
||||
"parameters": [
|
||||
{"description": "test", "name": "list_output", "required": False, "type": "array[string]"}
|
||||
],
|
||||
"query": ["sys", "query"],
|
||||
"reasoning_mode": "prompt",
|
||||
"title": "pe",
|
||||
"type": "parameter-extractor",
|
||||
},
|
||||
"id": "pe",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "dify",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "1",
|
||||
},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
pool.add(["pe", "list_output"], ["1", "1"])
|
||||
iteration_node = IterationNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
|
||||
config={
|
||||
"data": {
|
||||
"iterator_selector": ["pe", "list_output"],
|
||||
"output_selector": ["tt", "output"],
|
||||
"output_type": "array[string]",
|
||||
"startNodeType": "template-transform",
|
||||
"start_node_id": "iteration-start",
|
||||
"title": "iteration",
|
||||
"type": "iteration",
|
||||
"is_parallel": True,
|
||||
"error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
|
||||
},
|
||||
"id": "iteration-1",
|
||||
},
|
||||
)
|
||||
# execute continue on error node
|
||||
result = iteration_node._run()
|
||||
result_arr = []
|
||||
count = 0
|
||||
for item in result:
|
||||
result_arr.append(item)
|
||||
count += 1
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs == {"output": [None, None]}
|
||||
|
||||
assert count == 14
|
||||
# execute remove abnormal output
|
||||
iteration_node.node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
|
||||
result = iteration_node._run()
|
||||
count = 0
|
||||
for item in result:
|
||||
count += 1
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert item.run_result.outputs == {"output": []}
|
||||
assert count == 14
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ from unittest.mock import MagicMock
|
|||
|
||||
import pytest
|
||||
|
||||
from core.file import File
|
||||
from core.file.models import FileTransferMethod, FileType
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.variables import ArrayFileSegment
|
||||
from core.workflow.nodes.list_operator.entities import FilterBy, FilterCondition, Limit, ListOperatorNodeData, OrderBy
|
||||
from core.workflow.nodes.list_operator.node import ListOperatorNode
|
||||
from core.workflow.nodes.list_operator.exc import InvalidKeyError
|
||||
from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
|
|
@ -109,3 +109,46 @@ def test_filter_files_by_type(list_operator_node):
|
|||
assert expected_file["tenant_id"] == result_file.tenant_id
|
||||
assert expected_file["transfer_method"] == result_file.transfer_method
|
||||
assert expected_file["related_id"] == result_file.related_id
|
||||
|
||||
|
||||
def test_get_file_extract_string_func():
|
||||
# Create a File object
|
||||
file = File(
|
||||
tenant_id="test_tenant",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
filename="test_file.txt",
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
remote_url="https://example.com/test_file.txt",
|
||||
related_id="test_related_id",
|
||||
)
|
||||
|
||||
# Test each case
|
||||
assert _get_file_extract_string_func(key="name")(file) == "test_file.txt"
|
||||
assert _get_file_extract_string_func(key="type")(file) == "document"
|
||||
assert _get_file_extract_string_func(key="extension")(file) == ".txt"
|
||||
assert _get_file_extract_string_func(key="mime_type")(file) == "text/plain"
|
||||
assert _get_file_extract_string_func(key="transfer_method")(file) == "local_file"
|
||||
assert _get_file_extract_string_func(key="url")(file) == "https://example.com/test_file.txt"
|
||||
|
||||
# Test with empty values
|
||||
empty_file = File(
|
||||
tenant_id="test_tenant",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
filename=None,
|
||||
extension=None,
|
||||
mime_type=None,
|
||||
remote_url=None,
|
||||
related_id="test_related_id",
|
||||
)
|
||||
|
||||
assert _get_file_extract_string_func(key="name")(empty_file) == ""
|
||||
assert _get_file_extract_string_func(key="extension")(empty_file) == ""
|
||||
assert _get_file_extract_string_func(key="mime_type")(empty_file) == ""
|
||||
assert _get_file_extract_string_func(key="url")(empty_file) == ""
|
||||
|
||||
# Test invalid key
|
||||
with pytest.raises(InvalidKeyError):
|
||||
_get_file_extract_string_func(key="invalid_key")
|
||||
|
|
|
|||
|
|
@ -9,10 +9,10 @@ if ! command -v ruff &> /dev/null || ! command -v dotenv-linter &> /dev/null; th
|
|||
fi
|
||||
|
||||
# run ruff linter
|
||||
ruff check --fix ./api
|
||||
poetry run -C api ruff check --fix ./api
|
||||
|
||||
# run ruff formatter
|
||||
ruff format ./api
|
||||
poetry run -C api ruff format ./api
|
||||
|
||||
# run dotenv-linter linter
|
||||
dotenv-linter ./api/.env.example ./web/.env.example
|
||||
poetry run -C api dotenv-linter ./api/.env.example ./web/.env.example
|
||||
|
|
|
|||
|
|
@ -690,6 +690,7 @@ WORKFLOW_MAX_EXECUTION_STEPS=500
|
|||
WORKFLOW_MAX_EXECUTION_TIME=1200
|
||||
WORKFLOW_CALL_MAX_DEPTH=5
|
||||
MAX_VARIABLE_SIZE=204800
|
||||
WORKFLOW_FILE_UPLOAD_LIMIT=10
|
||||
|
||||
# HTTP request node in workflow configuration
|
||||
HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
x-shared-env: &shared-api-worker-env
|
||||
WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10}
|
||||
LOG_LEVEL: ${LOG_LEVEL:-INFO}
|
||||
LOG_FILE: ${LOG_FILE:-}
|
||||
LOG_FILE_MAX_SIZE: ${LOG_FILE_MAX_SIZE:-20}
|
||||
|
|
|
|||
|
|
@ -23,8 +23,9 @@ export default function AppSelector() {
|
|||
params: {},
|
||||
})
|
||||
|
||||
if (localStorage?.getItem('console_token'))
|
||||
localStorage.removeItem('console_token')
|
||||
localStorage.removeItem('setup_status')
|
||||
localStorage.removeItem('console_token')
|
||||
localStorage.removeItem('refresh_token')
|
||||
|
||||
router.push('/signin')
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,5 +3,6 @@ export const IMG_SIZE_LIMIT = 10 * 1024 * 1024
|
|||
export const FILE_SIZE_LIMIT = 15 * 1024 * 1024
|
||||
export const AUDIO_SIZE_LIMIT = 50 * 1024 * 1024
|
||||
export const VIDEO_SIZE_LIMIT = 100 * 1024 * 1024
|
||||
export const MAX_FILE_UPLOAD_LIMIT = 10
|
||||
|
||||
export const FILE_URL_REGEX = /^(https?|ftp):\/\//
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import {
|
|||
AUDIO_SIZE_LIMIT,
|
||||
FILE_SIZE_LIMIT,
|
||||
IMG_SIZE_LIMIT,
|
||||
MAX_FILE_UPLOAD_LIMIT,
|
||||
VIDEO_SIZE_LIMIT,
|
||||
} from '@/app/components/base/file-uploader/constants'
|
||||
import { useToastContext } from '@/app/components/base/toast'
|
||||
|
|
@ -33,12 +34,14 @@ export const useFileSizeLimit = (fileUploadConfig?: FileUploadConfigResponse) =>
|
|||
const docSizeLimit = Number(fileUploadConfig?.file_size_limit) * 1024 * 1024 || FILE_SIZE_LIMIT
|
||||
const audioSizeLimit = Number(fileUploadConfig?.audio_file_size_limit) * 1024 * 1024 || AUDIO_SIZE_LIMIT
|
||||
const videoSizeLimit = Number(fileUploadConfig?.video_file_size_limit) * 1024 * 1024 || VIDEO_SIZE_LIMIT
|
||||
const maxFileUploadLimit = Number(fileUploadConfig?.workflow_file_upload_limit) || MAX_FILE_UPLOAD_LIMIT
|
||||
|
||||
return {
|
||||
imgSizeLimit,
|
||||
docSizeLimit,
|
||||
audioSizeLimit,
|
||||
videoSizeLimit,
|
||||
maxFileUploadLimit,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ const Select: FC<ISelectProps> = ({
|
|||
</Combobox.Button>
|
||||
</div>
|
||||
|
||||
{filteredItems.length > 0 && (
|
||||
{(filteredItems.length > 0 && open) && (
|
||||
<Combobox.Options className={`absolute z-10 mt-1 px-1 max-h-60 w-full overflow-auto rounded-md bg-white py-1 text-base shadow-lg border-gray-200 border-[0.5px] focus:outline-none sm:text-sm ${overlayClassName}`}>
|
||||
{filteredItems.map((item: Item) => (
|
||||
<Combobox.Option
|
||||
|
|
|
|||
|
|
@ -47,8 +47,9 @@ export default function AppSelector({ isMobile }: IAppSelector) {
|
|||
params: {},
|
||||
})
|
||||
|
||||
if (localStorage?.getItem('console_token'))
|
||||
localStorage.removeItem('console_token')
|
||||
localStorage.removeItem('setup_status')
|
||||
localStorage.removeItem('console_token')
|
||||
localStorage.removeItem('refresh_token')
|
||||
|
||||
router.push('/signin')
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import { SWRConfig } from 'swr'
|
|||
import { useCallback, useEffect, useState } from 'react'
|
||||
import type { ReactNode } from 'react'
|
||||
import { usePathname, useRouter, useSearchParams } from 'next/navigation'
|
||||
import useRefreshToken from '@/hooks/use-refresh-token'
|
||||
import { fetchSetupStatus } from '@/service/common'
|
||||
|
||||
type SwrInitorProps = {
|
||||
|
|
@ -15,12 +14,11 @@ const SwrInitor = ({
|
|||
}: SwrInitorProps) => {
|
||||
const router = useRouter()
|
||||
const searchParams = useSearchParams()
|
||||
const pathname = usePathname()
|
||||
const { getNewAccessToken } = useRefreshToken()
|
||||
const consoleToken = searchParams.get('access_token')
|
||||
const refreshToken = searchParams.get('refresh_token')
|
||||
const consoleToken = decodeURIComponent(searchParams.get('access_token') || '')
|
||||
const refreshToken = decodeURIComponent(searchParams.get('refresh_token') || '')
|
||||
const consoleTokenFromLocalStorage = localStorage?.getItem('console_token')
|
||||
const refreshTokenFromLocalStorage = localStorage?.getItem('refresh_token')
|
||||
const pathname = usePathname()
|
||||
const [init, setInit] = useState(false)
|
||||
|
||||
const isSetupFinished = useCallback(async () => {
|
||||
|
|
@ -41,25 +39,6 @@ const SwrInitor = ({
|
|||
}
|
||||
}, [])
|
||||
|
||||
const setRefreshToken = useCallback(async () => {
|
||||
try {
|
||||
if (!(consoleToken || refreshToken || consoleTokenFromLocalStorage || refreshTokenFromLocalStorage))
|
||||
return Promise.reject(new Error('No token found'))
|
||||
|
||||
if (consoleTokenFromLocalStorage && refreshTokenFromLocalStorage)
|
||||
await getNewAccessToken()
|
||||
|
||||
if (consoleToken && refreshToken) {
|
||||
localStorage.setItem('console_token', consoleToken)
|
||||
localStorage.setItem('refresh_token', refreshToken)
|
||||
await getNewAccessToken()
|
||||
}
|
||||
}
|
||||
catch (error) {
|
||||
return Promise.reject(error)
|
||||
}
|
||||
}, [consoleToken, refreshToken, consoleTokenFromLocalStorage, refreshTokenFromLocalStorage, getNewAccessToken])
|
||||
|
||||
useEffect(() => {
|
||||
(async () => {
|
||||
try {
|
||||
|
|
@ -68,9 +47,15 @@ const SwrInitor = ({
|
|||
router.replace('/install')
|
||||
return
|
||||
}
|
||||
await setRefreshToken()
|
||||
if (searchParams.has('access_token') || searchParams.has('refresh_token'))
|
||||
if (!((consoleToken && refreshToken) || (consoleTokenFromLocalStorage && refreshTokenFromLocalStorage))) {
|
||||
router.replace('/signin')
|
||||
return
|
||||
}
|
||||
if (searchParams.has('access_token') || searchParams.has('refresh_token')) {
|
||||
consoleToken && localStorage.setItem('console_token', consoleToken)
|
||||
refreshToken && localStorage.setItem('refresh_token', refreshToken)
|
||||
router.replace(pathname)
|
||||
}
|
||||
|
||||
setInit(true)
|
||||
}
|
||||
|
|
@ -78,7 +63,7 @@ const SwrInitor = ({
|
|||
router.replace('/signin')
|
||||
}
|
||||
})()
|
||||
}, [isSetupFinished, setRefreshToken, router, pathname, searchParams])
|
||||
}, [isSetupFinished, router, pathname, searchParams, consoleToken, refreshToken, consoleTokenFromLocalStorage, refreshTokenFromLocalStorage])
|
||||
|
||||
return init
|
||||
? (
|
||||
|
|
|
|||
|
|
@ -340,7 +340,9 @@ export const NODES_INITIAL_DATA = {
|
|||
...ListFilterDefault.defaultValue,
|
||||
},
|
||||
}
|
||||
|
||||
export const MAX_ITERATION_PARALLEL_NUM = 10
|
||||
export const MIN_ITERATION_PARALLEL_NUM = 1
|
||||
export const DEFAULT_ITER_TIMES = 1
|
||||
export const NODE_WIDTH = 240
|
||||
export const X_OFFSET = 60
|
||||
export const NODE_WIDTH_X_OFFSET = NODE_WIDTH + X_OFFSET
|
||||
|
|
|
|||
|
|
@ -644,6 +644,11 @@ export const useNodesInteractions = () => {
|
|||
newNode.data.isInIteration = true
|
||||
newNode.data.iteration_id = prevNode.parentId
|
||||
newNode.zIndex = ITERATION_CHILDREN_Z_INDEX
|
||||
if (newNode.data.type === BlockEnum.Answer || newNode.data.type === BlockEnum.Tool || newNode.data.type === BlockEnum.Assigner) {
|
||||
const parentIterNodeIndex = nodes.findIndex(node => node.id === prevNode.parentId)
|
||||
const iterNodeData: IterationNodeType = nodes[parentIterNodeIndex].data
|
||||
iterNodeData._isShowTips = true
|
||||
}
|
||||
}
|
||||
|
||||
const newEdge: Edge = {
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import {
|
|||
NodeRunningStatus,
|
||||
WorkflowRunningStatus,
|
||||
} from '../types'
|
||||
import { DEFAULT_ITER_TIMES } from '../constants'
|
||||
import { useWorkflowUpdate } from './use-workflow-interactions'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import type { IOtherOptions } from '@/service/base'
|
||||
|
|
@ -170,11 +171,13 @@ export const useWorkflowRun = () => {
|
|||
const {
|
||||
workflowRunningData,
|
||||
setWorkflowRunningData,
|
||||
setIterParallelLogMap,
|
||||
} = workflowStore.getState()
|
||||
const {
|
||||
edges,
|
||||
setEdges,
|
||||
} = store.getState()
|
||||
setIterParallelLogMap(new Map())
|
||||
setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
|
||||
draft.task_id = task_id
|
||||
draft.result = {
|
||||
|
|
@ -244,6 +247,8 @@ export const useWorkflowRun = () => {
|
|||
const {
|
||||
workflowRunningData,
|
||||
setWorkflowRunningData,
|
||||
iterParallelLogMap,
|
||||
setIterParallelLogMap,
|
||||
} = workflowStore.getState()
|
||||
const {
|
||||
getNodes,
|
||||
|
|
@ -259,10 +264,21 @@ export const useWorkflowRun = () => {
|
|||
const tracing = draft.tracing!
|
||||
const iterations = tracing.find(trace => trace.node_id === node?.parentId)
|
||||
const currIteration = iterations?.details![node.data.iteration_index] || iterations?.details![iterations.details!.length - 1]
|
||||
currIteration?.push({
|
||||
...data,
|
||||
status: NodeRunningStatus.Running,
|
||||
} as any)
|
||||
if (!data.parallel_run_id) {
|
||||
currIteration?.push({
|
||||
...data,
|
||||
status: NodeRunningStatus.Running,
|
||||
} as any)
|
||||
}
|
||||
else {
|
||||
if (!iterParallelLogMap.has(data.parallel_run_id))
|
||||
iterParallelLogMap.set(data.parallel_run_id, [{ ...data, status: NodeRunningStatus.Running } as any])
|
||||
else
|
||||
iterParallelLogMap.get(data.parallel_run_id)!.push({ ...data, status: NodeRunningStatus.Running } as any)
|
||||
setIterParallelLogMap(iterParallelLogMap)
|
||||
if (iterations)
|
||||
iterations.details = Array.from(iterParallelLogMap.values())
|
||||
}
|
||||
}))
|
||||
}
|
||||
else {
|
||||
|
|
@ -309,6 +325,8 @@ export const useWorkflowRun = () => {
|
|||
const {
|
||||
workflowRunningData,
|
||||
setWorkflowRunningData,
|
||||
iterParallelLogMap,
|
||||
setIterParallelLogMap,
|
||||
} = workflowStore.getState()
|
||||
const {
|
||||
getNodes,
|
||||
|
|
@ -317,21 +335,21 @@ export const useWorkflowRun = () => {
|
|||
const nodes = getNodes()
|
||||
const nodeParentId = nodes.find(node => node.id === data.node_id)!.parentId
|
||||
if (nodeParentId) {
|
||||
setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
|
||||
const tracing = draft.tracing!
|
||||
const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node
|
||||
if (!data.execution_metadata.parallel_mode_run_id) {
|
||||
setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
|
||||
const tracing = draft.tracing!
|
||||
const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node
|
||||
|
||||
if (iterations && iterations.details) {
|
||||
const iterationIndex = data.execution_metadata?.iteration_index || 0
|
||||
if (!iterations.details[iterationIndex])
|
||||
iterations.details[iterationIndex] = []
|
||||
if (iterations && iterations.details) {
|
||||
const iterationIndex = data.execution_metadata?.iteration_index || 0
|
||||
if (!iterations.details[iterationIndex])
|
||||
iterations.details[iterationIndex] = []
|
||||
|
||||
const currIteration = iterations.details[iterationIndex]
|
||||
const nodeIndex = currIteration.findIndex(node =>
|
||||
node.node_id === data.node_id && (
|
||||
node.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || node.parallel_id === data.execution_metadata?.parallel_id),
|
||||
)
|
||||
if (data.status === NodeRunningStatus.Succeeded) {
|
||||
const currIteration = iterations.details[iterationIndex]
|
||||
const nodeIndex = currIteration.findIndex(node =>
|
||||
node.node_id === data.node_id && (
|
||||
node.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || node.parallel_id === data.execution_metadata?.parallel_id),
|
||||
)
|
||||
if (nodeIndex !== -1) {
|
||||
currIteration[nodeIndex] = {
|
||||
...currIteration[nodeIndex],
|
||||
|
|
@ -344,8 +362,40 @@ export const useWorkflowRun = () => {
|
|||
} as any)
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
}))
|
||||
}
|
||||
else {
|
||||
// open parallel mode
|
||||
setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
|
||||
const tracing = draft.tracing!
|
||||
const iterations = tracing.find(trace => trace.node_id === nodeParentId) // the iteration node
|
||||
|
||||
if (iterations && iterations.details) {
|
||||
const iterRunID = data.execution_metadata?.parallel_mode_run_id
|
||||
|
||||
const currIteration = iterParallelLogMap.get(iterRunID)
|
||||
const nodeIndex = currIteration?.findIndex(node =>
|
||||
node.node_id === data.node_id && (
|
||||
node?.parallel_run_id === data.execution_metadata?.parallel_mode_run_id),
|
||||
)
|
||||
if (currIteration) {
|
||||
if (nodeIndex !== undefined && nodeIndex !== -1) {
|
||||
currIteration[nodeIndex] = {
|
||||
...currIteration[nodeIndex],
|
||||
...data,
|
||||
} as any
|
||||
}
|
||||
else {
|
||||
currIteration.push({
|
||||
...data,
|
||||
} as any)
|
||||
}
|
||||
}
|
||||
setIterParallelLogMap(iterParallelLogMap)
|
||||
iterations.details = Array.from(iterParallelLogMap.values())
|
||||
}
|
||||
}))
|
||||
}
|
||||
}
|
||||
else {
|
||||
setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
|
||||
|
|
@ -379,6 +429,7 @@ export const useWorkflowRun = () => {
|
|||
const {
|
||||
workflowRunningData,
|
||||
setWorkflowRunningData,
|
||||
setIterTimes,
|
||||
} = workflowStore.getState()
|
||||
const {
|
||||
getNodes,
|
||||
|
|
@ -388,6 +439,7 @@ export const useWorkflowRun = () => {
|
|||
transform,
|
||||
} = store.getState()
|
||||
const nodes = getNodes()
|
||||
setIterTimes(DEFAULT_ITER_TIMES)
|
||||
setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
|
||||
draft.tracing!.push({
|
||||
...data,
|
||||
|
|
@ -431,6 +483,8 @@ export const useWorkflowRun = () => {
|
|||
const {
|
||||
workflowRunningData,
|
||||
setWorkflowRunningData,
|
||||
iterTimes,
|
||||
setIterTimes,
|
||||
} = workflowStore.getState()
|
||||
|
||||
const { data } = params
|
||||
|
|
@ -445,13 +499,14 @@ export const useWorkflowRun = () => {
|
|||
if (iteration.details!.length >= iteration.metadata.iterator_length!)
|
||||
return
|
||||
}
|
||||
iteration?.details!.push([])
|
||||
if (!data.parallel_mode_run_id)
|
||||
iteration?.details!.push([])
|
||||
}))
|
||||
const nodes = getNodes()
|
||||
const newNodes = produce(nodes, (draft) => {
|
||||
const currentNode = draft.find(node => node.id === data.node_id)!
|
||||
|
||||
currentNode.data._iterationIndex = data.index > 0 ? data.index : 1
|
||||
currentNode.data._iterationIndex = iterTimes
|
||||
setIterTimes(iterTimes + 1)
|
||||
})
|
||||
setNodes(newNodes)
|
||||
|
||||
|
|
@ -464,6 +519,7 @@ export const useWorkflowRun = () => {
|
|||
const {
|
||||
workflowRunningData,
|
||||
setWorkflowRunningData,
|
||||
setIterTimes,
|
||||
} = workflowStore.getState()
|
||||
const {
|
||||
getNodes,
|
||||
|
|
@ -480,7 +536,7 @@ export const useWorkflowRun = () => {
|
|||
})
|
||||
}
|
||||
}))
|
||||
|
||||
setIterTimes(DEFAULT_ITER_TIMES)
|
||||
const newNodes = produce(nodes, (draft) => {
|
||||
const currentNode = draft.find(node => node.id === data.node_id)!
|
||||
|
||||
|
|
|
|||
|
|
@ -12,15 +12,15 @@ import Tooltip from '@/app/components/base/tooltip'
|
|||
type Props = {
|
||||
className?: string
|
||||
title: JSX.Element | string | DefaultTFuncReturn
|
||||
tooltip?: React.ReactNode
|
||||
isSubTitle?: boolean
|
||||
tooltip?: string
|
||||
supportFold?: boolean
|
||||
children?: JSX.Element | string | null
|
||||
operations?: JSX.Element
|
||||
inline?: boolean
|
||||
}
|
||||
|
||||
const Filed: FC<Props> = ({
|
||||
const Field: FC<Props> = ({
|
||||
className,
|
||||
title,
|
||||
isSubTitle,
|
||||
|
|
@ -60,4 +60,4 @@ const Filed: FC<Props> = ({
|
|||
</div>
|
||||
)
|
||||
}
|
||||
export default React.memo(Filed)
|
||||
export default React.memo(Field)
|
||||
|
|
|
|||
|
|
@ -39,7 +39,13 @@ const FileUploadSetting: FC<Props> = ({
|
|||
allowed_file_extensions,
|
||||
} = payload
|
||||
const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig)
|
||||
const { imgSizeLimit, docSizeLimit, audioSizeLimit, videoSizeLimit } = useFileSizeLimit(fileUploadConfigResponse)
|
||||
const {
|
||||
imgSizeLimit,
|
||||
docSizeLimit,
|
||||
audioSizeLimit,
|
||||
videoSizeLimit,
|
||||
maxFileUploadLimit,
|
||||
} = useFileSizeLimit(fileUploadConfigResponse)
|
||||
|
||||
const handleSupportFileTypeChange = useCallback((type: SupportUploadFileTypes) => {
|
||||
const newPayload = produce(payload, (draft) => {
|
||||
|
|
@ -156,7 +162,7 @@ const FileUploadSetting: FC<Props> = ({
|
|||
<InputNumberWithSlider
|
||||
value={max_length}
|
||||
min={1}
|
||||
max={10}
|
||||
max={maxFileUploadLimit}
|
||||
onChange={handleMaxUploadNumLimitChange}
|
||||
/>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ import {
|
|||
useToolIcon,
|
||||
} from '../../hooks'
|
||||
import { useNodeIterationInteractions } from '../iteration/use-interactions'
|
||||
import type { IterationNodeType } from '../iteration/types'
|
||||
import {
|
||||
NodeSourceHandle,
|
||||
NodeTargetHandle,
|
||||
|
|
@ -34,6 +35,7 @@ import NodeControl from './components/node-control'
|
|||
import AddVariablePopupWithPosition from './components/add-variable-popup-with-position'
|
||||
import cn from '@/utils/classnames'
|
||||
import BlockIcon from '@/app/components/workflow/block-icon'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
|
||||
type BaseNodeProps = {
|
||||
children: ReactElement
|
||||
|
|
@ -166,9 +168,27 @@ const BaseNode: FC<BaseNodeProps> = ({
|
|||
/>
|
||||
<div
|
||||
title={data.title}
|
||||
className='grow mr-1 system-sm-semibold-uppercase text-text-primary truncate'
|
||||
className='grow mr-1 system-sm-semibold-uppercase text-text-primary truncate flex items-center'
|
||||
>
|
||||
{data.title}
|
||||
<div>
|
||||
{data.title}
|
||||
</div>
|
||||
{
|
||||
data.type === BlockEnum.Iteration && (data as IterationNodeType).is_parallel && (
|
||||
<Tooltip popupContent={
|
||||
<div className='w-[180px]'>
|
||||
<div className='font-extrabold'>
|
||||
{t('workflow.nodes.iteration.parallelModeEnableTitle')}
|
||||
</div>
|
||||
{t('workflow.nodes.iteration.parallelModeEnableDesc')}
|
||||
</div>}
|
||||
>
|
||||
<div className='flex justify-center items-center px-[5px] py-[3px] ml-1 border-[1px] border-text-warning rounded-[5px] text-text-warning system-2xs-medium-uppercase '>
|
||||
{t('workflow.nodes.iteration.parallelModeUpper')}
|
||||
</div>
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
{
|
||||
data._iterationLength && data._iterationIndex && data._runningStatus === NodeRunningStatus.Running && (
|
||||
|
|
|
|||
|
|
@ -78,24 +78,24 @@ const useConfig = (id: string, payload: IfElseNodeType) => {
|
|||
})
|
||||
|
||||
const handleAddCase = useCallback(() => {
|
||||
const newInputs = produce(inputs, () => {
|
||||
if (inputs.cases) {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
if (draft.cases) {
|
||||
const case_id = uuid4()
|
||||
inputs.cases.push({
|
||||
draft.cases.push({
|
||||
case_id,
|
||||
logical_operator: LogicalOperator.and,
|
||||
conditions: [],
|
||||
})
|
||||
if (inputs._targetBranches) {
|
||||
const elseCaseIndex = inputs._targetBranches.findIndex(branch => branch.id === 'false')
|
||||
if (draft._targetBranches) {
|
||||
const elseCaseIndex = draft._targetBranches.findIndex(branch => branch.id === 'false')
|
||||
if (elseCaseIndex > -1) {
|
||||
inputs._targetBranches = branchNameCorrect([
|
||||
...inputs._targetBranches.slice(0, elseCaseIndex),
|
||||
draft._targetBranches = branchNameCorrect([
|
||||
...draft._targetBranches.slice(0, elseCaseIndex),
|
||||
{
|
||||
id: case_id,
|
||||
name: '',
|
||||
},
|
||||
...inputs._targetBranches.slice(elseCaseIndex),
|
||||
...draft._targetBranches.slice(elseCaseIndex),
|
||||
])
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
import { BlockEnum } from '../../types'
|
||||
import { BlockEnum, ErrorHandleMode } from '../../types'
|
||||
import type { NodeDefault } from '../../types'
|
||||
import type { IterationNodeType } from './types'
|
||||
import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants'
|
||||
import {
|
||||
ALL_CHAT_AVAILABLE_BLOCKS,
|
||||
ALL_COMPLETION_AVAILABLE_BLOCKS,
|
||||
} from '@/app/components/workflow/constants'
|
||||
const i18nPrefix = 'workflow'
|
||||
|
||||
const nodeDefault: NodeDefault<IterationNodeType> = {
|
||||
|
|
@ -10,25 +13,45 @@ const nodeDefault: NodeDefault<IterationNodeType> = {
|
|||
iterator_selector: [],
|
||||
output_selector: [],
|
||||
_children: [],
|
||||
_isShowTips: false,
|
||||
is_parallel: false,
|
||||
parallel_nums: 10,
|
||||
error_handle_mode: ErrorHandleMode.Terminated,
|
||||
},
|
||||
getAvailablePrevNodes(isChatMode: boolean) {
|
||||
const nodes = isChatMode
|
||||
? ALL_CHAT_AVAILABLE_BLOCKS
|
||||
: ALL_COMPLETION_AVAILABLE_BLOCKS.filter(type => type !== BlockEnum.End)
|
||||
: ALL_COMPLETION_AVAILABLE_BLOCKS.filter(
|
||||
type => type !== BlockEnum.End,
|
||||
)
|
||||
return nodes
|
||||
},
|
||||
getAvailableNextNodes(isChatMode: boolean) {
|
||||
const nodes = isChatMode ? ALL_CHAT_AVAILABLE_BLOCKS : ALL_COMPLETION_AVAILABLE_BLOCKS
|
||||
const nodes = isChatMode
|
||||
? ALL_CHAT_AVAILABLE_BLOCKS
|
||||
: ALL_COMPLETION_AVAILABLE_BLOCKS
|
||||
return nodes
|
||||
},
|
||||
checkValid(payload: IterationNodeType, t: any) {
|
||||
let errorMessages = ''
|
||||
|
||||
if (!errorMessages && (!payload.iterator_selector || payload.iterator_selector.length === 0))
|
||||
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.iteration.input`) })
|
||||
if (
|
||||
!errorMessages
|
||||
&& (!payload.iterator_selector || payload.iterator_selector.length === 0)
|
||||
) {
|
||||
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, {
|
||||
field: t(`${i18nPrefix}.nodes.iteration.input`),
|
||||
})
|
||||
}
|
||||
|
||||
if (!errorMessages && (!payload.output_selector || payload.output_selector.length === 0))
|
||||
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.iteration.output`) })
|
||||
if (
|
||||
!errorMessages
|
||||
&& (!payload.output_selector || payload.output_selector.length === 0)
|
||||
) {
|
||||
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, {
|
||||
field: t(`${i18nPrefix}.nodes.iteration.output`),
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
isValid: !errorMessages,
|
||||
|
|
|
|||
|
|
@ -8,12 +8,16 @@ import {
|
|||
useNodesInitialized,
|
||||
useViewport,
|
||||
} from 'reactflow'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { IterationStartNodeDumb } from '../iteration-start'
|
||||
import { useNodeIterationInteractions } from './use-interactions'
|
||||
import type { IterationNodeType } from './types'
|
||||
import AddBlock from './add-block'
|
||||
import cn from '@/utils/classnames'
|
||||
import type { NodeProps } from '@/app/components/workflow/types'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
|
||||
const i18nPrefix = 'workflow.nodes.iteration'
|
||||
|
||||
const Node: FC<NodeProps<IterationNodeType>> = ({
|
||||
id,
|
||||
|
|
@ -22,11 +26,20 @@ const Node: FC<NodeProps<IterationNodeType>> = ({
|
|||
const { zoom } = useViewport()
|
||||
const nodesInitialized = useNodesInitialized()
|
||||
const { handleNodeIterationRerender } = useNodeIterationInteractions()
|
||||
const { t } = useTranslation()
|
||||
|
||||
useEffect(() => {
|
||||
if (nodesInitialized)
|
||||
handleNodeIterationRerender(id)
|
||||
}, [nodesInitialized, id, handleNodeIterationRerender])
|
||||
if (data.is_parallel && data._isShowTips) {
|
||||
Toast.notify({
|
||||
type: 'warning',
|
||||
message: t(`${i18nPrefix}.answerNodeWarningDesc`),
|
||||
duration: 5000,
|
||||
})
|
||||
data._isShowTips = false
|
||||
}
|
||||
}, [nodesInitialized, id, handleNodeIterationRerender, data, t])
|
||||
|
||||
return (
|
||||
<div className={cn(
|
||||
|
|
|
|||
|
|
@ -8,11 +8,17 @@ import VarReferencePicker from '../_base/components/variable/var-reference-picke
|
|||
import Split from '../_base/components/split'
|
||||
import ResultPanel from '../../run/result-panel'
|
||||
import IterationResultPanel from '../../run/iteration-result-panel'
|
||||
import { MAX_ITERATION_PARALLEL_NUM, MIN_ITERATION_PARALLEL_NUM } from '../../constants'
|
||||
import type { IterationNodeType } from './types'
|
||||
import useConfig from './use-config'
|
||||
import { InputVarType, type NodePanelProps } from '@/app/components/workflow/types'
|
||||
import { ErrorHandleMode, InputVarType, type NodePanelProps } from '@/app/components/workflow/types'
|
||||
import Field from '@/app/components/workflow/nodes/_base/components/field'
|
||||
import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form'
|
||||
import Switch from '@/app/components/base/switch'
|
||||
import Select from '@/app/components/base/select'
|
||||
import Slider from '@/app/components/base/slider'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
|
||||
const i18nPrefix = 'workflow.nodes.iteration'
|
||||
|
||||
|
|
@ -21,7 +27,20 @@ const Panel: FC<NodePanelProps<IterationNodeType>> = ({
|
|||
data,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const responseMethod = [
|
||||
{
|
||||
value: ErrorHandleMode.Terminated,
|
||||
name: t(`${i18nPrefix}.ErrorMethod.operationTerminated`),
|
||||
},
|
||||
{
|
||||
value: ErrorHandleMode.ContinueOnError,
|
||||
name: t(`${i18nPrefix}.ErrorMethod.continueOnError`),
|
||||
},
|
||||
{
|
||||
value: ErrorHandleMode.RemoveAbnormalOutput,
|
||||
name: t(`${i18nPrefix}.ErrorMethod.removeAbnormalOutput`),
|
||||
},
|
||||
]
|
||||
const {
|
||||
readOnly,
|
||||
inputs,
|
||||
|
|
@ -47,6 +66,9 @@ const Panel: FC<NodePanelProps<IterationNodeType>> = ({
|
|||
setIterator,
|
||||
iteratorInputKey,
|
||||
iterationRunResult,
|
||||
changeParallel,
|
||||
changeErrorResponseMode,
|
||||
changeParallelNums,
|
||||
} = useConfig(id, data)
|
||||
|
||||
return (
|
||||
|
|
@ -87,6 +109,39 @@ const Panel: FC<NodePanelProps<IterationNodeType>> = ({
|
|||
/>
|
||||
</Field>
|
||||
</div>
|
||||
<div className='px-4 pb-2'>
|
||||
<Field title={t(`${i18nPrefix}.parallelMode`)} tooltip={<div className='w-[230px]'>{t(`${i18nPrefix}.parallelPanelDesc`)}</div>} inline>
|
||||
<Switch defaultValue={inputs.is_parallel} onChange={changeParallel} />
|
||||
</Field>
|
||||
</div>
|
||||
{
|
||||
inputs.is_parallel && (<div className='px-4 pb-2'>
|
||||
<Field title={t(`${i18nPrefix}.MaxParallelismTitle`)} isSubTitle tooltip={<div className='w-[230px]'>{t(`${i18nPrefix}.MaxParallelismDesc`)}</div>}>
|
||||
<div className='flex row'>
|
||||
<Input type='number' wrapperClassName='w-18 mr-4 ' max={MAX_ITERATION_PARALLEL_NUM} min={MIN_ITERATION_PARALLEL_NUM} value={inputs.parallel_nums} onChange={(e) => { changeParallelNums(Number(e.target.value)) }} />
|
||||
<Slider
|
||||
value={inputs.parallel_nums}
|
||||
onChange={changeParallelNums}
|
||||
max={MAX_ITERATION_PARALLEL_NUM}
|
||||
min={MIN_ITERATION_PARALLEL_NUM}
|
||||
className=' flex-shrink-0 flex-1 mt-4'
|
||||
/>
|
||||
</div>
|
||||
|
||||
</Field>
|
||||
</div>)
|
||||
}
|
||||
<div className='px-4 py-2'>
|
||||
<Divider className='h-[1px]'/>
|
||||
</div>
|
||||
|
||||
<div className='px-4 py-2'>
|
||||
<Field title={t(`${i18nPrefix}.errorResponseMethod`)} >
|
||||
<Select items={responseMethod} defaultValue={inputs.error_handle_mode} onSelect={changeErrorResponseMode} allowSearch={false}>
|
||||
</Select>
|
||||
</Field>
|
||||
</div>
|
||||
|
||||
{isShowSingleRun && (
|
||||
<BeforeRunForm
|
||||
nodeName={inputs.title}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import type {
|
||||
BlockEnum,
|
||||
CommonNodeType,
|
||||
ErrorHandleMode,
|
||||
ValueSelector,
|
||||
VarType,
|
||||
} from '@/app/components/workflow/types'
|
||||
|
|
@ -12,4 +13,8 @@ export type IterationNodeType = CommonNodeType & {
|
|||
iterator_selector: ValueSelector
|
||||
output_selector: ValueSelector
|
||||
output_type: VarType // output type.
|
||||
is_parallel: boolean // open the parallel mode or not
|
||||
parallel_nums: number // the numbers of parallel
|
||||
error_handle_mode: ErrorHandleMode // how to handle error in the iteration
|
||||
_isShowTips: boolean // when answer node in parallel mode iteration show tips
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,12 +8,13 @@ import {
|
|||
useWorkflow,
|
||||
} from '../../hooks'
|
||||
import { VarType } from '../../types'
|
||||
import type { ValueSelector, Var } from '../../types'
|
||||
import type { ErrorHandleMode, ValueSelector, Var } from '../../types'
|
||||
import useNodeCrud from '../_base/hooks/use-node-crud'
|
||||
import { getNodeInfoById, getNodeUsedVarPassToServerKey, getNodeUsedVars, isSystemVar, toNodeOutputVars } from '../_base/components/variable/utils'
|
||||
import useOneStepRun from '../_base/hooks/use-one-step-run'
|
||||
import type { IterationNodeType } from './types'
|
||||
import type { VarType as VarKindType } from '@/app/components/workflow/nodes/tool/types'
|
||||
import type { Item } from '@/app/components/base/select'
|
||||
|
||||
const DELIMITER = '@@@@@'
|
||||
const useConfig = (id: string, payload: IterationNodeType) => {
|
||||
|
|
@ -184,6 +185,25 @@ const useConfig = (id: string, payload: IterationNodeType) => {
|
|||
})
|
||||
}, [iteratorInputKey, runInputData, setRunInputData])
|
||||
|
||||
const changeParallel = useCallback((value: boolean) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.is_parallel = value
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [inputs, setInputs])
|
||||
|
||||
const changeErrorResponseMode = useCallback((item: Item) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.error_handle_mode = item.value as ErrorHandleMode
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [inputs, setInputs])
|
||||
const changeParallelNums = useCallback((num: number) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.parallel_nums = num
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [inputs, setInputs])
|
||||
return {
|
||||
readOnly,
|
||||
inputs,
|
||||
|
|
@ -210,6 +230,9 @@ const useConfig = (id: string, payload: IterationNodeType) => {
|
|||
setIterator,
|
||||
iteratorInputKey,
|
||||
iterationRunResult,
|
||||
changeParallel,
|
||||
changeErrorResponseMode,
|
||||
changeParallelNums,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ import { produce, setAutoFreeze } from 'immer'
|
|||
import { uniqBy } from 'lodash-es'
|
||||
import { useWorkflowRun } from '../../hooks'
|
||||
import { NodeRunningStatus, WorkflowRunningStatus } from '../../types'
|
||||
import { useWorkflowStore } from '../../store'
|
||||
import { DEFAULT_ITER_TIMES } from '../../constants'
|
||||
import type {
|
||||
ChatItem,
|
||||
Inputs,
|
||||
|
|
@ -43,6 +45,7 @@ export const useChat = (
|
|||
const { notify } = useToastContext()
|
||||
const { handleRun } = useWorkflowRun()
|
||||
const hasStopResponded = useRef(false)
|
||||
const workflowStore = useWorkflowStore()
|
||||
const conversationId = useRef('')
|
||||
const taskIdRef = useRef('')
|
||||
const [chatList, setChatList] = useState<ChatItem[]>(prevChatList || [])
|
||||
|
|
@ -52,6 +55,9 @@ export const useChat = (
|
|||
const [suggestedQuestions, setSuggestQuestions] = useState<string[]>([])
|
||||
const suggestedQuestionsAbortControllerRef = useRef<AbortController | null>(null)
|
||||
|
||||
const {
|
||||
setIterTimes,
|
||||
} = workflowStore.getState()
|
||||
useEffect(() => {
|
||||
setAutoFreeze(false)
|
||||
return () => {
|
||||
|
|
@ -102,15 +108,16 @@ export const useChat = (
|
|||
handleResponding(false)
|
||||
if (stopChat && taskIdRef.current)
|
||||
stopChat(taskIdRef.current)
|
||||
|
||||
setIterTimes(DEFAULT_ITER_TIMES)
|
||||
if (suggestedQuestionsAbortControllerRef.current)
|
||||
suggestedQuestionsAbortControllerRef.current.abort()
|
||||
}, [handleResponding, stopChat])
|
||||
}, [handleResponding, setIterTimes, stopChat])
|
||||
|
||||
const handleRestart = useCallback(() => {
|
||||
conversationId.current = ''
|
||||
taskIdRef.current = ''
|
||||
handleStop()
|
||||
setIterTimes(DEFAULT_ITER_TIMES)
|
||||
const newChatList = config?.opening_statement
|
||||
? [{
|
||||
id: `${Date.now()}`,
|
||||
|
|
@ -126,6 +133,7 @@ export const useChat = (
|
|||
config,
|
||||
handleStop,
|
||||
handleUpdateChatList,
|
||||
setIterTimes,
|
||||
])
|
||||
|
||||
const updateCurrentQA = useCallback(({
|
||||
|
|
|
|||
|
|
@ -60,36 +60,67 @@ const RunPanel: FC<RunProps> = ({ hideResult, activeTab = 'RESULT', runID, getRe
|
|||
}, [notify, getResultCallback])
|
||||
|
||||
const formatNodeList = useCallback((list: NodeTracing[]) => {
|
||||
const allItems = list.reverse()
|
||||
const allItems = [...list].reverse()
|
||||
const result: NodeTracing[] = []
|
||||
allItems.forEach((item) => {
|
||||
const { node_type, execution_metadata } = item
|
||||
if (node_type !== BlockEnum.Iteration) {
|
||||
const isInIteration = !!execution_metadata?.iteration_id
|
||||
const groupMap = new Map<string, NodeTracing[]>()
|
||||
|
||||
if (isInIteration) {
|
||||
const iterationNode = result.find(node => node.node_id === execution_metadata?.iteration_id)
|
||||
const iterationDetails = iterationNode?.details
|
||||
const currentIterationIndex = execution_metadata?.iteration_index ?? 0
|
||||
|
||||
if (Array.isArray(iterationDetails)) {
|
||||
if (iterationDetails.length === 0 || !iterationDetails[currentIterationIndex])
|
||||
iterationDetails[currentIterationIndex] = [item]
|
||||
else
|
||||
iterationDetails[currentIterationIndex].push(item)
|
||||
}
|
||||
return
|
||||
}
|
||||
// not in iteration
|
||||
result.push(item)
|
||||
|
||||
return
|
||||
}
|
||||
const processIterationNode = (item: NodeTracing) => {
|
||||
result.push({
|
||||
...item,
|
||||
details: [],
|
||||
})
|
||||
}
|
||||
const updateParallelModeGroup = (runId: string, item: NodeTracing, iterationNode: NodeTracing) => {
|
||||
if (!groupMap.has(runId))
|
||||
groupMap.set(runId, [item])
|
||||
else
|
||||
groupMap.get(runId)!.push(item)
|
||||
if (item.status === 'failed') {
|
||||
iterationNode.status = 'failed'
|
||||
iterationNode.error = item.error
|
||||
}
|
||||
|
||||
iterationNode.details = Array.from(groupMap.values())
|
||||
}
|
||||
const updateSequentialModeGroup = (index: number, item: NodeTracing, iterationNode: NodeTracing) => {
|
||||
const { details } = iterationNode
|
||||
if (details) {
|
||||
if (!details[index])
|
||||
details[index] = [item]
|
||||
else
|
||||
details[index].push(item)
|
||||
}
|
||||
|
||||
if (item.status === 'failed') {
|
||||
iterationNode.status = 'failed'
|
||||
iterationNode.error = item.error
|
||||
}
|
||||
}
|
||||
const processNonIterationNode = (item: NodeTracing) => {
|
||||
const { execution_metadata } = item
|
||||
if (!execution_metadata?.iteration_id) {
|
||||
result.push(item)
|
||||
return
|
||||
}
|
||||
|
||||
const iterationNode = result.find(node => node.node_id === execution_metadata.iteration_id)
|
||||
if (!iterationNode || !Array.isArray(iterationNode.details))
|
||||
return
|
||||
|
||||
const { parallel_mode_run_id, iteration_index = 0 } = execution_metadata
|
||||
|
||||
if (parallel_mode_run_id)
|
||||
updateParallelModeGroup(parallel_mode_run_id, item, iterationNode)
|
||||
else
|
||||
updateSequentialModeGroup(iteration_index, item, iterationNode)
|
||||
}
|
||||
|
||||
allItems.forEach((item) => {
|
||||
item.node_type === BlockEnum.Iteration
|
||||
? processIterationNode(item)
|
||||
: processNonIterationNode(item)
|
||||
})
|
||||
|
||||
return result
|
||||
}, [])
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import { useTranslation } from 'react-i18next'
|
|||
import {
|
||||
RiArrowRightSLine,
|
||||
RiCloseLine,
|
||||
RiErrorWarningLine,
|
||||
} from '@remixicon/react'
|
||||
import { ArrowNarrowLeft } from '../../base/icons/src/vender/line/arrows'
|
||||
import TracingPanel from './tracing-panel'
|
||||
|
|
@ -27,7 +28,7 @@ const IterationResultPanel: FC<Props> = ({
|
|||
noWrap,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const [expandedIterations, setExpandedIterations] = useState<Record<number, boolean>>([])
|
||||
const [expandedIterations, setExpandedIterations] = useState<Record<number, boolean>>({})
|
||||
|
||||
const toggleIteration = useCallback((index: number) => {
|
||||
setExpandedIterations(prev => ({
|
||||
|
|
@ -71,10 +72,19 @@ const IterationResultPanel: FC<Props> = ({
|
|||
<span className='system-sm-semibold-uppercase text-text-primary flex-grow'>
|
||||
{t(`${i18nPrefix}.iteration`)} {index + 1}
|
||||
</span>
|
||||
<RiArrowRightSLine className={cn(
|
||||
'w-4 h-4 text-text-tertiary transition-transform duration-200 flex-shrink-0',
|
||||
expandedIterations[index] && 'transform rotate-90',
|
||||
)} />
|
||||
{
|
||||
iteration.some(item => item.status === 'failed')
|
||||
? (
|
||||
<RiErrorWarningLine className='w-4 h-4 text-text-destructive' />
|
||||
)
|
||||
: (< RiArrowRightSLine className={
|
||||
cn(
|
||||
'w-4 h-4 text-text-tertiary transition-transform duration-200 flex-shrink-0',
|
||||
expandedIterations[index] && 'transform rotate-90',
|
||||
)} />
|
||||
)
|
||||
}
|
||||
|
||||
</div>
|
||||
</div>
|
||||
{expandedIterations[index] && <div
|
||||
|
|
|
|||
|
|
@ -72,7 +72,16 @@ const NodePanel: FC<Props> = ({
|
|||
|
||||
return iteration_length
|
||||
}
|
||||
const getErrorCount = (details: NodeTracing[][] | undefined) => {
|
||||
if (!details || details.length === 0)
|
||||
return 0
|
||||
|
||||
return details.reduce((acc, iteration) => {
|
||||
if (iteration.some(item => item.status === 'failed'))
|
||||
acc++
|
||||
return acc
|
||||
}, 0)
|
||||
}
|
||||
useEffect(() => {
|
||||
setCollapseState(!nodeInfo.expand)
|
||||
}, [nodeInfo.expand, setCollapseState])
|
||||
|
|
@ -136,7 +145,12 @@ const NodePanel: FC<Props> = ({
|
|||
onClick={handleOnShowIterationDetail}
|
||||
>
|
||||
<Iteration className='w-4 h-4 text-components-button-tertiary-text flex-shrink-0' />
|
||||
<div className='flex-1 text-left system-sm-medium text-components-button-tertiary-text'>{t('workflow.nodes.iteration.iteration', { count: getCount(nodeInfo.details?.length, nodeInfo.metadata?.iterator_length) })}</div>
|
||||
<div className='flex-1 text-left system-sm-medium text-components-button-tertiary-text'>{t('workflow.nodes.iteration.iteration', { count: getCount(nodeInfo.details?.length, nodeInfo.metadata?.iterator_length) })}{getErrorCount(nodeInfo.details) > 0 && (
|
||||
<>
|
||||
{t('workflow.nodes.iteration.comma')}
|
||||
{t('workflow.nodes.iteration.error', { count: getErrorCount(nodeInfo.details) })}
|
||||
</>
|
||||
)}</div>
|
||||
{justShowIterationNavArrow
|
||||
? (
|
||||
<RiArrowRightSLine className='w-4 h-4 text-components-button-tertiary-text flex-shrink-0' />
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import type {
|
|||
WorkflowRunningData,
|
||||
} from './types'
|
||||
import { WorkflowContext } from './context'
|
||||
import type { NodeTracing } from '@/types/workflow'
|
||||
|
||||
// #TODO chatVar#
|
||||
// const MOCK_DATA = [
|
||||
|
|
@ -166,6 +167,10 @@ type Shape = {
|
|||
setShowImportDSLModal: (showImportDSLModal: boolean) => void
|
||||
showTips: string
|
||||
setShowTips: (showTips: string) => void
|
||||
iterTimes: number
|
||||
setIterTimes: (iterTimes: number) => void
|
||||
iterParallelLogMap: Map<string, NodeTracing[]>
|
||||
setIterParallelLogMap: (iterParallelLogMap: Map<string, NodeTracing[]>) => void
|
||||
}
|
||||
|
||||
export const createWorkflowStore = () => {
|
||||
|
|
@ -281,6 +286,11 @@ export const createWorkflowStore = () => {
|
|||
setShowImportDSLModal: showImportDSLModal => set(() => ({ showImportDSLModal })),
|
||||
showTips: '',
|
||||
setShowTips: showTips => set(() => ({ showTips })),
|
||||
iterTimes: 1,
|
||||
setIterTimes: iterTimes => set(() => ({ iterTimes })),
|
||||
iterParallelLogMap: new Map<string, NodeTracing[]>(),
|
||||
setIterParallelLogMap: iterParallelLogMap => set(() => ({ iterParallelLogMap })),
|
||||
|
||||
}))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -36,7 +36,11 @@ export enum ControlMode {
|
|||
Pointer = 'pointer',
|
||||
Hand = 'hand',
|
||||
}
|
||||
|
||||
export enum ErrorHandleMode {
|
||||
Terminated = 'terminated',
|
||||
ContinueOnError = 'continue-on-error',
|
||||
RemoveAbnormalOutput = 'remove-abnormal-output',
|
||||
}
|
||||
export type Branch = {
|
||||
id: string
|
||||
name: string
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ import type {
|
|||
ToolWithProvider,
|
||||
ValueSelector,
|
||||
} from './types'
|
||||
import { BlockEnum } from './types'
|
||||
import { BlockEnum, ErrorHandleMode } from './types'
|
||||
import {
|
||||
CUSTOM_NODE,
|
||||
ITERATION_CHILDREN_Z_INDEX,
|
||||
|
|
@ -267,8 +267,13 @@ export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => {
|
|||
})
|
||||
}
|
||||
|
||||
if (node.data.type === BlockEnum.Iteration)
|
||||
node.data._children = iterationNodeMap[node.id] || []
|
||||
if (node.data.type === BlockEnum.Iteration) {
|
||||
const iterationNodeData = node.data as IterationNodeType
|
||||
iterationNodeData._children = iterationNodeMap[node.id] || []
|
||||
iterationNodeData.is_parallel = iterationNodeData.is_parallel || false
|
||||
iterationNodeData.parallel_nums = iterationNodeData.parallel_nums || 10
|
||||
iterationNodeData.error_handle_mode = iterationNodeData.error_handle_mode || ErrorHandleMode.Terminated
|
||||
}
|
||||
|
||||
return node
|
||||
})
|
||||
|
|
|
|||
|
|
@ -12,11 +12,9 @@ import cn from '@/utils/classnames'
|
|||
import { getSystemFeatures, invitationCheck } from '@/service/common'
|
||||
import { defaultSystemFeatures } from '@/types/feature'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import useRefreshToken from '@/hooks/use-refresh-token'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
|
||||
const NormalForm = () => {
|
||||
const { getNewAccessToken } = useRefreshToken()
|
||||
const { t } = useTranslation()
|
||||
const router = useRouter()
|
||||
const searchParams = useSearchParams()
|
||||
|
|
@ -38,7 +36,6 @@ const NormalForm = () => {
|
|||
if (consoleToken && refreshToken) {
|
||||
localStorage.setItem('console_token', consoleToken)
|
||||
localStorage.setItem('refresh_token', refreshToken)
|
||||
getNewAccessToken()
|
||||
router.replace('/apps')
|
||||
return
|
||||
}
|
||||
|
|
@ -71,7 +68,7 @@ const NormalForm = () => {
|
|||
setSystemFeatures(defaultSystemFeatures)
|
||||
}
|
||||
finally { setIsLoading(false) }
|
||||
}, [consoleToken, refreshToken, message, router, invite_token, isInviteLink, getNewAccessToken])
|
||||
}, [consoleToken, refreshToken, message, router, invite_token, isInviteLink])
|
||||
useEffect(() => {
|
||||
init()
|
||||
}, [init])
|
||||
|
|
|
|||
|
|
@ -1,99 +0,0 @@
|
|||
'use client'
|
||||
import { useCallback, useEffect, useRef } from 'react'
|
||||
import { jwtDecode } from 'jwt-decode'
|
||||
import dayjs from 'dayjs'
|
||||
import utc from 'dayjs/plugin/utc'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import type { CommonResponse } from '@/models/common'
|
||||
import { fetchNewToken } from '@/service/common'
|
||||
import { fetchWithRetry } from '@/utils'
|
||||
|
||||
dayjs.extend(utc)
|
||||
|
||||
const useRefreshToken = () => {
|
||||
const router = useRouter()
|
||||
const timer = useRef<NodeJS.Timeout>()
|
||||
const advanceTime = useRef<number>(5 * 60 * 1000)
|
||||
|
||||
const getExpireTime = useCallback((token: string) => {
|
||||
if (!token)
|
||||
return 0
|
||||
const decoded = jwtDecode(token)
|
||||
return (decoded.exp || 0) * 1000
|
||||
}, [])
|
||||
|
||||
const getCurrentTimeStamp = useCallback(() => {
|
||||
return dayjs.utc().valueOf()
|
||||
}, [])
|
||||
|
||||
const handleError = useCallback(() => {
|
||||
localStorage?.removeItem('is_refreshing')
|
||||
localStorage?.removeItem('console_token')
|
||||
localStorage?.removeItem('refresh_token')
|
||||
router.replace('/signin')
|
||||
}, [])
|
||||
|
||||
const getNewAccessToken = useCallback(async () => {
|
||||
const currentAccessToken = localStorage?.getItem('console_token')
|
||||
const currentRefreshToken = localStorage?.getItem('refresh_token')
|
||||
if (!currentAccessToken || !currentRefreshToken) {
|
||||
handleError()
|
||||
return new Error('No access token or refresh token found')
|
||||
}
|
||||
if (localStorage?.getItem('is_refreshing') === '1') {
|
||||
clearTimeout(timer.current)
|
||||
timer.current = setTimeout(() => {
|
||||
getNewAccessToken()
|
||||
}, 1000)
|
||||
return null
|
||||
}
|
||||
const currentTokenExpireTime = getExpireTime(currentAccessToken)
|
||||
if (getCurrentTimeStamp() + advanceTime.current > currentTokenExpireTime) {
|
||||
localStorage?.setItem('is_refreshing', '1')
|
||||
const [e, res] = await fetchWithRetry(fetchNewToken({
|
||||
body: { refresh_token: currentRefreshToken },
|
||||
}) as Promise<CommonResponse & { data: { access_token: string; refresh_token: string } }>)
|
||||
if (e) {
|
||||
handleError()
|
||||
return e
|
||||
}
|
||||
const { access_token, refresh_token } = res.data
|
||||
localStorage?.setItem('is_refreshing', '0')
|
||||
localStorage?.setItem('console_token', access_token)
|
||||
localStorage?.setItem('refresh_token', refresh_token)
|
||||
const newTokenExpireTime = getExpireTime(access_token)
|
||||
clearTimeout(timer.current)
|
||||
timer.current = setTimeout(() => {
|
||||
getNewAccessToken()
|
||||
}, newTokenExpireTime - advanceTime.current - getCurrentTimeStamp())
|
||||
}
|
||||
else {
|
||||
const newTokenExpireTime = getExpireTime(currentAccessToken)
|
||||
clearTimeout(timer.current)
|
||||
timer.current = setTimeout(() => {
|
||||
getNewAccessToken()
|
||||
}, newTokenExpireTime - advanceTime.current - getCurrentTimeStamp())
|
||||
}
|
||||
return null
|
||||
}, [getExpireTime, getCurrentTimeStamp, handleError])
|
||||
|
||||
const handleVisibilityChange = useCallback(() => {
|
||||
if (document.visibilityState === 'visible')
|
||||
getNewAccessToken()
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
window.addEventListener('visibilitychange', handleVisibilityChange)
|
||||
return () => {
|
||||
window.removeEventListener('visibilitychange', handleVisibilityChange)
|
||||
clearTimeout(timer.current)
|
||||
localStorage?.removeItem('is_refreshing')
|
||||
}
|
||||
}, [])
|
||||
|
||||
return {
|
||||
getNewAccessToken,
|
||||
}
|
||||
}
|
||||
|
||||
export default useRefreshToken
|
||||
|
|
@ -557,6 +557,23 @@ const translation = {
|
|||
iteration_one: '{{count}} Iteration',
|
||||
iteration_other: '{{count}} Iterationen',
|
||||
currentIteration: 'Aktuelle Iteration',
|
||||
ErrorMethod: {
|
||||
operationTerminated: 'beendet',
|
||||
removeAbnormalOutput: 'remove-abnormale_ausgabe',
|
||||
continueOnError: 'Fehler "Fortfahren bei"',
|
||||
},
|
||||
MaxParallelismTitle: 'Maximale Parallelität',
|
||||
parallelMode: 'Paralleler Modus',
|
||||
errorResponseMethod: 'Methode der Fehlerantwort',
|
||||
error_one: '{{Anzahl}} Fehler',
|
||||
error_other: '{{Anzahl}} Irrtümer',
|
||||
MaxParallelismDesc: 'Die maximale Parallelität wird verwendet, um die Anzahl der Aufgaben zu steuern, die gleichzeitig in einer einzigen Iteration ausgeführt werden.',
|
||||
parallelPanelDesc: 'Im parallelen Modus unterstützen Aufgaben in der Iteration die parallele Ausführung.',
|
||||
parallelModeEnableDesc: 'Im parallelen Modus unterstützen Aufgaben innerhalb von Iterationen die parallele Ausführung. Sie können dies im Eigenschaftenbereich auf der rechten Seite konfigurieren.',
|
||||
answerNodeWarningDesc: 'Warnung im parallelen Modus: Antwortknoten, Zuweisungen von Konversationsvariablen und persistente Lese-/Schreibvorgänge innerhalb von Iterationen können Ausnahmen verursachen.',
|
||||
parallelModeEnableTitle: 'Paralleler Modus aktiviert',
|
||||
parallelModeUpper: 'PARALLELER MODUS',
|
||||
comma: ',',
|
||||
},
|
||||
note: {
|
||||
editor: {
|
||||
|
|
|
|||
|
|
@ -556,6 +556,23 @@ const translation = {
|
|||
iteration_one: '{{count}} Iteration',
|
||||
iteration_other: '{{count}} Iterations',
|
||||
currentIteration: 'Current Iteration',
|
||||
comma: ', ',
|
||||
error_one: '{{count}} Error',
|
||||
error_other: '{{count}} Errors',
|
||||
parallelMode: 'Parallel Mode',
|
||||
parallelModeUpper: 'PARALLEL MODE',
|
||||
parallelModeEnableTitle: 'Parallel Mode Enabled',
|
||||
parallelModeEnableDesc: 'In parallel mode, tasks within iterations support parallel execution. You can configure this in the properties panel on the right.',
|
||||
parallelPanelDesc: 'In parallel mode, tasks in the iteration support parallel execution.',
|
||||
MaxParallelismTitle: 'Maximum parallelism',
|
||||
MaxParallelismDesc: 'The maximum parallelism is used to control the number of tasks executed simultaneously in a single iteration.',
|
||||
errorResponseMethod: 'Error response method',
|
||||
ErrorMethod: {
|
||||
operationTerminated: 'terminated',
|
||||
continueOnError: 'continue-on-error',
|
||||
removeAbnormalOutput: 'remove-abnormal-output',
|
||||
},
|
||||
answerNodeWarningDesc: 'Parallel mode warning: Answer nodes, conversation variable assignments, and persistent read/write operations within iterations may cause exceptions.',
|
||||
},
|
||||
note: {
|
||||
addNote: 'Add Note',
|
||||
|
|
|
|||
|
|
@ -557,6 +557,23 @@ const translation = {
|
|||
iteration_one: '{{count}} Iteración',
|
||||
iteration_other: '{{count}} Iteraciones',
|
||||
currentIteration: 'Iteración actual',
|
||||
ErrorMethod: {
|
||||
operationTerminated: 'Terminado',
|
||||
continueOnError: 'Continuar en el error',
|
||||
removeAbnormalOutput: 'eliminar-salida-anormal',
|
||||
},
|
||||
comma: ',',
|
||||
errorResponseMethod: 'Método de respuesta a errores',
|
||||
error_one: '{{conteo}} Error',
|
||||
parallelPanelDesc: 'En el modo paralelo, las tareas de la iteración admiten la ejecución en paralelo.',
|
||||
MaxParallelismTitle: 'Máximo paralelismo',
|
||||
error_other: '{{conteo}} Errores',
|
||||
parallelMode: 'Modo paralelo',
|
||||
parallelModeEnableDesc: 'En el modo paralelo, las tareas dentro de las iteraciones admiten la ejecución en paralelo. Puede configurar esto en el panel de propiedades a la derecha.',
|
||||
parallelModeUpper: 'MODO PARALELO',
|
||||
MaxParallelismDesc: 'El paralelismo máximo se utiliza para controlar el número de tareas ejecutadas simultáneamente en una sola iteración.',
|
||||
answerNodeWarningDesc: 'Advertencia de modo paralelo: Los nodos de respuesta, las asignaciones de variables de conversación y las operaciones de lectura/escritura persistentes dentro de las iteraciones pueden provocar excepciones.',
|
||||
parallelModeEnableTitle: 'Modo paralelo habilitado',
|
||||
},
|
||||
note: {
|
||||
addNote: 'Agregar nota',
|
||||
|
|
|
|||
|
|
@ -557,6 +557,23 @@ const translation = {
|
|||
iteration_one: '{{count}} تکرار',
|
||||
iteration_other: '{{count}} تکرارها',
|
||||
currentIteration: 'تکرار فعلی',
|
||||
ErrorMethod: {
|
||||
continueOnError: 'ادامه در خطا',
|
||||
operationTerminated: 'فسخ',
|
||||
removeAbnormalOutput: 'حذف خروجی غیرطبیعی',
|
||||
},
|
||||
error_one: '{{تعداد}} خطا',
|
||||
error_other: '{{تعداد}} خطاهای',
|
||||
parallelMode: 'حالت موازی',
|
||||
errorResponseMethod: 'روش پاسخ به خطا',
|
||||
parallelModeEnableTitle: 'حالت موازی فعال است',
|
||||
parallelModeUpper: 'حالت موازی',
|
||||
comma: ',',
|
||||
parallelModeEnableDesc: 'در حالت موازی، وظایف درون تکرارها از اجرای موازی پشتیبانی می کنند. می توانید این را در پانل ویژگی ها در سمت راست پیکربندی کنید.',
|
||||
MaxParallelismTitle: 'حداکثر موازی سازی',
|
||||
parallelPanelDesc: 'در حالت موازی، وظایف در تکرار از اجرای موازی پشتیبانی می کنند.',
|
||||
MaxParallelismDesc: 'حداکثر موازی سازی برای کنترل تعداد وظایف اجرا شده به طور همزمان در یک تکرار واحد استفاده می شود.',
|
||||
answerNodeWarningDesc: 'هشدار حالت موازی: گره های پاسخ، تکالیف متغیر مکالمه و عملیات خواندن/نوشتن مداوم در تکرارها ممکن است باعث استثنائات شود.',
|
||||
},
|
||||
note: {
|
||||
addNote: 'افزودن یادداشت',
|
||||
|
|
|
|||
|
|
@ -557,6 +557,23 @@ const translation = {
|
|||
iteration_one: '{{count}} Itération',
|
||||
iteration_other: '{{count}} Itérations',
|
||||
currentIteration: 'Itération actuelle',
|
||||
ErrorMethod: {
|
||||
operationTerminated: 'Terminé',
|
||||
removeAbnormalOutput: 'remove-abnormal-output',
|
||||
continueOnError: 'continuer sur l’erreur',
|
||||
},
|
||||
comma: ',',
|
||||
error_one: '{{compte}} Erreur',
|
||||
error_other: '{{compte}} Erreurs',
|
||||
parallelModeEnableDesc: 'En mode parallèle, les tâches au sein des itérations prennent en charge l’exécution parallèle. Vous pouvez le configurer dans le panneau des propriétés à droite.',
|
||||
parallelModeUpper: 'MODE PARALLÈLE',
|
||||
parallelPanelDesc: 'En mode parallèle, les tâches de l’itération prennent en charge l’exécution parallèle.',
|
||||
MaxParallelismDesc: 'Le parallélisme maximal est utilisé pour contrôler le nombre de tâches exécutées simultanément en une seule itération.',
|
||||
errorResponseMethod: 'Méthode de réponse aux erreurs',
|
||||
MaxParallelismTitle: 'Parallélisme maximal',
|
||||
answerNodeWarningDesc: 'Avertissement en mode parallèle : les nœuds de réponse, les affectations de variables de conversation et les opérations de lecture/écriture persistantes au sein des itérations peuvent provoquer des exceptions.',
|
||||
parallelModeEnableTitle: 'Mode parallèle activé',
|
||||
parallelMode: 'Mode parallèle',
|
||||
},
|
||||
note: {
|
||||
addNote: 'Ajouter note',
|
||||
|
|
|
|||
|
|
@ -577,6 +577,23 @@ const translation = {
|
|||
iteration_one: '{{count}} इटरेशन',
|
||||
iteration_other: '{{count}} इटरेशन्स',
|
||||
currentIteration: 'वर्तमान इटरेशन',
|
||||
ErrorMethod: {
|
||||
operationTerminated: 'समाप्त',
|
||||
continueOnError: 'जारी रखें-पर-त्रुटि',
|
||||
removeAbnormalOutput: 'निकालें-असामान्य-आउटपुट',
|
||||
},
|
||||
comma: ',',
|
||||
error_other: '{{गिनती}} त्रुटियों',
|
||||
error_one: '{{गिनती}} चूक',
|
||||
parallelMode: 'समानांतर मोड',
|
||||
parallelModeUpper: 'समानांतर मोड',
|
||||
errorResponseMethod: 'त्रुटि प्रतिक्रिया विधि',
|
||||
MaxParallelismTitle: 'अधिकतम समांतरता',
|
||||
parallelModeEnableTitle: 'समानांतर मोड सक्षम किया गया',
|
||||
parallelModeEnableDesc: 'समानांतर मोड में, पुनरावृत्तियों के भीतर कार्य समानांतर निष्पादन का समर्थन करते हैं। आप इसे दाईं ओर गुण पैनल में कॉन्फ़िगर कर सकते हैं।',
|
||||
parallelPanelDesc: 'समानांतर मोड में, पुनरावृत्ति में कार्य समानांतर निष्पादन का समर्थन करते हैं।',
|
||||
MaxParallelismDesc: 'अधिकतम समांतरता का उपयोग एकल पुनरावृत्ति में एक साथ निष्पादित कार्यों की संख्या को नियंत्रित करने के लिए किया जाता है।',
|
||||
answerNodeWarningDesc: 'समानांतर मोड चेतावनी: उत्तर नोड्स, वार्तालाप चर असाइनमेंट, और पुनरावृत्तियों के भीतर लगातार पढ़ने/लिखने की कार्रवाई अपवाद पैदा कर सकती है।',
|
||||
},
|
||||
note: {
|
||||
addNote: 'नोट जोड़ें',
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue