diff --git a/api/commands.py b/api/commands.py index 09548ac9f3..bf013cc77e 100644 --- a/api/commands.py +++ b/api/commands.py @@ -555,7 +555,8 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str if language not in languages: language = "en-US" - name = name.strip() + # Validates name encoding for non-Latin characters. + name = name.strip().encode("utf-8").decode("utf-8") if name else None # generate random password new_password = secrets.token_urlsafe(16) diff --git a/api/controllers/common/errors.py b/api/controllers/common/errors.py index c71f1ce5a3..9f762b3135 100644 --- a/api/controllers/common/errors.py +++ b/api/controllers/common/errors.py @@ -4,3 +4,8 @@ from werkzeug.exceptions import HTTPException class FilenameNotExistsError(HTTPException): code = 400 description = "The specified filename does not exist." + + +class RemoteFileUploadError(HTTPException): + code = 400 + description = "Error uploading remote file." diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 6f9d7769b9..5e7a3da017 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,12 +1,14 @@ from flask_login import current_user from flask_restful import marshal_with, reqparse from flask_restful.inputs import int_range +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from controllers.console import api from controllers.console.explore.error import NotChatAppError from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value from models.model import AppMode @@ -34,14 +36,16 @@ class ConversationListApi(InstalledAppResource): pinned = True if args["pinned"] == "true" else False try: - return WebConversationService.pagination_by_last_id( - app_model=app_model, - user=current_user, - last_id=args["last_id"], - limit=args["limit"], - invoke_from=InvokeFrom.EXPLORE, - pinned=pinned, - ) + with Session(db.engine) as session: + return WebConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=current_user, + last_id=args["last_id"], + limit=args["limit"], + invoke_from=InvokeFrom.EXPLORE, + pinned=pinned, + ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 3d221ff30a..4e11d8005f 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -70,7 +70,7 @@ class MessageFeedbackApi(InstalledAppResource): args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, current_user, args["rating"]) + MessageService.create_feedback(app_model, message_id, current_user, args["rating"], args["content"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index fac1341b39..b8cf019e4f 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -7,6 +7,7 @@ from flask_restful import Resource, marshal_with, reqparse import services from controllers.common import helpers +from controllers.common.errors import RemoteFileUploadError from core.file import helpers as file_helpers from core.helper import ssrf_proxy from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields @@ -43,10 +44,14 @@ class RemoteFileUploadApi(Resource): url = args["url"] - resp = ssrf_proxy.head(url=url) - if resp.status_code != httpx.codes.OK: - resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) - resp.raise_for_status() + try: + resp = ssrf_proxy.head(url=url) + if resp.status_code != httpx.codes.OK: + resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) + if resp.status_code != httpx.codes.OK: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") + except httpx.RequestError as e: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}") file_info = helpers.guess_file_info_from_response(resp) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 2cd6dcda3b..9e62a54699 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -3,12 +3,14 @@ import io from flask import send_file from flask_login import current_user from flask_restful import Resource, reqparse +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder +from extensions.ext_database import db from libs.helper import alphanumeric, uuid_value from libs.login import login_required from services.tools.api_tools_manage_service import ApiToolManageService @@ -91,12 +93,16 @@ class ToolBuiltinProviderUpdateApi(Resource): args = parser.parse_args() - return BuiltinToolManageService.update_builtin_tool_provider( - user_id, - tenant_id, - provider, - args["credentials"], - ) + with Session(db.engine) as session: + result = BuiltinToolManageService.update_builtin_tool_provider( + session=session, + user_id=user_id, + tenant_id=tenant_id, + provider_name=provider, + credentials=args["credentials"], + ) + session.commit() + return result class ToolBuiltinProviderGetCredentialsApi(Resource): @@ -104,13 +110,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): @login_required @account_initialization_required def get(self, provider): - user_id = current_user.id tenant_id = current_user.current_tenant_id return BuiltinToolManageService.get_builtin_tool_provider_credentials( - user_id, - tenant_id, - provider, + tenant_id=tenant_id, + provider_name=provider, ) diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index c62fd77d36..32940cbc29 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,5 +1,6 @@ from flask_restful import Resource, marshal_with, reqparse from flask_restful.inputs import int_range +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound import services @@ -7,6 +8,7 @@ from controllers.service_api import api from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db from fields.conversation_fields import ( conversation_delete_fields, conversation_infinite_scroll_pagination_fields, @@ -39,14 +41,16 @@ class ConversationApi(Resource): args = parser.parse_args() try: - return ConversationService.pagination_by_last_id( - app_model=app_model, - user=end_user, - last_id=args["last_id"], - limit=args["limit"], - invoke_from=InvokeFrom.SERVICE_API, - sort_by=args["sort_by"], - ) + with Session(db.engine) as session: + return ConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=end_user, + last_id=args["last_id"], + limit=args["limit"], + invoke_from=InvokeFrom.SERVICE_API, + sort_by=args["sort_by"], + ) except services.errors.conversation.LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index ada40ec9cb..599401bc6f 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -104,10 +104,11 @@ class MessageFeedbackApi(Resource): parser = reqparse.RequestParser() parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") + parser.add_argument("content", type=str, location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) + MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index c3b0cd4f44..fe0d7c74f3 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,11 +1,13 @@ from flask_restful import marshal_with, reqparse from flask_restful.inputs import int_range +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from controllers.web import api from controllers.web.error import NotChatAppError from controllers.web.wraps import WebApiResource from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value from models.model import AppMode @@ -40,15 +42,17 @@ class ConversationListApi(WebApiResource): pinned = True if args["pinned"] == "true" else False try: - return WebConversationService.pagination_by_last_id( - app_model=app_model, - user=end_user, - last_id=args["last_id"], - limit=args["limit"], - invoke_from=InvokeFrom.WEB_APP, - pinned=pinned, - sort_by=args["sort_by"], - ) + with Session(db.engine) as session: + return WebConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=end_user, + last_id=args["last_id"], + limit=args["limit"], + invoke_from=InvokeFrom.WEB_APP, + pinned=pinned, + sort_by=args["sort_by"], + ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 98891f5d00..febaab5328 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -108,7 +108,7 @@ class MessageFeedbackApi(WebApiResource): args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) + MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index d6b8eb2855..ae68df6bdc 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -5,6 +5,7 @@ from flask_restful import marshal_with, reqparse import services from controllers.common import helpers +from controllers.common.errors import RemoteFileUploadError from controllers.web.wraps import WebApiResource from core.file import helpers as file_helpers from core.helper import ssrf_proxy @@ -38,10 +39,14 @@ class RemoteFileUploadApi(WebApiResource): url = args["url"] - resp = ssrf_proxy.head(url=url) - if resp.status_code != httpx.codes.OK: - resp = ssrf_proxy.get(url=url, timeout=3) - resp.raise_for_status() + try: + resp = ssrf_proxy.head(url=url) + if resp.status_code != httpx.codes.OK: + resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) + if resp.status_code != httpx.codes.OK: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") + except httpx.RequestError as e: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}") file_info = helpers.guess_file_info_from_response(resp) diff --git a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py index 18b115dfe4..29709914b7 100644 --- a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py +++ b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py @@ -4,14 +4,17 @@ import logging import queue import re import threading +from collections.abc import Iterable from core.app.entities.queue_entities import ( + MessageQueueMessage, QueueAgentMessageEvent, QueueLLMChunkEvent, QueueNodeSucceededEvent, QueueTextChunkEvent, + WorkflowQueueMessage, ) -from core.model_manager import ModelManager +from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType @@ -21,7 +24,7 @@ class AudioTrunk: self.status = status -def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str): +def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str): if not text_content or text_content.isspace(): return return model_instance.invoke_tts( @@ -29,13 +32,19 @@ def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str): ) -def _process_future(future_queue, audio_queue): +def _process_future( + future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None], + audio_queue: queue.Queue[AudioTrunk], +): while True: try: future = future_queue.get() if future is None: break - for audio in future.result(): + invoke_result = future.result() + if not invoke_result: + continue + for audio in invoke_result: audio_base64 = base64.b64encode(bytes(audio)) audio_queue.put(AudioTrunk("responding", audio=audio_base64)) except Exception as e: @@ -49,8 +58,8 @@ class AppGeneratorTTSPublisher: self.logger = logging.getLogger(__name__) self.tenant_id = tenant_id self.msg_text = "" - self._audio_queue = queue.Queue() - self._msg_queue = queue.Queue() + self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue() + self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() self.match = re.compile(r"[。.!?]") self.model_manager = ModelManager() self.model_instance = self.model_manager.get_default_model_instance( @@ -66,14 +75,11 @@ class AppGeneratorTTSPublisher: self._runtime_thread = threading.Thread(target=self._runtime).start() self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3) - def publish(self, message): - try: - self._msg_queue.put(message) - except Exception as e: - self.logger.warning(e) + def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /): + self._msg_queue.put(message) def _runtime(self): - future_queue = queue.Queue() + future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None] = queue.Queue() threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start() while True: try: @@ -110,7 +116,7 @@ class AppGeneratorTTSPublisher: break future_queue.put(None) - def check_and_get_audio(self) -> AudioTrunk | None: + def check_and_get_audio(self): try: if self._last_audio_event and self._last_audio_event.status == "finish": if self.executor: diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 8e1731b314..c7bf37dd08 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -197,11 +197,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc stream_response=stream_response, ) - def _listen_audio_msg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str): if not publisher: return None - audio_msg: AudioTrunk = publisher.check_and_get_audio() - if audio_msg and audio_msg.status != "finish": + audio_msg = publisher.check_and_get_audio() + if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish": return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None @@ -222,7 +222,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) + audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -512,7 +512,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc # only publish tts message at text chunk streaming if tts_publisher: - tts_publisher.publish(message=queue_message) + tts_publisher.publish(queue_message) self._task_state.answer += delta_text yield self._message_to_stream_response( diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 4c4d282e99..3725c6e6dd 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -1,7 +1,6 @@ import queue import time from abc import abstractmethod -from collections.abc import Generator from enum import Enum from typing import Any @@ -11,9 +10,11 @@ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, + MessageQueueMessage, QueueErrorEvent, QueuePingEvent, QueueStopEvent, + WorkflowQueueMessage, ) from extensions.ext_redis import redis_client @@ -37,11 +38,11 @@ class AppQueueManager: AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}" ) - q = queue.Queue() + q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() self._q = q - def listen(self) -> Generator: + def listen(self): """ Listen to queue :return: diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index b129904efb..fd84908975 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -171,11 +171,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response) - def _listen_audio_msg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str): if not publisher: return None - audio_msg: AudioTrunk = publisher.check_and_get_audio() - if audio_msg and audio_msg.status != "finish": + audio_msg = publisher.check_and_get_audio() + if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish": return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None @@ -196,7 +196,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) + audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -421,7 +421,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa # only publish tts message at text chunk streaming if tts_publisher: - tts_publisher.publish(message=queue_message) + tts_publisher.publish(queue_message) self._task_state.answer += delta_text yield self._text_chunk_to_stream_response( diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 917649f34e..4216cd46cf 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -201,11 +201,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan stream_response=stream_response, ) - def _listen_audio_msg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str): if publisher is None: return None - audio_msg: AudioTrunk = publisher.check_and_get_audio() - if audio_msg and audio_msg.status != "finish": + audio_msg = publisher.check_and_get_audio() + if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish": # audio_str = audio_msg.audio.decode('utf-8', errors='ignore') return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 951fef1fa1..5061804310 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -508,6 +508,12 @@ class WorkflowCycleManage: :param workflow_run: workflow run :return: """ + # Attach WorkflowRun to an active session so "created_by_role" can be accessed. + workflow_run = db.session.merge(workflow_run) + + # Refresh to ensure any expired attributes are fully loaded + db.session.refresh(workflow_run) + created_by = None if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value: created_by_account = workflow_run.created_by_account diff --git a/api/core/errors/error.py b/api/core/errors/error.py index 3b186476eb..ad921bc255 100644 --- a/api/core/errors/error.py +++ b/api/core/errors/error.py @@ -1,7 +1,7 @@ from typing import Optional -class LLMError(Exception): +class LLMError(ValueError): """Base class for all LLM exceptions.""" description: Optional[str] = None @@ -16,7 +16,7 @@ class LLMBadRequestError(LLMError): description = "Bad Request" -class ProviderTokenNotInitError(Exception): +class ProviderTokenNotInitError(ValueError): """ Custom exception raised when the provider token is not initialized. """ @@ -27,7 +27,7 @@ class ProviderTokenNotInitError(Exception): self.description = args[0] if args else self.description -class QuotaExceededError(Exception): +class QuotaExceededError(ValueError): """ Custom exception raised when the quota for a provider has been exceeded. """ @@ -35,7 +35,7 @@ class QuotaExceededError(Exception): description = "Quota Exceeded" -class AppInvokeQuotaExceededError(Exception): +class AppInvokeQuotaExceededError(ValueError): """ Custom exception raised when the quota for an app has been exceeded. """ @@ -43,7 +43,7 @@ class AppInvokeQuotaExceededError(Exception): description = "App Invoke Quota Exceeded" -class ModelCurrentlyNotSupportError(Exception): +class ModelCurrentlyNotSupportError(ValueError): """ Custom exception raised when the model not support """ @@ -51,7 +51,7 @@ class ModelCurrentlyNotSupportError(Exception): description = "Model Currently Not Support" -class InvokeRateLimitError(Exception): +class InvokeRateLimitError(ValueError): """Raised when the Invoke returns rate limit error.""" description = "Rate Limit Error" diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 011ff382ea..584e3e9698 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -118,7 +118,7 @@ class CodeExecutor: return response.data.stdout or "" @classmethod - def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]) -> dict: + def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]): """ Execute code :param language: code language diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index b7a07b21e1..605719747a 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -25,7 +25,7 @@ class TemplateTransformer(ABC): return runner_script, preload_script @classmethod - def extract_result_str_from_response(cls, response: str) -> str: + def extract_result_str_from_response(cls, response: str): result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL) if not result: raise ValueError("Failed to parse result") @@ -33,13 +33,21 @@ class TemplateTransformer(ABC): return result @classmethod - def transform_response(cls, response: str) -> dict: + def transform_response(cls, response: str) -> Mapping[str, Any]: """ Transform response to dict :param response: response :return: """ - return json.loads(cls.extract_result_str_from_response(response)) + try: + result = json.loads(cls.extract_result_str_from_response(response)) + except json.JSONDecodeError: + raise ValueError("failed to parse response") + if not isinstance(result, dict): + raise ValueError("result must be a dict") + if not all(isinstance(k, str) for k in result): + raise ValueError("result keys must be strings") + return result @classmethod @abstractmethod diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 425b3535c4..424983a819 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -24,7 +24,7 @@ BACKOFF_FACTOR = 0.5 STATUS_FORCELIST = [429, 500, 502, 503, 504] -class MaxRetriesExceededError(Exception): +class MaxRetriesExceededError(ValueError): """Raised when the maximum number of retries is exceeded.""" pass diff --git a/api/core/llm_generator/output_parser/errors.py b/api/core/llm_generator/output_parser/errors.py index 1e743f1757..0922806ca8 100644 --- a/api/core/llm_generator/output_parser/errors.py +++ b/api/core/llm_generator/output_parser/errors.py @@ -1,2 +1,2 @@ -class OutputParserError(Exception): +class OutputParserError(ValueError): pass diff --git a/api/core/model_runtime/errors/invoke.py b/api/core/model_runtime/errors/invoke.py index edfb19c7d0..7675425361 100644 --- a/api/core/model_runtime/errors/invoke.py +++ b/api/core/model_runtime/errors/invoke.py @@ -1,7 +1,7 @@ from typing import Optional -class InvokeError(Exception): +class InvokeError(ValueError): """Base class for all LLM exceptions.""" description: Optional[str] = None diff --git a/api/core/model_runtime/errors/validate.py b/api/core/model_runtime/errors/validate.py index 7fcd2133f9..16bebcc67d 100644 --- a/api/core/model_runtime/errors/validate.py +++ b/api/core/model_runtime/errors/validate.py @@ -1,4 +1,4 @@ -class CredentialsValidateFailedError(Exception): +class CredentialsValidateFailedError(ValueError): """ Credentials validate failed error """ diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index e1d35ff872..c0ea8c6325 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -531,7 +531,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): "source": { "type": "base64", "media_type": message_content.mime_type, - "data": message_content.data, + "data": message_content.base64_data, }, } sub_messages.append(sub_message_dict) diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index b54668a12d..7d19ccbb74 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -21,6 +21,7 @@ from core.model_runtime.entities.message_entities import ( PromptMessageContentType, PromptMessageTool, SystemPromptMessage, + TextPromptMessageContent, ToolPromptMessage, UserPromptMessage, ) @@ -143,7 +144,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): """ try: - ping_message = SystemPromptMessage(content="ping") + ping_message = UserPromptMessage(content="ping") self._generate(model, credentials, [ping_message], {"max_output_tokens": 5}) except Exception as ex: @@ -187,17 +188,23 @@ class GoogleLargeLanguageModel(LargeLanguageModel): config_kwargs["stop_sequences"] = stop genai.configure(api_key=credentials["google_api_key"]) - google_model = genai.GenerativeModel(model_name=model) history = [] + system_instruction = None for msg in prompt_messages: # makes message roles strictly alternating content = self._format_message_to_glm_content(msg) if history and history[-1]["role"] == content["role"]: history[-1]["parts"].extend(content["parts"]) + elif content["role"] == "system": + system_instruction = content["parts"][0] else: history.append(content) + if not history: + raise InvokeError("The user prompt message is required. You only add a system prompt message.") + + google_model = genai.GenerativeModel(model_name=model, system_instruction=system_instruction) response = google_model.generate_content( contents=history, generation_config=genai.types.GenerationConfig(**config_kwargs), @@ -404,7 +411,10 @@ class GoogleLargeLanguageModel(LargeLanguageModel): ) return glm_content elif isinstance(message, SystemPromptMessage): - return {"role": "user", "parts": [to_part(message.content)]} + if isinstance(message.content, list): + text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content) + message.content = "".join(c.data for c in text_contents) + return {"role": "system", "parts": [to_part(message.content)]} elif isinstance(message, ToolPromptMessage): return { "role": "function", diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index b7799ce1fb..a04fc6ee78 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -355,7 +355,13 @@ class TraceTask: def conversation_trace(self, **kwargs): return kwargs - def workflow_trace(self, workflow_run: WorkflowRun, conversation_id, user_id): + def workflow_trace(self, workflow_run: WorkflowRun | None, conversation_id, user_id): + if not workflow_run: + raise ValueError("Workflow run not found") + + db.session.merge(workflow_run) + db.sessoin.refresh(workflow_run) + workflow_id = workflow_run.workflow_id tenant_id = workflow_run.tenant_id workflow_run_id = workflow_run.id diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 992415657e..d17d76333e 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -83,11 +83,15 @@ class DataPostProcessor: if reranking_model: try: model_manager = ModelManager() + reranking_provider_name = reranking_model.get("reranking_provider_name") + reranking_model_name = reranking_model.get("reranking_model_name") + if not reranking_provider_name or not reranking_model_name: + return None rerank_model_instance = model_manager.get_model_instance( tenant_id=tenant_id, - provider=reranking_model["reranking_provider_name"], + provider=reranking_provider_name, model_type=ModelType.RERANK, - model=reranking_model["reranking_model_name"], + model=reranking_model_name, ) return rerank_model_instance except InvokeAuthorizationError: diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index b2141396d6..18f8d4e839 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -103,7 +103,7 @@ class RetrievalService: if exceptions: exception_message = ";\n".join(exceptions) - raise Exception(exception_message) + raise ValueError(exception_message) if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: data_post_processor = DataPostProcessor( diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py index 09f328cd1f..582ad636b1 100644 --- a/api/core/tools/provider/app_tool_provider.py +++ b/api/core/tools/provider/app_tool_provider.py @@ -62,7 +62,7 @@ class AppToolProviderEntity(ToolProviderController): user_input_form_list = app_model_config.user_input_form_list for input_form in user_input_form_list: # get type - form_type = input_form.keys()[0] + form_type = list(input_form.keys())[0] default = input_form[form_type]["default"] required = input_form[form_type]["required"] label = input_form[form_type]["label"] diff --git a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py new file mode 100644 index 0000000000..050b468b74 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py @@ -0,0 +1,115 @@ +import json +import operator +from typing import Any, Optional, Union + +import boto3 + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class BedrockRetrieveTool(BuiltinTool): + bedrock_client: Any = None + knowledge_base_id: str = None + topk: int = None + + def _bedrock_retrieve( + self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None + ): + try: + retrieval_query = {"text": query_input} + + retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}} + + # 如果有元数据过滤条件,则添加到检索配置中 + if metadata_filter: + retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter + + response = self.bedrock_client.retrieve( + knowledgeBaseId=knowledge_base_id, + retrievalQuery=retrieval_query, + retrievalConfiguration=retrieval_configuration, + ) + + results = [] + for result in response.get("retrievalResults", []): + results.append( + { + "content": result.get("content", {}).get("text", ""), + "score": result.get("score", 0.0), + "metadata": result.get("metadata", {}), + } + ) + + return results + except Exception as e: + raise Exception(f"Error retrieving from knowledge base: {str(e)}") + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + line = 0 + try: + if not self.bedrock_client: + aws_region = tool_parameters.get("aws_region") + if aws_region: + self.bedrock_client = boto3.client("bedrock-agent-runtime", region_name=aws_region) + else: + self.bedrock_client = boto3.client("bedrock-agent-runtime") + + line = 1 + if not self.knowledge_base_id: + self.knowledge_base_id = tool_parameters.get("knowledge_base_id") + if not self.knowledge_base_id: + return self.create_text_message("Please provide knowledge_base_id") + + line = 2 + if not self.topk: + self.topk = tool_parameters.get("topk", 5) + + line = 3 + query = tool_parameters.get("query", "") + if not query: + return self.create_text_message("Please input query") + + # 获取元数据过滤条件(如果存在) + metadata_filter_str = tool_parameters.get("metadata_filter") + metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None + + line = 4 + retrieved_docs = self._bedrock_retrieve( + query_input=query, + knowledge_base_id=self.knowledge_base_id, + num_results=self.topk, + metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法 + ) + + line = 5 + # Sort results by score in descending order + sorted_docs = sorted(retrieved_docs, key=operator.itemgetter("score"), reverse=True) + + line = 6 + return [self.create_json_message(res) for res in sorted_docs] + + except Exception as e: + return self.create_text_message(f"Exception {str(e)}, line : {line}") + + def validate_parameters(self, parameters: dict[str, Any]) -> None: + """ + Validate the parameters + """ + if not parameters.get("knowledge_base_id"): + raise ValueError("knowledge_base_id is required") + + if not parameters.get("query"): + raise ValueError("query is required") + + # 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供) + metadata_filter_str = parameters.get("metadata_filter") + if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict): + raise ValueError("metadata_filter must be a valid JSON object") diff --git a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml new file mode 100644 index 0000000000..9e51d52def --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml @@ -0,0 +1,87 @@ +identity: + name: bedrock_retrieve + author: AWS + label: + en_US: Bedrock Retrieve + zh_Hans: Bedrock检索 + pt_BR: Bedrock Retrieve + icon: icon.svg + +description: + human: + en_US: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool + zh_Hans: Amazon Bedrock知识库检索工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署说明 + pt_BR: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. + llm: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool + +parameters: + - name: knowledge_base_id + type: string + required: true + label: + en_US: Bedrock Knowledge Base ID + zh_Hans: Bedrock知识库ID + pt_BR: Bedrock Knowledge Base ID + human_description: + en_US: ID of the Bedrock Knowledge Base to retrieve from + zh_Hans: 用于检索的Bedrock知识库ID + pt_BR: ID of the Bedrock Knowledge Base to retrieve from + llm_description: ID of the Bedrock Knowledge Base to retrieve from + form: form + + - name: query + type: string + required: true + label: + en_US: Query string + zh_Hans: 查询语句 + pt_BR: Query string + human_description: + en_US: The search query to retrieve relevant information + zh_Hans: 用于检索相关信息的查询语句 + pt_BR: The search query to retrieve relevant information + llm_description: The search query to retrieve relevant information + form: llm + + - name: topk + type: number + required: false + form: form + label: + en_US: Limit for results count + zh_Hans: 返回结果数量限制 + pt_BR: Limit for results count + human_description: + en_US: Maximum number of results to return + zh_Hans: 最大返回结果数量 + pt_BR: Maximum number of results to return + min: 1 + max: 10 + default: 5 + + - name: aws_region + type: string + required: false + label: + en_US: AWS Region + zh_Hans: AWS 区域 + pt_BR: AWS Region + human_description: + en_US: AWS region where the Bedrock Knowledge Base is located + zh_Hans: Bedrock知识库所在的AWS区域 + pt_BR: AWS region where the Bedrock Knowledge Base is located + llm_description: AWS region where the Bedrock Knowledge Base is located + form: form + + - name: metadata_filter + type: string + required: false + label: + en_US: Metadata Filter + zh_Hans: 元数据过滤器 + pt_BR: Metadata Filter + human_description: + en_US: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})' + zh_Hans: '元数据的JSON格式过滤条件(例如,{{"greaterThan": {"key: "aaa", "value": 10}})' + pt_BR: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})' + form: form diff --git a/api/core/tools/provider/builtin/aws/tools/nova_canvas.py b/api/core/tools/provider/builtin/aws/tools/nova_canvas.py new file mode 100644 index 0000000000..954dbe35a4 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/nova_canvas.py @@ -0,0 +1,357 @@ +import base64 +import json +import logging +import re +from datetime import datetime +from typing import Any, Union +from urllib.parse import urlparse + +import boto3 + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.tool.builtin_tool import BuiltinTool + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class NovaCanvasTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + Invoke AWS Bedrock Nova Canvas model for image generation + """ + # Get common parameters + prompt = tool_parameters.get("prompt", "") + image_output_s3uri = tool_parameters.get("image_output_s3uri", "").strip() + if not prompt: + return self.create_text_message("Please provide a text prompt for image generation.") + if not image_output_s3uri or urlparse(image_output_s3uri).scheme != "s3": + return self.create_text_message("Please provide an valid S3 URI for image output.") + + task_type = tool_parameters.get("task_type", "TEXT_IMAGE") + aws_region = tool_parameters.get("aws_region", "us-east-1") + + # Get common image generation config parameters + width = tool_parameters.get("width", 1024) + height = tool_parameters.get("height", 1024) + cfg_scale = tool_parameters.get("cfg_scale", 8.0) + negative_prompt = tool_parameters.get("negative_prompt", "") + seed = tool_parameters.get("seed", 0) + quality = tool_parameters.get("quality", "standard") + + # Handle S3 image if provided + image_input_s3uri = tool_parameters.get("image_input_s3uri", "") + if task_type != "TEXT_IMAGE": + if not image_input_s3uri or urlparse(image_input_s3uri).scheme != "s3": + return self.create_text_message("Please provide a valid S3 URI for image to image generation.") + + # Parse S3 URI + parsed_uri = urlparse(image_input_s3uri) + bucket = parsed_uri.netloc + key = parsed_uri.path.lstrip("/") + + # Initialize S3 client and download image + s3_client = boto3.client("s3") + response = s3_client.get_object(Bucket=bucket, Key=key) + image_data = response["Body"].read() + + # Base64 encode the image + input_image = base64.b64encode(image_data).decode("utf-8") + + try: + # Initialize Bedrock client + bedrock = boto3.client(service_name="bedrock-runtime", region_name=aws_region) + + # Base image generation config + image_generation_config = { + "width": width, + "height": height, + "cfgScale": cfg_scale, + "seed": seed, + "numberOfImages": 1, + "quality": quality, + } + + # Prepare request body based on task type + body = {"imageGenerationConfig": image_generation_config} + + if task_type == "TEXT_IMAGE": + body["taskType"] = "TEXT_IMAGE" + body["textToImageParams"] = {"text": prompt} + if negative_prompt: + body["textToImageParams"]["negativeText"] = negative_prompt + + elif task_type == "COLOR_GUIDED_GENERATION": + colors = tool_parameters.get("colors", "#ff8080-#ffb280-#ffe680-#ffe680") + if not self._validate_color_string(colors): + return self.create_text_message("Please provide valid colors in hexadecimal format.") + + body["taskType"] = "COLOR_GUIDED_GENERATION" + body["colorGuidedGenerationParams"] = { + "colors": colors.split("-"), + "referenceImage": input_image, + "text": prompt, + } + if negative_prompt: + body["colorGuidedGenerationParams"]["negativeText"] = negative_prompt + + elif task_type == "IMAGE_VARIATION": + similarity_strength = tool_parameters.get("similarity_strength", 0.5) + + body["taskType"] = "IMAGE_VARIATION" + body["imageVariationParams"] = { + "images": [input_image], + "similarityStrength": similarity_strength, + "text": prompt, + } + if negative_prompt: + body["imageVariationParams"]["negativeText"] = negative_prompt + + elif task_type == "INPAINTING": + mask_prompt = tool_parameters.get("mask_prompt") + if not mask_prompt: + return self.create_text_message("Please provide a mask prompt for image inpainting.") + + body["taskType"] = "INPAINTING" + body["inPaintingParams"] = {"image": input_image, "maskPrompt": mask_prompt, "text": prompt} + if negative_prompt: + body["inPaintingParams"]["negativeText"] = negative_prompt + + elif task_type == "OUTPAINTING": + mask_prompt = tool_parameters.get("mask_prompt") + if not mask_prompt: + return self.create_text_message("Please provide a mask prompt for image outpainting.") + outpainting_mode = tool_parameters.get("outpainting_mode", "DEFAULT") + + body["taskType"] = "OUTPAINTING" + body["outPaintingParams"] = { + "image": input_image, + "maskPrompt": mask_prompt, + "outPaintingMode": outpainting_mode, + "text": prompt, + } + if negative_prompt: + body["outPaintingParams"]["negativeText"] = negative_prompt + + elif task_type == "BACKGROUND_REMOVAL": + body["taskType"] = "BACKGROUND_REMOVAL" + body["backgroundRemovalParams"] = {"image": input_image} + + else: + return self.create_text_message(f"Unsupported task type: {task_type}") + + # Call Nova Canvas model + response = bedrock.invoke_model( + body=json.dumps(body), + modelId="amazon.nova-canvas-v1:0", + accept="application/json", + contentType="application/json", + ) + + # Process response + response_body = json.loads(response.get("body").read()) + if response_body.get("error"): + raise Exception(f"Error in model response: {response_body.get('error')}") + base64_image = response_body.get("images")[0] + + # Upload to S3 if image_output_s3uri is provided + try: + # Parse S3 URI for output + parsed_uri = urlparse(image_output_s3uri) + output_bucket = parsed_uri.netloc + output_base_path = parsed_uri.path.lstrip("/") + # Generate filename with timestamp + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_key = f"{output_base_path}/canvas-output-{timestamp}.png" + + # Initialize S3 client if not already done + s3_client = boto3.client("s3", region_name=aws_region) + + # Decode base64 image and upload to S3 + image_data = base64.b64decode(base64_image) + s3_client.put_object(Bucket=output_bucket, Key=output_key, Body=image_data, ContentType="image/png") + logger.info(f"Image uploaded to s3://{output_bucket}/{output_key}") + except Exception as e: + logger.exception("Failed to upload image to S3") + # Return image + return [ + self.create_text_message(f"Image is available at: s3://{output_bucket}/{output_key}"), + self.create_blob_message( + blob=base64.b64decode(base64_image), + meta={"mime_type": "image/png"}, + save_as=self.VariableKey.IMAGE.value, + ), + ] + + except Exception as e: + return self.create_text_message(f"Failed to generate image: {str(e)}") + + def _validate_color_string(self, color_string) -> bool: + color_pattern = r"^#[0-9a-fA-F]{6}(?:-#[0-9a-fA-F]{6})*$" + + if re.match(color_pattern, color_string): + return True + return False + + def get_runtime_parameters(self) -> list[ToolParameter]: + parameters = [ + ToolParameter( + name="prompt", + label=I18nObject(en_US="Prompt", zh_Hans="提示词"), + type=ToolParameter.ToolParameterType.STRING, + required=True, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="Text description of the image you want to generate or modify", + zh_Hans="您想要生成或修改的图像的文本描述", + ), + llm_description="Describe the image you want to generate or how you want to modify the input image", + ), + ToolParameter( + name="image_input_s3uri", + label=I18nObject(en_US="Input image s3 uri", zh_Hans="输入图片的s3 uri"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject(en_US="Image to be modified", zh_Hans="想要修改的图片"), + ), + ToolParameter( + name="image_output_s3uri", + label=I18nObject(en_US="Output Image S3 URI", zh_Hans="输出图片的S3 URI目录"), + type=ToolParameter.ToolParameterType.STRING, + required=True, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="S3 URI where the generated image should be uploaded", zh_Hans="生成的图像应该上传到的S3 URI" + ), + ), + ToolParameter( + name="width", + label=I18nObject(en_US="Width", zh_Hans="宽度"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=1024, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Width of the generated image", zh_Hans="生成图像的宽度"), + ), + ToolParameter( + name="height", + label=I18nObject(en_US="Height", zh_Hans="高度"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=1024, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Height of the generated image", zh_Hans="生成图像的高度"), + ), + ToolParameter( + name="cfg_scale", + label=I18nObject(en_US="CFG Scale", zh_Hans="CFG比例"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=8.0, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="How strongly the image should conform to the prompt", zh_Hans="图像应该多大程度上符合提示词" + ), + ), + ToolParameter( + name="negative_prompt", + label=I18nObject(en_US="Negative Prompt", zh_Hans="负面提示词"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default="", + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="Things you don't want in the generated image", zh_Hans="您不想在生成的图像中出现的内容" + ), + ), + ToolParameter( + name="seed", + label=I18nObject(en_US="Seed", zh_Hans="种子值"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=0, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Random seed for image generation", zh_Hans="图像生成的随机种子"), + ), + ToolParameter( + name="aws_region", + label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default="us-east-1", + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"), + ), + ToolParameter( + name="task_type", + label=I18nObject(en_US="Task Type", zh_Hans="任务类型"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default="TEXT_IMAGE", + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject(en_US="Type of image generation task", zh_Hans="图像生成任务的类型"), + ), + ToolParameter( + name="quality", + label=I18nObject(en_US="Quality", zh_Hans="质量"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default="standard", + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="Quality of the generated image (standard or premium)", zh_Hans="生成图像的质量(标准或高级)" + ), + ), + ToolParameter( + name="colors", + label=I18nObject(en_US="Colors", zh_Hans="颜色"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="List of colors for color-guided generation, example: #ff8080-#ffb280-#ffe680-#ffe680", + zh_Hans="颜色引导生成的颜色列表, 例子: #ff8080-#ffb280-#ffe680-#ffe680", + ), + ), + ToolParameter( + name="similarity_strength", + label=I18nObject(en_US="Similarity Strength", zh_Hans="相似度强度"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=0.5, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="How similar the generated image should be to the input image (0.0 to 1.0)", + zh_Hans="生成的图像应该与输入图像的相似程度(0.0到1.0)", + ), + ), + ToolParameter( + name="mask_prompt", + label=I18nObject(en_US="Mask Prompt", zh_Hans="蒙版提示词"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="Text description to generate mask for inpainting/outpainting", + zh_Hans="用于生成内补绘制/外补绘制蒙版的文本描述", + ), + ), + ToolParameter( + name="outpainting_mode", + label=I18nObject(en_US="Outpainting Mode", zh_Hans="外补绘制模式"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default="DEFAULT", + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="Mode for outpainting (DEFAULT or other supported modes)", + zh_Hans="外补绘制的模式(DEFAULT或其他支持的模式)", + ), + ), + ] + + return parameters diff --git a/api/core/tools/provider/builtin/aws/tools/nova_canvas.yaml b/api/core/tools/provider/builtin/aws/tools/nova_canvas.yaml new file mode 100644 index 0000000000..a72fd9c8ef --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/nova_canvas.yaml @@ -0,0 +1,175 @@ +identity: + name: nova_canvas + author: AWS + label: + en_US: AWS Bedrock Nova Canvas + zh_Hans: AWS Bedrock Nova Canvas + icon: icon.svg +description: + human: + en_US: A tool for generating and modifying images using AWS Bedrock's Nova Canvas model. Supports text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html + zh_Hans: 使用 AWS Bedrock 的 Nova Canvas 模型生成和修改图像的工具。支持文生图、颜色引导生成、图像变体、内补绘制、外补绘制和背景移除功能, 输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html。 + llm: Generate or modify images using AWS Bedrock's Nova Canvas model with multiple task types including text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal. +parameters: + - name: task_type + type: string + required: false + default: TEXT_IMAGE + label: + en_US: Task Type + zh_Hans: 任务类型 + human_description: + en_US: Type of image generation task (TEXT_IMAGE, COLOR_GUIDED_GENERATION, IMAGE_VARIATION, INPAINTING, OUTPAINTING, BACKGROUND_REMOVAL) + zh_Hans: 图像生成任务的类型(文生图、颜色引导生成、图像变体、内补绘制、外补绘制、背景移除) + form: llm + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + human_description: + en_US: Text description of the image you want to generate or modify + zh_Hans: 您想要生成或修改的图像的文本描述 + llm_description: Describe the image you want to generate or how you want to modify the input image + form: llm + - name: image_input_s3uri + type: string + required: false + label: + en_US: Input image s3 uri + zh_Hans: 输入图片的s3 uri + human_description: + en_US: The input image to modify (required for all modes except TEXT_IMAGE) + zh_Hans: 要修改的输入图像(除文生图外的所有模式都需要) + llm_description: The input image you want to modify. Required for all modes except TEXT_IMAGE. + form: llm + - name: image_output_s3uri + type: string + required: true + label: + en_US: Output S3 URI + zh_Hans: 输出S3 URI + human_description: + en_US: The S3 URI where the generated image will be saved. If provided, the image will be uploaded with name format canvas-output-{timestamp}.png + zh_Hans: 生成的图像将保存到的S3 URI。如果提供,图像将以canvas-output-{timestamp}.png的格式上传 + llm_description: Optional S3 URI where the generated image will be uploaded. The image will be saved with a timestamp-based filename. + form: form + - name: negative_prompt + type: string + required: false + label: + en_US: Negative Prompt + zh_Hans: 负面提示词 + human_description: + en_US: Things you don't want in the generated image + zh_Hans: 您不想在生成的图像中出现的内容 + form: llm + - name: width + type: number + required: false + label: + en_US: Width + zh_Hans: 宽度 + human_description: + en_US: Width of the generated image + zh_Hans: 生成图像的宽度 + form: form + default: 1024 + - name: height + type: number + required: false + label: + en_US: Height + zh_Hans: 高度 + human_description: + en_US: Height of the generated image + zh_Hans: 生成图像的高度 + form: form + default: 1024 + - name: cfg_scale + type: number + required: false + label: + en_US: CFG Scale + zh_Hans: CFG比例 + human_description: + en_US: How strongly the image should conform to the prompt + zh_Hans: 图像应该多大程度上符合提示词 + form: form + default: 8.0 + - name: seed + type: number + required: false + label: + en_US: Seed + zh_Hans: 种子值 + human_description: + en_US: Random seed for image generation + zh_Hans: 图像生成的随机种子 + form: form + default: 0 + - name: aws_region + type: string + required: false + default: us-east-1 + label: + en_US: AWS Region + zh_Hans: AWS 区域 + human_description: + en_US: AWS region for Bedrock service + zh_Hans: Bedrock 服务的 AWS 区域 + form: form + - name: quality + type: string + required: false + default: standard + label: + en_US: Quality + zh_Hans: 质量 + human_description: + en_US: Quality of the generated image (standard or premium) + zh_Hans: 生成图像的质量(标准或高级) + form: form + - name: colors + type: string + required: false + label: + en_US: Colors + zh_Hans: 颜色 + human_description: + en_US: List of colors for color-guided generation + zh_Hans: 颜色引导生成的颜色列表 + form: form + - name: similarity_strength + type: number + required: false + default: 0.5 + label: + en_US: Similarity Strength + zh_Hans: 相似度强度 + human_description: + en_US: How similar the generated image should be to the input image (0.0 to 1.0) + zh_Hans: 生成的图像应该与输入图像的相似程度(0.0到1.0) + form: form + - name: mask_prompt + type: string + required: false + label: + en_US: Mask Prompt + zh_Hans: 蒙版提示词 + human_description: + en_US: Text description to generate mask for inpainting/outpainting + zh_Hans: 用于生成内补绘制/外补绘制蒙版的文本描述 + form: llm + - name: outpainting_mode + type: string + required: false + default: DEFAULT + label: + en_US: Outpainting Mode + zh_Hans: 外补绘制模式 + human_description: + en_US: Mode for outpainting (DEFAULT or other supported modes) + zh_Hans: 外补绘制的模式(DEFAULT或其他支持的模式) + form: form diff --git a/api/core/tools/provider/builtin/aws/tools/nova_reel.py b/api/core/tools/provider/builtin/aws/tools/nova_reel.py new file mode 100644 index 0000000000..bfd3d302b2 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/nova_reel.py @@ -0,0 +1,371 @@ +import base64 +import logging +import time +from io import BytesIO +from typing import Any, Optional, Union +from urllib.parse import urlparse + +import boto3 +from botocore.exceptions import ClientError +from PIL import Image + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.tool.builtin_tool import BuiltinTool + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +NOVA_REEL_DEFAULT_REGION = "us-east-1" +NOVA_REEL_DEFAULT_DIMENSION = "1280x720" +NOVA_REEL_DEFAULT_FPS = 24 +NOVA_REEL_DEFAULT_DURATION = 6 +NOVA_REEL_MODEL_ID = "amazon.nova-reel-v1:0" +NOVA_REEL_STATUS_CHECK_INTERVAL = 5 + +# Image requirements +NOVA_REEL_REQUIRED_IMAGE_WIDTH = 1280 +NOVA_REEL_REQUIRED_IMAGE_HEIGHT = 720 +NOVA_REEL_REQUIRED_IMAGE_MODE = "RGB" + + +class NovaReelTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + Invoke AWS Bedrock Nova Reel model for video generation. + + Args: + user_id: The ID of the user making the request + tool_parameters: Dictionary containing the tool parameters + + Returns: + ToolInvokeMessage containing either the video content or status information + """ + try: + # Validate and extract parameters + params = self._validate_and_extract_parameters(tool_parameters) + if isinstance(params, ToolInvokeMessage): + return params + + # Initialize AWS clients + bedrock, s3_client = self._initialize_aws_clients(params["aws_region"]) + + # Prepare model input + model_input = self._prepare_model_input(params, s3_client) + if isinstance(model_input, ToolInvokeMessage): + return model_input + + # Start video generation + invocation = self._start_video_generation(bedrock, model_input, params["video_output_s3uri"]) + invocation_arn = invocation["invocationArn"] + + # Handle async/sync mode + return self._handle_generation_mode(bedrock, s3_client, invocation_arn, params["async_mode"]) + + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "Unknown") + error_message = e.response.get("Error", {}).get("Message", str(e)) + logger.exception(f"AWS API error: {error_code} - {error_message}") + return self.create_text_message(f"AWS service error: {error_code} - {error_message}") + except Exception as e: + logger.error(f"Unexpected error in video generation: {str(e)}", exc_info=True) + return self.create_text_message(f"Failed to generate video: {str(e)}") + + def _validate_and_extract_parameters( + self, tool_parameters: dict[str, Any] + ) -> Union[dict[str, Any], ToolInvokeMessage]: + """Validate and extract parameters from the input dictionary.""" + prompt = tool_parameters.get("prompt", "") + video_output_s3uri = tool_parameters.get("video_output_s3uri", "").strip() + + # Validate required parameters + if not prompt: + return self.create_text_message("Please provide a text prompt for video generation.") + if not video_output_s3uri: + return self.create_text_message("Please provide an S3 URI for video output.") + + # Validate S3 URI format + if not video_output_s3uri.startswith("s3://"): + return self.create_text_message("Invalid S3 URI format. Must start with 's3://'") + + # Ensure S3 URI ends with '/' + video_output_s3uri = video_output_s3uri if video_output_s3uri.endswith("/") else video_output_s3uri + "/" + + return { + "prompt": prompt, + "video_output_s3uri": video_output_s3uri, + "image_input_s3uri": tool_parameters.get("image_input_s3uri", "").strip(), + "aws_region": tool_parameters.get("aws_region", NOVA_REEL_DEFAULT_REGION), + "dimension": tool_parameters.get("dimension", NOVA_REEL_DEFAULT_DIMENSION), + "seed": int(tool_parameters.get("seed", 0)), + "fps": int(tool_parameters.get("fps", NOVA_REEL_DEFAULT_FPS)), + "duration": int(tool_parameters.get("duration", NOVA_REEL_DEFAULT_DURATION)), + "async_mode": bool(tool_parameters.get("async", True)), + } + + def _initialize_aws_clients(self, region: str) -> tuple[Any, Any]: + """Initialize AWS Bedrock and S3 clients.""" + bedrock = boto3.client(service_name="bedrock-runtime", region_name=region) + s3_client = boto3.client("s3", region_name=region) + return bedrock, s3_client + + def _prepare_model_input(self, params: dict[str, Any], s3_client: Any) -> Union[dict[str, Any], ToolInvokeMessage]: + """Prepare the input for the Nova Reel model.""" + model_input = { + "taskType": "TEXT_VIDEO", + "textToVideoParams": {"text": params["prompt"]}, + "videoGenerationConfig": { + "durationSeconds": params["duration"], + "fps": params["fps"], + "dimension": params["dimension"], + "seed": params["seed"], + }, + } + + # Add image if provided + if params["image_input_s3uri"]: + try: + image_data = self._get_image_from_s3(s3_client, params["image_input_s3uri"]) + if not image_data: + return self.create_text_message("Failed to retrieve image from S3") + + # Process and validate image + processed_image = self._process_and_validate_image(image_data) + if isinstance(processed_image, ToolInvokeMessage): + return processed_image + + # Convert processed image to base64 + img_buffer = BytesIO() + processed_image.save(img_buffer, format="PNG") + img_buffer.seek(0) + input_image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8") + + model_input["textToVideoParams"]["images"] = [ + {"format": "png", "source": {"bytes": input_image_base64}} + ] + except Exception as e: + logger.error(f"Error processing input image: {str(e)}", exc_info=True) + return self.create_text_message(f"Failed to process input image: {str(e)}") + + return model_input + + def _process_and_validate_image(self, image_data: bytes) -> Union[Image.Image, ToolInvokeMessage]: + """ + Process and validate the input image according to Nova Reel requirements. + + Requirements: + - Must be 1280x720 pixels + - Must be RGB format (8 bits per channel) + - If PNG, alpha channel must not have transparent/translucent pixels + """ + try: + # Open image + img = Image.open(BytesIO(image_data)) + + # Convert RGBA to RGB if needed, ensuring no transparency + if img.mode == "RGBA": + # Check for transparency + if img.getchannel("A").getextrema()[0] < 255: + return self.create_text_message( + "PNG image contains transparent or translucent pixels, which is not supported. " + "Please provide an image without transparency." + ) + # Convert to RGB + img = img.convert("RGB") + elif img.mode != "RGB": + # Convert any other mode to RGB + img = img.convert("RGB") + + # Validate/adjust dimensions + if img.size != (NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT): + logger.warning( + f"Image dimensions {img.size} do not match required dimensions " + f"({NOVA_REEL_REQUIRED_IMAGE_WIDTH}x{NOVA_REEL_REQUIRED_IMAGE_HEIGHT}). Resizing..." + ) + img = img.resize( + (NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT), Image.Resampling.LANCZOS + ) + + # Validate bit depth + if img.mode != NOVA_REEL_REQUIRED_IMAGE_MODE: + return self.create_text_message( + f"Image must be in {NOVA_REEL_REQUIRED_IMAGE_MODE} mode with 8 bits per channel" + ) + + return img + + except Exception as e: + logger.error(f"Error processing image: {str(e)}", exc_info=True) + return self.create_text_message( + "Failed to process image. Please ensure the image is a valid JPEG or PNG file." + ) + + def _get_image_from_s3(self, s3_client: Any, s3_uri: str) -> Optional[bytes]: + """Download and return image data from S3.""" + parsed_uri = urlparse(s3_uri) + bucket = parsed_uri.netloc + key = parsed_uri.path.lstrip("/") + + response = s3_client.get_object(Bucket=bucket, Key=key) + return response["Body"].read() + + def _start_video_generation(self, bedrock: Any, model_input: dict[str, Any], output_s3uri: str) -> dict[str, Any]: + """Start the async video generation process.""" + return bedrock.start_async_invoke( + modelId=NOVA_REEL_MODEL_ID, + modelInput=model_input, + outputDataConfig={"s3OutputDataConfig": {"s3Uri": output_s3uri}}, + ) + + def _handle_generation_mode( + self, bedrock: Any, s3_client: Any, invocation_arn: str, async_mode: bool + ) -> ToolInvokeMessage: + """Handle async or sync video generation mode.""" + invocation_response = bedrock.get_async_invoke(invocationArn=invocation_arn) + video_path = invocation_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"] + video_uri = f"{video_path}/output.mp4" + + if async_mode: + return self.create_text_message( + f"Video generation started.\nInvocation ARN: {invocation_arn}\n" + f"Video will be available at: {video_uri}" + ) + + return self._wait_for_completion(bedrock, s3_client, invocation_arn) + + def _wait_for_completion(self, bedrock: Any, s3_client: Any, invocation_arn: str) -> ToolInvokeMessage: + """Wait for video generation completion and handle the result.""" + while True: + status_response = bedrock.get_async_invoke(invocationArn=invocation_arn) + status = status_response["status"] + video_path = status_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"] + + if status == "Completed": + return self._handle_completed_video(s3_client, video_path) + elif status == "Failed": + failure_message = status_response.get("failureMessage", "Unknown error") + return self.create_text_message(f"Video generation failed.\nError: {failure_message}") + elif status == "InProgress": + time.sleep(NOVA_REEL_STATUS_CHECK_INTERVAL) + else: + return self.create_text_message(f"Unexpected status: {status}") + + def _handle_completed_video(self, s3_client: Any, video_path: str) -> ToolInvokeMessage: + """Handle completed video generation and return the result.""" + parsed_uri = urlparse(video_path) + bucket = parsed_uri.netloc + key = parsed_uri.path.lstrip("/") + "/output.mp4" + + try: + response = s3_client.get_object(Bucket=bucket, Key=key) + video_content = response["Body"].read() + return [ + self.create_text_message(f"Video is available at: {video_path}/output.mp4"), + self.create_blob_message(blob=video_content, meta={"mime_type": "video/mp4"}, save_as="output.mp4"), + ] + except Exception as e: + logger.error(f"Error downloading video: {str(e)}", exc_info=True) + return self.create_text_message( + f"Video generation completed but failed to download video: {str(e)}\n" + f"Video is available at: s3://{bucket}/{key}" + ) + + def get_runtime_parameters(self) -> list[ToolParameter]: + """Define the tool's runtime parameters.""" + parameters = [ + ToolParameter( + name="prompt", + label=I18nObject(en_US="Prompt", zh_Hans="提示词"), + type=ToolParameter.ToolParameterType.STRING, + required=True, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="Text description of the video you want to generate", zh_Hans="您想要生成的视频的文本描述" + ), + llm_description="Describe the video you want to generate", + ), + ToolParameter( + name="video_output_s3uri", + label=I18nObject(en_US="Output S3 URI", zh_Hans="输出S3 URI"), + type=ToolParameter.ToolParameterType.STRING, + required=True, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="S3 URI where the generated video will be stored", zh_Hans="生成的视频将存储的S3 URI" + ), + ), + ToolParameter( + name="dimension", + label=I18nObject(en_US="Dimension", zh_Hans="尺寸"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default=NOVA_REEL_DEFAULT_DIMENSION, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Video dimensions (width x height)", zh_Hans="视频尺寸(宽 x 高)"), + ), + ToolParameter( + name="duration", + label=I18nObject(en_US="Duration", zh_Hans="时长"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=NOVA_REEL_DEFAULT_DURATION, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Video duration in seconds", zh_Hans="视频时长(秒)"), + ), + ToolParameter( + name="seed", + label=I18nObject(en_US="Seed", zh_Hans="种子值"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=0, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Random seed for video generation", zh_Hans="视频生成的随机种子"), + ), + ToolParameter( + name="fps", + label=I18nObject(en_US="FPS", zh_Hans="帧率"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=NOVA_REEL_DEFAULT_FPS, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="Frames per second for the generated video", zh_Hans="生成视频的每秒帧数" + ), + ), + ToolParameter( + name="async", + label=I18nObject(en_US="Async Mode", zh_Hans="异步模式"), + type=ToolParameter.ToolParameterType.BOOLEAN, + required=False, + default=True, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="Whether to run in async mode (return immediately) or sync mode (wait for completion)", + zh_Hans="是否以异步模式运行(立即返回)或同步模式(等待完成)", + ), + ), + ToolParameter( + name="aws_region", + label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default=NOVA_REEL_DEFAULT_REGION, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"), + ), + ToolParameter( + name="image_input_s3uri", + label=I18nObject(en_US="Input Image S3 URI", zh_Hans="输入图像S3 URI"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame", + zh_Hans="用作第一帧的输入图像(1280x720 JPEG/PNG)的S3 URI", + ), + ), + ] + + return parameters diff --git a/api/core/tools/provider/builtin/aws/tools/nova_reel.yaml b/api/core/tools/provider/builtin/aws/tools/nova_reel.yaml new file mode 100644 index 0000000000..16df5ba5c9 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/nova_reel.yaml @@ -0,0 +1,124 @@ +identity: + name: nova_reel + author: AWS + label: + en_US: AWS Bedrock Nova Reel + zh_Hans: AWS Bedrock Nova Reel + icon: icon.svg +description: + human: + en_US: A tool for generating videos using AWS Bedrock's Nova Reel model. Supports text-to-video generation and image-to-video generation with customizable parameters like duration, FPS, and dimensions. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html + zh_Hans: 使用 AWS Bedrock 的 Nova Reel 模型生成视频的工具。支持文本生成视频和图像生成视频功能,可自定义持续时间、帧率和尺寸等参数。输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html + llm: Generate videos using AWS Bedrock's Nova Reel model with support for both text-to-video and image-to-video generation, allowing customization of video properties like duration, frame rate, and resolution. + +parameters: + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + human_description: + en_US: Text description of the video you want to generate + zh_Hans: 您想要生成的视频的文本描述 + llm_description: Describe the video you want to generate + form: llm + + - name: video_output_s3uri + type: string + required: true + label: + en_US: Output S3 URI + zh_Hans: 输出S3 URI + human_description: + en_US: S3 URI where the generated video will be stored + zh_Hans: 生成的视频将存储的S3 URI + form: form + + - name: dimension + type: string + required: false + default: 1280x720 + label: + en_US: Dimension + zh_Hans: 尺寸 + human_description: + en_US: Video dimensions (width x height) + zh_Hans: 视频尺寸(宽 x 高) + form: form + + - name: duration + type: number + required: false + default: 6 + label: + en_US: Duration + zh_Hans: 时长 + human_description: + en_US: Video duration in seconds + zh_Hans: 视频时长(秒) + form: form + + - name: seed + type: number + required: false + default: 0 + label: + en_US: Seed + zh_Hans: 种子值 + human_description: + en_US: Random seed for video generation + zh_Hans: 视频生成的随机种子 + form: form + + - name: fps + type: number + required: false + default: 24 + label: + en_US: FPS + zh_Hans: 帧率 + human_description: + en_US: Frames per second for the generated video + zh_Hans: 生成视频的每秒帧数 + form: form + + - name: async + type: boolean + required: false + default: true + label: + en_US: Async Mode + zh_Hans: 异步模式 + human_description: + en_US: Whether to run in async mode (return immediately) or sync mode (wait for completion) + zh_Hans: 是否以异步模式运行(立即返回)或同步模式(等待完成) + form: llm + + - name: aws_region + type: string + required: false + default: us-east-1 + label: + en_US: AWS Region + zh_Hans: AWS 区域 + human_description: + en_US: AWS region for Bedrock service + zh_Hans: Bedrock 服务的 AWS 区域 + form: form + + - name: image_input_s3uri + type: string + required: false + label: + en_US: Input Image S3 URI + zh_Hans: 输入图像S3 URI + human_description: + en_US: S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame + zh_Hans: 用作第一帧的输入图像(1280x720 JPEG/PNG)的S3 URI + form: llm + +development: + dependencies: + - boto3 + - pillow diff --git a/api/core/tools/provider/builtin/aws/tools/s3_operator.py b/api/core/tools/provider/builtin/aws/tools/s3_operator.py new file mode 100644 index 0000000000..e4026b07a8 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/s3_operator.py @@ -0,0 +1,80 @@ +from typing import Any, Union +from urllib.parse import urlparse + +import boto3 + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class S3Operator(BuiltinTool): + s3_client: Any = None + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + try: + # Initialize S3 client if not already done + if not self.s3_client: + aws_region = tool_parameters.get("aws_region") + if aws_region: + self.s3_client = boto3.client("s3", region_name=aws_region) + else: + self.s3_client = boto3.client("s3") + + # Parse S3 URI + s3_uri = tool_parameters.get("s3_uri") + if not s3_uri: + return self.create_text_message("s3_uri parameter is required") + + parsed_uri = urlparse(s3_uri) + if parsed_uri.scheme != "s3": + return self.create_text_message("Invalid S3 URI format. Must start with 's3://'") + + bucket = parsed_uri.netloc + # Remove leading slash from key + key = parsed_uri.path.lstrip("/") + + operation_type = tool_parameters.get("operation_type", "read") + generate_presign_url = tool_parameters.get("generate_presign_url", False) + presign_expiry = int(tool_parameters.get("presign_expiry", 3600)) # default 1 hour + + if operation_type == "write": + text_content = tool_parameters.get("text_content") + if not text_content: + return self.create_text_message("text_content parameter is required for write operation") + + # Write content to S3 + self.s3_client.put_object(Bucket=bucket, Key=key, Body=text_content.encode("utf-8")) + result = f"s3://{bucket}/{key}" + + # Generate presigned URL for the written object if requested + if generate_presign_url: + result = self.s3_client.generate_presigned_url( + "get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry + ) + + else: # read operation + # Get object from S3 + response = self.s3_client.get_object(Bucket=bucket, Key=key) + result = response["Body"].read().decode("utf-8") + + # Generate presigned URL if requested + if generate_presign_url: + result = self.s3_client.generate_presigned_url( + "get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry + ) + + return self.create_text_message(text=result) + + except self.s3_client.exceptions.NoSuchBucket: + return self.create_text_message(f"Bucket '{bucket}' does not exist") + except self.s3_client.exceptions.NoSuchKey: + return self.create_text_message(f"Object '{key}' does not exist in bucket '{bucket}'") + except Exception as e: + return self.create_text_message(f"Exception: {str(e)}") diff --git a/api/core/tools/provider/builtin/aws/tools/s3_operator.yaml b/api/core/tools/provider/builtin/aws/tools/s3_operator.yaml new file mode 100644 index 0000000000..642fc2966e --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/s3_operator.yaml @@ -0,0 +1,98 @@ +identity: + name: s3_operator + author: AWS + label: + en_US: AWS S3 Operator + zh_Hans: AWS S3 读写器 + pt_BR: AWS S3 Operator + icon: icon.svg +description: + human: + en_US: AWS S3 Writer and Reader + zh_Hans: 读写S3 bucket中的文件 + pt_BR: AWS S3 Writer and Reader + llm: AWS S3 Writer and Reader +parameters: + - name: text_content + type: string + required: false + label: + en_US: The text to write + zh_Hans: 待写入的文本 + pt_BR: The text to write + human_description: + en_US: The text to write + zh_Hans: 待写入的文本 + pt_BR: The text to write + llm_description: The text to write + form: llm + - name: s3_uri + type: string + required: true + label: + en_US: s3 uri + zh_Hans: s3 uri + pt_BR: s3 uri + human_description: + en_US: s3 uri + zh_Hans: s3 uri + pt_BR: s3 uri + llm_description: s3 uri + form: llm + - name: aws_region + type: string + required: true + label: + en_US: region of bucket + zh_Hans: bucket 所在的region + pt_BR: region of bucket + human_description: + en_US: region of bucket + zh_Hans: bucket 所在的region + pt_BR: region of bucket + llm_description: region of bucket + form: form + - name: operation_type + type: select + required: true + label: + en_US: operation type + zh_Hans: 操作类型 + pt_BR: operation type + human_description: + en_US: operation type + zh_Hans: 操作类型 + pt_BR: operation type + default: read + options: + - value: read + label: + en_US: read + zh_Hans: 读 + - value: write + label: + en_US: write + zh_Hans: 写 + form: form + - name: generate_presign_url + type: boolean + required: false + label: + en_US: Generate presigned URL + zh_Hans: 生成预签名URL + human_description: + en_US: Whether to generate a presigned URL for the S3 object + zh_Hans: 是否生成S3对象的预签名URL + default: false + form: form + - name: presign_expiry + type: number + required: false + label: + en_US: Presigned URL expiration time + zh_Hans: 预签名URL有效期 + human_description: + en_US: Expiration time in seconds for the presigned URL + zh_Hans: 预签名URL的有效期(秒) + default: 3600 + form: form diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index 636debffd4..48aac75dbb 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -210,7 +210,7 @@ class ApiTool(Tool): ) return response else: - raise ValueError(f"Invalid http method {self.method}") + raise ValueError(f"Invalid http method {method}") def _convert_body_property_any_of( self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10 diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 5052f0897a..2aaca6d82e 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -8,9 +8,10 @@ from mimetypes import guess_extension, guess_type from typing import Optional, Union from uuid import uuid4 -from httpx import get +import httpx from configs import dify_config +from core.helper import ssrf_proxy from extensions.ext_database import db from extensions.ext_storage import storage from models.model import MessageFile @@ -94,12 +95,11 @@ class ToolFileManager: ) -> ToolFile: # try to download image try: - response = get(file_url) + response = ssrf_proxy.get(file_url) response.raise_for_status() blob = response.content - except Exception as e: - logger.exception(f"Failed to download file from {file_url}") - raise + except httpx.TimeoutException as e: + raise ValueError(f"timeout when downloading file from {file_url}") mimetype = guess_type(file_url)[0] or "octet/stream" extension = guess_extension(mimetype) or ".bin" diff --git a/api/core/workflow/nodes/base/exc.py b/api/core/workflow/nodes/base/exc.py index ec134e031c..aeecf40640 100644 --- a/api/core/workflow/nodes/base/exc.py +++ b/api/core/workflow/nodes/base/exc.py @@ -1,4 +1,4 @@ -class BaseNodeError(Exception): +class BaseNodeError(ValueError): """Base class for node errors.""" pass diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 19b9078a5c..4e371ca436 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Optional, Union +from typing import Any, Optional from configs import dify_config from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage @@ -59,7 +59,7 @@ class CodeNode(BaseNode[CodeNodeData]): ) # Transform result - result = self._transform_result(result, self.node_data.outputs) + result = self._transform_result(result=result, output_schema=self.node_data.outputs) except (CodeExecutionError, CodeNodeError) as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ @@ -67,18 +67,17 @@ class CodeNode(BaseNode[CodeNodeData]): return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) - def _check_string(self, value: str, variable: str) -> str: + def _check_string(self, value: str | None, variable: str) -> str | None: """ Check string :param value: value :param variable: variable :return: """ + if value is None: + return None if not isinstance(value, str): - if value is None: - return None - else: - raise OutputValidationError(f"Output variable `{variable}` must be a string") + raise OutputValidationError(f"Output variable `{variable}` must be a string") if len(value) > dify_config.CODE_MAX_STRING_LENGTH: raise OutputValidationError( @@ -88,18 +87,17 @@ class CodeNode(BaseNode[CodeNodeData]): return value.replace("\x00", "") - def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]: + def _check_number(self, value: int | float | None, variable: str) -> int | float | None: """ Check number :param value: value :param variable: variable :return: """ + if value is None: + return None if not isinstance(value, int | float): - if value is None: - return None - else: - raise OutputValidationError(f"Output variable `{variable}` must be a number") + raise OutputValidationError(f"Output variable `{variable}` must be a number") if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER: raise OutputValidationError( @@ -118,14 +116,12 @@ class CodeNode(BaseNode[CodeNodeData]): return value def _transform_result( - self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1 - ) -> dict: - """ - Transform result - :param result: result - :param output_schema: output schema - :return: - """ + self, + result: Mapping[str, Any], + output_schema: Optional[dict[str, CodeNodeData.Output]], + prefix: str = "", + depth: int = 1, + ): if depth > dify_config.CODE_MAX_DEPTH: raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.") diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 95d0ea3aab..6d82dbe6d7 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -1,6 +1,7 @@ import csv import io import json +import logging import os import tempfile @@ -22,6 +23,8 @@ from models.workflow import WorkflowNodeExecutionStatus from .entities import DocumentExtractorNodeData from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError +logger = logging.getLogger(__name__) + class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): """ @@ -177,10 +180,43 @@ def _extract_text_from_pdf(file_content: bytes) -> str: def _extract_text_from_doc(file_content: bytes) -> str: + """ + Extract text from a DOC/DOCX file. + For now support only paragraph and table add more if needed + """ try: doc_file = io.BytesIO(file_content) doc = docx.Document(doc_file) - return "\n".join([paragraph.text for paragraph in doc.paragraphs]) + text = [] + # Process paragraphs + for paragraph in doc.paragraphs: + if paragraph.text.strip(): + text.append(paragraph.text) + + # Process tables + for table in doc.tables: + # Table header + try: + # table maybe cause errors so ignore it. + if len(table.rows) > 0 and table.rows[0].cells is not None: + # Check if any cell in the table has text + has_content = False + for row in table.rows: + if any(cell.text.strip() for cell in row.cells): + has_content = True + break + + if has_content: + markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n" + markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n" + for row in table.rows[1:]: + markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n" + text.append(markdown_table) + except Exception as e: + logger.warning(f"Failed to extract table from DOC/DOCX: {e}") + continue + + return "\n".join(text) except Exception as e: raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 5b960ea615..c8c854a43b 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -179,6 +179,15 @@ class ParameterExtractorNode(LLMNode): error=str(e), metadata={}, ) + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=inputs, + process_data=process_data, + outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)}, + error=str(e), + metadata={}, + ) error = None diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 15cfabe478..5043e25e2b 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -154,8 +154,7 @@ class QuestionClassifierNode(LLMNode): }, llm_usage=usage, ) - - except ValueError as e: + except Exception as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, diff --git a/api/core/workflow/nodes/variable_assigner/common/exc.py b/api/core/workflow/nodes/variable_assigner/common/exc.py index a1178fb020..f8dbedc290 100644 --- a/api/core/workflow/nodes/variable_assigner/common/exc.py +++ b/api/core/workflow/nodes/variable_assigner/common/exc.py @@ -1,4 +1,4 @@ -class VariableOperatorNodeError(Exception): +class VariableOperatorNodeError(ValueError): """Base error type, don't use directly.""" pass diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 01d95dcfb3..13034f5cf5 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -116,8 +116,11 @@ def _build_from_local_file( tenant_id: str, transfer_method: FileTransferMethod, ) -> File: + upload_file_id = mapping.get("upload_file_id") + if not upload_file_id: + raise ValueError("Invalid upload file id") stmt = select(UploadFile).where( - UploadFile.id == mapping.get("upload_file_id"), + UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, ) diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 41c5d20c4b..267af611f5 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -27,7 +27,7 @@ def parse_json_markdown(json_string: str) -> dict: extracted_content = json_string[start_index:end_index].strip() parsed = json.loads(extracted_content) else: - raise Exception("Could not find JSON block in the output.") + raise ValueError("could not find json block in the output.") return parsed @@ -36,10 +36,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: try: json_obj = parse_json_markdown(text) except json.JSONDecodeError as e: - raise OutputParserError(f"Got invalid JSON object. Error: {e}") + raise OutputParserError(f"got invalid json object. error: {e}") for key in expected_keys: if key not in json_obj: raise OutputParserError( - f"Got invalid return object. Expected key `{key}` to be present, but got {json_obj}" + f"got invalid return object. expected key `{key}` to be present, but got {json_obj}" ) return json_obj diff --git a/api/models/account.py b/api/models/account.py index ce17b90def..932ba1da57 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -2,6 +2,7 @@ import enum import json from flask_login import UserMixin +from sqlalchemy import func from .engine import db from .types import StringUUID @@ -30,11 +31,11 @@ class Account(UserMixin, db.Model): timezone = db.Column(db.String(255)) last_login_at = db.Column(db.DateTime) last_login_ip = db.Column(db.String(255)) - last_active_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + last_active_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying")) initialized_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def is_password_set(self): @@ -187,8 +188,8 @@ class Tenant(db.Model): plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying")) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) custom_config = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def get_accounts(self) -> list[Account]: return ( @@ -228,8 +229,8 @@ class TenantAccountJoin(db.Model): current = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) role = db.Column(db.String(16), nullable=False, server_default="normal") invited_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class AccountIntegrate(db.Model): @@ -245,8 +246,8 @@ class AccountIntegrate(db.Model): provider = db.Column(db.String(16), nullable=False) open_id = db.Column(db.String(255), nullable=False) encrypted_token = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class InvitationCode(db.Model): @@ -265,4 +266,4 @@ class InvitationCode(db.Model): used_by_tenant_id = db.Column(StringUUID) used_by_account_id = db.Column(StringUUID) deprecated_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 4d4182cabd..fbffe7a3b2 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,5 +1,7 @@ import enum +from sqlalchemy import func + from .engine import db from .types import StringUUID @@ -23,4 +25,4 @@ class APIBasedExtension(db.Model): name = db.Column(db.String(255), nullable=False) api_endpoint = db.Column(db.String(255), nullable=False) api_key = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/dataset.py b/api/models/dataset.py index 97e4d6c0ef..7279e8d5b3 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -50,9 +50,9 @@ class Dataset(db.Model): indexing_technique = db.Column(db.String(255), nullable=True) index_struct = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) embedding_model = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True) collection_binding_id = db.Column(StringUUID, nullable=True) @@ -212,7 +212,7 @@ class DatasetProcessRule(db.Model): mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) rules = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) MODES = ["automatic", "custom"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] @@ -264,7 +264,7 @@ class Document(db.Model): created_from = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) created_api_request_id = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) # start processing processing_started_at = db.Column(db.DateTime, nullable=True) @@ -303,7 +303,7 @@ class Document(db.Model): archived_reason = db.Column(db.String(255), nullable=True) archived_by = db.Column(StringUUID, nullable=True) archived_at = db.Column(db.DateTime, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) doc_type = db.Column(db.String(40), nullable=True) doc_metadata = db.Column(db.JSON, nullable=True) doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) @@ -527,9 +527,9 @@ class DocumentSegment(db.Model): disabled_by = db.Column(StringUUID, nullable=True) status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) indexing_at = db.Column(db.DateTime, nullable=True) completed_at = db.Column(db.DateTime, nullable=True) error = db.Column(db.Text, nullable=True) @@ -697,7 +697,7 @@ class Embedding(db.Model): ) hash = db.Column(db.String(64), nullable=False) embedding = db.Column(db.LargeBinary, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying")) def set_embedding(self, embedding_data: list[float]): @@ -719,7 +719,7 @@ class DatasetCollectionBinding(db.Model): model_name = db.Column(db.String(255), nullable=False) type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) collection_name = db.Column(db.String(64), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TidbAuthBinding(db.Model): @@ -739,7 +739,7 @@ class TidbAuthBinding(db.Model): status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING")) account = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class Whitelist(db.Model): @@ -751,7 +751,7 @@ class Whitelist(db.Model): id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=True) category = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class DatasetPermission(db.Model): @@ -768,7 +768,7 @@ class DatasetPermission(db.Model): account_id = db.Column(StringUUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False) has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ExternalKnowledgeApis(db.Model): @@ -785,9 +785,9 @@ class ExternalKnowledgeApis(db.Model): tenant_id = db.Column(StringUUID, nullable=False) settings = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def to_dict(self): return { @@ -840,6 +840,6 @@ class ExternalKnowledgeBindings(db.Model): dataset_id = db.Column(StringUUID, nullable=False) external_knowledge_id = db.Column(db.Text, nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/model.py b/api/models/model.py index f484acde78..ebf0c16c56 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -30,7 +30,7 @@ class DifySetup(db.Model): __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) version = db.Column(db.String(255), nullable=False) - setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + setup_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class AppMode(StrEnum): @@ -85,9 +85,9 @@ class App(db.Model): tracing = db.Column(db.Text, nullable=True) max_active_requests = db.Column(db.Integer, nullable=True) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @property @@ -226,9 +226,9 @@ class AppModelConfig(db.Model): model_id = db.Column(db.String(255), nullable=True) configs = db.Column(db.JSON, nullable=True) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) opening_statement = db.Column(db.Text) suggested_questions = db.Column(db.Text) suggested_questions_after_answer = db.Column(db.Text) @@ -482,8 +482,8 @@ class RecommendedApp(db.Model): is_listed = db.Column(db.Boolean, nullable=False, default=True) install_count = db.Column(db.Integer, nullable=False, default=0) language = db.Column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def app(self): @@ -507,7 +507,7 @@ class InstalledApp(db.Model): position = db.Column(db.Integer, nullable=False, default=0) is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def app(self): @@ -548,8 +548,8 @@ class Conversation(db.Model): read_at = db.Column(db.DateTime) read_account_id = db.Column(StringUUID) dialogue_count: Mapped[int] = mapped_column(default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all") message_annotations = db.relationship( @@ -700,8 +700,10 @@ class Conversation(db.Model): def status_count(self): messages = db.session.query(Message).filter(Message.conversation_id == self.id).all() status_counts = { + WorkflowRunStatus.RUNNING: 0, WorkflowRunStatus.SUCCEEDED: 0, WorkflowRunStatus.FAILED: 0, + WorkflowRunStatus.STOPPED: 0, WorkflowRunStatus.PARTIAL_SUCCESSED: 0, } @@ -789,8 +791,8 @@ class Message(db.Model): from_source = db.Column(db.String(255), nullable=False) from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID) from_account_id: Mapped[Optional[str]] = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) workflow_run_id = db.Column(StringUUID) @@ -1115,8 +1117,8 @@ class MessageFeedback(db.Model): from_source = db.Column(db.String(255), nullable=False) from_end_user_id = db.Column(StringUUID) from_account_id = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def from_account(self): @@ -1162,9 +1164,7 @@ class MessageFile(db.Model): upload_file_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True) created_by_role: Mapped[str] = db.Column(db.String(255), nullable=False) created_by: Mapped[str] = db.Column(StringUUID, nullable=False) - created_at: Mapped[datetime] = db.Column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") - ) + created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class MessageAnnotation(db.Model): @@ -1184,8 +1184,8 @@ class MessageAnnotation(db.Model): content = db.Column(db.Text, nullable=False) hit_count = db.Column(db.Integer, nullable=False, server_default=db.text("0")) account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def account(self): @@ -1214,7 +1214,7 @@ class AppAnnotationHitHistory(db.Model): source = db.Column(db.Text, nullable=False) question = db.Column(db.Text, nullable=False) account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) score = db.Column(Float, nullable=False, server_default=db.text("0")) message_id = db.Column(StringUUID, nullable=False) annotation_question = db.Column(db.Text, nullable=False) @@ -1248,9 +1248,9 @@ class AppAnnotationSetting(db.Model): score_threshold = db.Column(Float, nullable=False, server_default=db.text("0")) collection_binding_id = db.Column(StringUUID, nullable=False) created_user_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_user_id = db.Column(StringUUID, nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def created_account(self): @@ -1296,9 +1296,9 @@ class OperationLog(db.Model): account_id = db.Column(StringUUID, nullable=False) action = db.Column(db.String(255), nullable=False) content = db.Column(db.JSON) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_ip = db.Column(db.String(255), nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class EndUser(UserMixin, db.Model): @@ -1317,8 +1317,8 @@ class EndUser(UserMixin, db.Model): name = db.Column(db.String(255)) is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) session_id = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class Site(db.Model): @@ -1349,9 +1349,9 @@ class Site(db.Model): prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) code = db.Column(db.String(255)) @property @@ -1393,7 +1393,7 @@ class ApiToken(db.Model): type = db.Column(db.String(16), nullable=False) token = db.Column(db.String(255), nullable=False) last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @staticmethod def generate_api_key(prefix, n): @@ -1424,9 +1424,7 @@ class UploadFile(db.Model): db.String(255), nullable=False, server_default=db.text("'account'::character varying") ) created_by: Mapped[str] = db.Column(StringUUID, nullable=False) - created_at: Mapped[datetime] = db.Column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") - ) + created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) used: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) used_by: Mapped[str | None] = db.Column(StringUUID, nullable=True) used_at: Mapped[datetime | None] = db.Column(db.DateTime, nullable=True) @@ -1483,7 +1481,7 @@ class ApiRequest(db.Model): request = db.Column(db.Text, nullable=True) response = db.Column(db.Text, nullable=True) ip = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class MessageChain(db.Model): @@ -1655,7 +1653,7 @@ class Tag(db.Model): type = db.Column(db.String(16), nullable=False) name = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TagBinding(db.Model): @@ -1671,7 +1669,7 @@ class TagBinding(db.Model): tag_id = db.Column(StringUUID, nullable=True) target_id = db.Column(StringUUID, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TraceAppConfig(db.Model): @@ -1685,8 +1683,10 @@ class TraceAppConfig(db.Model): app_id = db.Column(StringUUID, nullable=False) tracing_provider = db.Column(db.String(255), nullable=True) tracing_config = db.Column(db.JSON, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.now()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.now(), onupdate=func.now()) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column( + db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) is_active = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) @property diff --git a/api/models/provider.py b/api/models/provider.py index 65f70b76e9..fdd3e802d7 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,5 +1,7 @@ from enum import Enum +from sqlalchemy import func + from .engine import db from .types import StringUUID @@ -60,8 +62,8 @@ class Provider(db.Model): quota_limit = db.Column(db.BigInteger, nullable=True) quota_used = db.Column(db.BigInteger, default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def __repr__(self): return ( @@ -108,8 +110,8 @@ class ProviderModel(db.Model): model_type = db.Column(db.String(40), nullable=False) encrypted_config = db.Column(db.Text, nullable=True) is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TenantDefaultModel(db.Model): @@ -124,8 +126,8 @@ class TenantDefaultModel(db.Model): provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TenantPreferredModelProvider(db.Model): @@ -139,8 +141,8 @@ class TenantPreferredModelProvider(db.Model): tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) preferred_provider_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderOrder(db.Model): @@ -164,8 +166,8 @@ class ProviderOrder(db.Model): paid_at = db.Column(db.DateTime) pay_failed_at = db.Column(db.DateTime) refunded_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderModelSetting(db.Model): @@ -186,8 +188,8 @@ class ProviderModelSetting(db.Model): model_type = db.Column(db.String(40), nullable=False) enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class LoadBalancingModelConfig(db.Model): @@ -209,5 +211,5 @@ class LoadBalancingModelConfig(db.Model): name = db.Column(db.String(255), nullable=False) encrypted_config = db.Column(db.Text, nullable=True) enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/source.py b/api/models/source.py index 4d98572ef8..114db8e110 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,5 +1,6 @@ import json +from sqlalchemy import func from sqlalchemy.dialects.postgresql import JSONB from .engine import db @@ -19,8 +20,8 @@ class DataSourceOauthBinding(db.Model): access_token = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) source_info = db.Column(JSONB, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) @@ -37,8 +38,8 @@ class DataSourceApiKeyAuthBinding(db.Model): category = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) credentials = db.Column(db.Text, nullable=True) # JSON - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) def to_dict(self): diff --git a/api/models/tools.py b/api/models/tools.py index c390be4625..e90ab669c6 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -2,7 +2,7 @@ import json from typing import Optional import sqlalchemy as sa -from sqlalchemy import ForeignKey +from sqlalchemy import ForeignKey, func from sqlalchemy.orm import Mapped, mapped_column from core.tools.entities.common_entities import I18nObject @@ -36,8 +36,8 @@ class BuiltinToolProvider(db.Model): provider = db.Column(db.String(40), nullable=False) # credential of the tool provider encrypted_credentials = db.Column(db.Text, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def credentials(self) -> dict: @@ -74,8 +74,8 @@ class PublishedAppTool(db.Model): tool_name = db.Column(db.String(40), nullable=False) # author author = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def description_i18n(self) -> I18nObject: @@ -120,8 +120,8 @@ class ApiToolProvider(db.Model): # custom_disclaimer custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def schema_type(self) -> ApiProviderSchemaType: @@ -198,8 +198,8 @@ class WorkflowToolProvider(db.Model): # privacy policy privacy_policy = db.Column(db.String(255), nullable=True, server_default="") - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def user(self) -> Account | None: @@ -251,8 +251,8 @@ class ToolModelInvoke(db.Model): provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_price = db.Column(db.Numeric(10, 7)) currency = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ToolConversationVariables(db.Model): @@ -278,8 +278,8 @@ class ToolConversationVariables(db.Model): # variables pool variables_str = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def variables(self) -> dict: diff --git a/api/models/web.py b/api/models/web.py index a0f87cf456..028a768519 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,3 +1,6 @@ +from sqlalchemy import func +from sqlalchemy.orm import Mapped, mapped_column + from .engine import db from .model import Message from .types import StringUUID @@ -15,7 +18,7 @@ class SavedMessage(db.Model): message_id = db.Column(StringUUID, nullable=False) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def message(self): @@ -31,7 +34,7 @@ class PinnedConversation(db.Model): id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) - conversation_id = db.Column(StringUUID, nullable=False) + conversation_id: Mapped[str] = mapped_column(StringUUID) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/workflow.py b/api/models/workflow.py index 51a6fbc8c8..d5be949bf4 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -103,12 +103,13 @@ class Workflow(db.Model): graph: Mapped[str] = mapped_column(sa.Text) _features: Mapped[str] = mapped_column("features", sa.TEXT) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") - ) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) updated_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, default=datetime.now(tz=UTC), server_onupdate=func.current_timestamp() + db.DateTime, + nullable=False, + default=datetime.now(UTC).replace(tzinfo=None), + server_onupdate=func.current_timestamp(), ) _environment_variables: Mapped[str] = mapped_column( "environment_variables", db.Text, nullable=False, server_default="{}" @@ -406,7 +407,7 @@ class WorkflowRun(db.Model): total_steps = db.Column(db.Integer, server_default=db.text("0")) created_by_role = db.Column(db.String(255), nullable=False) # account, end_user created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) finished_at = db.Column(db.DateTime) exceptions_count = db.Column(db.Integer, server_default=db.text("0")) @@ -636,7 +637,7 @@ class WorkflowNodeExecution(db.Model): error = db.Column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) execution_metadata = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) finished_at = db.Column(db.DateTime) @@ -754,7 +755,7 @@ class WorkflowAppLog(db.Model): created_from = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def workflow_run(self): @@ -780,7 +781,7 @@ class ConversationVariable(db.Model): conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True) app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True) data = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=func.current_timestamp()) updated_at = db.Column( db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 8180c3b400..0478903fa4 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -340,7 +340,10 @@ class AppDslService: ) -> App: """Create a new app or update an existing one.""" app_data = data.get("app", {}) - app_mode = AppMode(app_data["mode"]) + app_mode = app_data.get("mode") + if not app_mode: + raise ValueError("loss app mode") + app_mode = AppMode(app_mode) # Set icon type icon_type_value = icon_type or app_data.get("icon_type") diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 8642972710..456dc3ebeb 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,8 +1,9 @@ -from collections.abc import Callable +from collections.abc import Callable, Sequence from datetime import UTC, datetime from typing import Optional, Union -from sqlalchemy import asc, desc, or_ +from sqlalchemy import asc, desc, func, or_, select +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator @@ -18,19 +19,21 @@ class ConversationService: @classmethod def pagination_by_last_id( cls, + *, + session: Session, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int, invoke_from: InvokeFrom, - include_ids: Optional[list] = None, - exclude_ids: Optional[list] = None, + include_ids: Optional[Sequence[str]] = None, + exclude_ids: Optional[Sequence[str]] = None, sort_by: str = "-updated_at", ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) - base_query = db.session.query(Conversation).filter( + stmt = select(Conversation).where( Conversation.is_deleted == False, Conversation.app_id == app_model.id, Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), @@ -38,37 +41,40 @@ class ConversationService: Conversation.from_account_id == (user.id if isinstance(user, Account) else None), or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value), ) - if include_ids is not None: - base_query = base_query.filter(Conversation.id.in_(include_ids)) - + stmt = stmt.where(Conversation.id.in_(include_ids)) if exclude_ids is not None: - base_query = base_query.filter(~Conversation.id.in_(exclude_ids)) + stmt = stmt.where(~Conversation.id.in_(exclude_ids)) # define sort fields and directions sort_field, sort_direction = cls._get_sort_params(sort_by) if last_id: - last_conversation = base_query.filter(Conversation.id == last_id).first() + last_conversation = session.scalar(stmt.where(Conversation.id == last_id)) if not last_conversation: raise LastConversationNotExistsError() # build filters based on sorting - filter_condition = cls._build_filter_condition(sort_field, sort_direction, last_conversation) - base_query = base_query.filter(filter_condition) - - base_query = base_query.order_by(sort_direction(getattr(Conversation, sort_field))) - - conversations = base_query.limit(limit).all() + filter_condition = cls._build_filter_condition( + sort_field=sort_field, + sort_direction=sort_direction, + reference_conversation=last_conversation, + ) + stmt = stmt.where(filter_condition) + query_stmt = stmt.order_by(sort_direction(getattr(Conversation, sort_field))).limit(limit) + conversations = session.scalars(query_stmt).all() has_more = False if len(conversations) == limit: current_page_last_conversation = conversations[-1] rest_filter_condition = cls._build_filter_condition( - sort_field, sort_direction, current_page_last_conversation, is_next_page=True + sort_field=sort_field, + sort_direction=sort_direction, + reference_conversation=current_page_last_conversation, ) - rest_count = base_query.filter(rest_filter_condition).count() - + count_stmt = stmt.where(rest_filter_condition) + count_stmt = select(func.count()).select_from(count_stmt.subquery()) + rest_count = session.scalar(count_stmt) or 0 if rest_count > 0: has_more = True @@ -81,11 +87,9 @@ class ConversationService: return sort_by, asc @classmethod - def _build_filter_condition( - cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation, is_next_page: bool = False - ): + def _build_filter_condition(cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation): field_value = getattr(reference_conversation, sort_field) - if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page): + if sort_direction == desc: return getattr(Conversation, sort_field) < field_value else: return getattr(Conversation, sort_field) > field_value diff --git a/api/services/message_service.py b/api/services/message_service.py index f432a77c80..be2922f4c5 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -151,7 +151,12 @@ class MessageService: @classmethod def create_feedback( - cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], rating: Optional[str] + cls, + app_model: App, + message_id: str, + user: Optional[Union[Account, EndUser]], + rating: Optional[str], + content: Optional[str], ) -> MessageFeedback: if not user: raise ValueError("user cannot be None") @@ -164,6 +169,7 @@ class MessageService: db.session.delete(feedback) elif rating and feedback: feedback.rating = rating + feedback.content = content elif not rating and not feedback: raise ValueError("rating cannot be None when feedback not exists") else: @@ -172,6 +178,7 @@ class MessageService: conversation_id=message.conversation_id, message_id=message.id, rating=rating, + content=content, from_source=("user" if isinstance(user, EndUser) else "admin"), from_end_user_id=(user.id if isinstance(user, EndUser) else None), from_account_id=(user.id if isinstance(user, Account) else None), diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index e2e49d017e..fada881fde 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -2,6 +2,9 @@ import json import logging from pathlib import Path +from sqlalchemy import select +from sqlalchemy.orm import Session + from configs import dify_config from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder @@ -32,7 +35,7 @@ class BuiltinToolManageService: tenant_id=tenant_id, provider_controller=provider_controller ) # check if user has added the provider - builtin_provider: BuiltinToolProvider = ( + builtin_provider = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, @@ -71,19 +74,18 @@ class BuiltinToolManageService: return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()]) @staticmethod - def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict): + def update_builtin_tool_provider( + session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict + ): """ update builtin tool provider """ # get if the provider exists - provider: BuiltinToolProvider = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, - ) - .first() + stmt = select(BuiltinToolProvider).where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, ) + provider = session.scalar(stmt) try: # get provider @@ -115,13 +117,10 @@ class BuiltinToolManageService: encrypted_credentials=json.dumps(credentials), ) - db.session.add(provider) - db.session.commit() + session.add(provider) else: provider.encrypted_credentials = json.dumps(credentials) - db.session.add(provider) - db.session.commit() # delete cache tool_configuration.delete_tool_credentials_cache() @@ -129,15 +128,15 @@ class BuiltinToolManageService: return {"result": "success"} @staticmethod - def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str): + def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str): """ get builtin tool provider credentials """ - provider: BuiltinToolProvider = ( + provider = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, + BuiltinToolProvider.provider == provider_name, ) .first() ) @@ -156,7 +155,7 @@ class BuiltinToolManageService: """ delete tool provider """ - provider: BuiltinToolProvider = ( + provider = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index d7ccc964cb..508fe20970 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -1,5 +1,8 @@ from typing import Optional, Union +from sqlalchemy import select +from sqlalchemy.orm import Session + from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -13,6 +16,8 @@ class WebConversationService: @classmethod def pagination_by_last_id( cls, + *, + session: Session, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], @@ -23,24 +28,25 @@ class WebConversationService: ) -> InfiniteScrollPagination: include_ids = None exclude_ids = None - if pinned is not None: - pinned_conversations = ( - db.session.query(PinnedConversation) - .filter( + if pinned is not None and user: + stmt = ( + select(PinnedConversation.conversation_id) + .where( PinnedConversation.app_id == app_model.id, PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), PinnedConversation.created_by == user.id, ) .order_by(PinnedConversation.created_at.desc()) - .all() ) - pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations] + pinned_conversation_ids = session.scalars(stmt).all() + if pinned: include_ids = pinned_conversation_ids else: exclude_ids = pinned_conversation_ids return ConversationService.pagination_by_last_id( + session=session, app_model=app_model, user=user, last_id=last_id, diff --git a/docker/.env.example b/docker/.env.example index e8ec246ae2..43e67a8db4 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -923,6 +923,3 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false # Maximum number of submitted thread count in a ThreadPool for parallel node execution MAX_SUBMIT_COUNT=100 -# Proxy -HTTP_PROXY= -HTTPS_PROXY= diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index d0738d9305..99bc14c717 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -386,8 +386,6 @@ x-shared-env: &shared-api-worker-env CSP_WHITELIST: ${CSP_WHITELIST:-} CREATE_TIDB_SERVICE_JOB_ENABLED: ${CREATE_TIDB_SERVICE_JOB_ENABLED:-false} MAX_SUBMIT_COUNT: ${MAX_SUBMIT_COUNT:-100} - HTTP_PROXY: ${HTTP_PROXY:-} - HTTPS_PROXY: ${HTTPS_PROXY:-} services: # API service diff --git a/web/app/components/base/chat/chat/answer/workflow-process.tsx b/web/app/components/base/chat/chat/answer/workflow-process.tsx index 4a09e27d98..bb9abdb6fc 100644 --- a/web/app/components/base/chat/chat/answer/workflow-process.tsx +++ b/web/app/components/base/chat/chat/answer/workflow-process.tsx @@ -64,6 +64,12 @@ const WorkflowProcessItem = ({ setShowMessageLogModal(true) }, [item, setCurrentLogItem, setCurrentLogModalActiveTab, setShowMessageLogModal]) + const showRetryDetail = useCallback(() => { + setCurrentLogItem(item) + setCurrentLogModalActiveTab('TRACING') + setShowMessageLogModal(true) + }, [item, setCurrentLogItem, setCurrentLogModalActiveTab, setShowMessageLogModal]) + return (
diff --git a/web/app/components/base/input/index.tsx b/web/app/components/base/input/index.tsx index bf8efdb65a..044fc27858 100644 --- a/web/app/components/base/input/index.tsx +++ b/web/app/components/base/input/index.tsx @@ -28,6 +28,7 @@ export type InputProps = { destructive?: boolean wrapperClassName?: string styleCss?: CSSProperties + unit?: string } & React.InputHTMLAttributes & VariantProps const Input = ({ @@ -43,6 +44,7 @@ const Input = ({ value, placeholder, onChange, + unit, ...props }: InputProps) => { const { t } = useTranslation() @@ -80,6 +82,13 @@ const Input = ({ {destructive && ( )} + { + unit && ( +
+ {unit} +
+ ) + }
) } diff --git a/web/app/components/base/search-input/index.tsx b/web/app/components/base/search-input/index.tsx index 89345fbe32..556a7bdf49 100644 --- a/web/app/components/base/search-input/index.tsx +++ b/web/app/components/base/search-input/index.tsx @@ -23,6 +23,7 @@ const SearchInput: FC = ({ const { t } = useTranslation() const [focus, setFocus] = useState(false) const isComposing = useRef(false) + const [internalValue, setInternalValue] = useState(value) return (
= ({ white && '!bg-white hover:!bg-white group-hover:!bg-white placeholder:!text-gray-400', )} placeholder={placeholder || t('common.operation.search')!} - value={value} + value={internalValue} onChange={(e) => { + setInternalValue(e.target.value) if (!isComposing.current) onChange(e.target.value) }} onCompositionStart={() => { isComposing.current = true }} - onCompositionEnd={() => { + onCompositionEnd={(e) => { isComposing.current = false + onChange(e.data) }} onFocus={() => setFocus(true)} onBlur={() => setFocus(false)} @@ -63,7 +66,10 @@ const SearchInput: FC = ({ {value && (
onChange('')} + onClick={() => { + onChange('') + setInternalValue('') + }} >
diff --git a/web/app/components/develop/template/template.en.mdx b/web/app/components/develop/template/template.en.mdx index f469076bf3..877955039c 100755 --- a/web/app/components/develop/template/template.en.mdx +++ b/web/app/components/develop/template/template.en.mdx @@ -346,6 +346,9 @@ The text generation application offers non-session support and is ideal for tran User identifier, defined by the developer's rules, must be unique within the application. + + The specific content of message feedback. + ### Response @@ -353,7 +356,7 @@ The text generation application offers non-session support and is ideal for tran - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -361,7 +364,8 @@ The text generation application offers non-session support and is ideal for tran --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` diff --git a/web/app/components/develop/template/template.ja.mdx b/web/app/components/develop/template/template.ja.mdx index bd92bd7f36..c3b376f2e7 100755 --- a/web/app/components/develop/template/template.ja.mdx +++ b/web/app/components/develop/template/template.ja.mdx @@ -345,6 +345,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from 開発者のルールで定義されたユーザー識別子。アプリケーション内で一意である必要があります。 + + メッセージのフィードバックです。 + ### レスポンス @@ -352,7 +355,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -360,7 +363,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` diff --git a/web/app/components/develop/template/template.zh.mdx b/web/app/components/develop/template/template.zh.mdx index 7b1bec3546..be7470480f 100755 --- a/web/app/components/develop/template/template.zh.mdx +++ b/web/app/components/develop/template/template.zh.mdx @@ -320,6 +320,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' 用户标识,由开发者定义规则,需保证用户标识在应用内唯一。 + + 消息反馈的具体信息。 + ### Response @@ -327,7 +330,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -335,7 +338,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` diff --git a/web/app/components/develop/template/template_advanced_chat.en.mdx b/web/app/components/develop/template/template_advanced_chat.en.mdx index 5f00977e75..5106b6f36a 100644 --- a/web/app/components/develop/template/template_advanced_chat.en.mdx +++ b/web/app/components/develop/template/template_advanced_chat.en.mdx @@ -444,6 +444,9 @@ Chat applications support session persistence, allowing previous chat history to User identifier, defined by the developer's rules, must be unique within the application. + + The specific content of message feedback. + ### Response @@ -451,7 +454,7 @@ Chat applications support session persistence, allowing previous chat history to - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -459,7 +462,8 @@ Chat applications support session persistence, allowing previous chat history to --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` diff --git a/web/app/components/develop/template/template_advanced_chat.ja.mdx b/web/app/components/develop/template/template_advanced_chat.ja.mdx index 7c933598f9..cf65a29b44 100644 --- a/web/app/components/develop/template/template_advanced_chat.ja.mdx +++ b/web/app/components/develop/template/template_advanced_chat.ja.mdx @@ -444,6 +444,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ユーザー識別子、開発者のルールによって定義され、アプリケーション内で一意でなければなりません。 + + メッセージのフィードバックです。 + ### 応答 @@ -451,7 +454,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -459,7 +462,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` diff --git a/web/app/components/develop/template/template_advanced_chat.zh.mdx b/web/app/components/develop/template/template_advanced_chat.zh.mdx index fec0636d40..662309525b 100755 --- a/web/app/components/develop/template/template_advanced_chat.zh.mdx +++ b/web/app/components/develop/template/template_advanced_chat.zh.mdx @@ -450,6 +450,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' 用户标识,由开发者定义规则,需保证用户标识在应用内唯一。 + + 消息反馈的具体信息。 + ### Response @@ -457,7 +460,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -465,7 +468,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` diff --git a/web/app/components/develop/template/template_chat.en.mdx b/web/app/components/develop/template/template_chat.en.mdx index 1eb289b3c1..d38e80407a 100644 --- a/web/app/components/develop/template/template_chat.en.mdx +++ b/web/app/components/develop/template/template_chat.en.mdx @@ -408,6 +408,9 @@ Chat applications support session persistence, allowing previous chat history to User identifier, defined by the developer's rules, must be unique within the application. + + The specific content of message feedback. + ### Response @@ -415,7 +418,7 @@ Chat applications support session persistence, allowing previous chat history to - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -423,7 +426,8 @@ Chat applications support session persistence, allowing previous chat history to --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` diff --git a/web/app/components/develop/template/template_chat.ja.mdx b/web/app/components/develop/template/template_chat.ja.mdx index fb686e0cff..96db9912d5 100644 --- a/web/app/components/develop/template/template_chat.ja.mdx +++ b/web/app/components/develop/template/template_chat.ja.mdx @@ -408,6 +408,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ユーザー識別子、開発者のルールで定義され、アプリケーション内で一意でなければなりません。 + + メッセージのフィードバックです。 + ### 応答 @@ -415,7 +418,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -423,7 +426,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` diff --git a/web/app/components/develop/template/template_chat.zh.mdx b/web/app/components/develop/template/template_chat.zh.mdx index af96cab5ff..3d6e3630be 100644 --- a/web/app/components/develop/template/template_chat.zh.mdx +++ b/web/app/components/develop/template/template_chat.zh.mdx @@ -423,6 +423,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' 用户标识,由开发者定义规则,需保证用户标识在应用内唯一。 + + 消息反馈的具体信息。 + ### Response @@ -430,7 +433,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \ @@ -438,7 +441,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' --header 'Content-Type: application/json' \ --data-raw '{ "rating": "like", - "user": "abc-123" + "user": "abc-123", + "content": "message feedback information" }' ``` diff --git a/web/app/components/workflow/constants.ts b/web/app/components/workflow/constants.ts index ffa14b347b..d04163b853 100644 --- a/web/app/components/workflow/constants.ts +++ b/web/app/components/workflow/constants.ts @@ -506,3 +506,5 @@ export const WORKFLOW_DATA_UPDATE = 'WORKFLOW_DATA_UPDATE' export const CUSTOM_NODE = 'custom' export const CUSTOM_EDGE = 'custom' export const DSL_EXPORT_CHECK = 'DSL_EXPORT_CHECK' +export const DEFAULT_RETRY_MAX = 3 +export const DEFAULT_RETRY_INTERVAL = 100 diff --git a/web/app/components/workflow/hooks/use-workflow-run.ts b/web/app/components/workflow/hooks/use-workflow-run.ts index a01b2d3154..822aa490db 100644 --- a/web/app/components/workflow/hooks/use-workflow-run.ts +++ b/web/app/components/workflow/hooks/use-workflow-run.ts @@ -28,6 +28,7 @@ import { getFilesInLogs, } from '@/app/components/base/file-uploader/utils' import { ErrorHandleTypeEnum } from '@/app/components/workflow/nodes/_base/components/error-handle/types' +import type { NodeTracing } from '@/types/workflow' export const useWorkflowRun = () => { const store = useStoreApi() @@ -114,6 +115,7 @@ export const useWorkflowRun = () => { onIterationStart, onIterationNext, onIterationFinish, + onNodeRetry, onError, ...restCallback } = callback || {} @@ -440,10 +442,13 @@ export const useWorkflowRun = () => { }) if (currentIndex > -1 && draft.tracing) { draft.tracing[currentIndex] = { + ...data, ...(draft.tracing[currentIndex].extras ? { extras: draft.tracing[currentIndex].extras } : {}), - ...data, + ...(draft.tracing[currentIndex].retryDetail + ? { retryDetail: draft.tracing[currentIndex].retryDetail } + : {}), } as any } })) @@ -616,6 +621,41 @@ export const useWorkflowRun = () => { if (onIterationFinish) onIterationFinish(params) }, + onNodeRetry: (params) => { + const { data } = params + const { + workflowRunningData, + setWorkflowRunningData, + } = workflowStore.getState() + const { + getNodes, + setNodes, + } = store.getState() + + const nodes = getNodes() + setWorkflowRunningData(produce(workflowRunningData!, (draft) => { + const tracing = draft.tracing! + const currentRetryNodeIndex = tracing.findIndex(trace => trace.node_id === data.node_id) + + if (currentRetryNodeIndex > -1) { + const currentRetryNode = tracing[currentRetryNodeIndex] + if (currentRetryNode.retryDetail) + draft.tracing![currentRetryNodeIndex].retryDetail!.push(data as NodeTracing) + + else + draft.tracing![currentRetryNodeIndex].retryDetail = [data as NodeTracing] + } + })) + const newNodes = produce(nodes, (draft) => { + const currentNode = draft.find(node => node.id === data.node_id)! + + currentNode.data._retryIndex = data.retry_index + }) + setNodes(newNodes) + + if (onNodeRetry) + onNodeRetry(params) + }, onParallelBranchStarted: (params) => { // console.log(params, 'parallel start') }, diff --git a/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx b/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx index 92a4deb513..d7e2a953da 100644 --- a/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx +++ b/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx @@ -17,17 +17,25 @@ import ResultPanel from '@/app/components/workflow/run/result-panel' import Toast from '@/app/components/base/toast' import { TransferMethod } from '@/types/app' import { getProcessedFiles } from '@/app/components/base/file-uploader/utils' +import type { NodeTracing } from '@/types/workflow' +import RetryResultPanel from '@/app/components/workflow/run/retry-result-panel' +import type { BlockEnum } from '@/app/components/workflow/types' +import type { Emoji } from '@/app/components/tools/types' const i18nPrefix = 'workflow.singleRun' type BeforeRunFormProps = { nodeName: string + nodeType?: BlockEnum + toolIcon?: string | Emoji onHide: () => void onRun: (submitData: Record) => void onStop: () => void runningStatus: NodeRunningStatus result?: JSX.Element forms: FormProps[] + retryDetails?: NodeTracing[] + onRetryDetailBack?: any } function formatValue(value: string | any, type: InputVarType) { @@ -50,12 +58,16 @@ function formatValue(value: string | any, type: InputVarType) { } const BeforeRunForm: FC = ({ nodeName, + nodeType, + toolIcon, onHide, onRun, onStop, runningStatus, result, forms, + retryDetails, + onRetryDetailBack = () => { }, }) => { const { t } = useTranslation() @@ -122,48 +134,69 @@ const BeforeRunForm: FC = ({
{t(`${i18nPrefix}.testRun`)} {nodeName}
-
+
{ + onHide() + }}>
- -
-
- {forms.map((form, index) => ( -
-
- {index < forms.length - 1 && } + { + retryDetails?.length && ( +
+ ({ + ...item, + title: `${t('workflow.nodes.common.retry.retry')} ${index + 1}`, + node_type: nodeType!, + extras: { + icon: toolIcon!, + }, + }))} + onBack={onRetryDetailBack} + /> +
+ ) + } + { + !retryDetails?.length && ( +
+
+ {forms.map((form, index) => ( +
+ + {index < forms.length - 1 && } +
+ ))}
- ))} -
- -
- {isRunning && ( -
- +
+ {isRunning && ( +
+ +
+ )} +
- )} - -
- {isRunning && ( - - )} - {isFinished && ( - <> - {result} - - )} -
+ {isRunning && ( + + )} + {isFinished && ( + <> + {result} + + )} +
+ ) + }
) diff --git a/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx b/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx index f11f8bd5fb..89412cabb3 100644 --- a/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx +++ b/web/app/components/workflow/nodes/_base/components/error-handle/error-handle-on-panel.tsx @@ -14,7 +14,6 @@ import type { CommonNodeType, Node, } from '@/app/components/workflow/types' -import Split from '@/app/components/workflow/nodes/_base/components/split' import Tooltip from '@/app/components/base/tooltip' type ErrorHandleProps = Pick @@ -45,7 +44,6 @@ const ErrorHandle = ({ return ( <> -
{ + const { handleNodeDataUpdateWithSyncDraft } = useNodeDataUpdate() + + const handleRetryConfigChange = useCallback((value?: WorkflowRetryConfig) => { + handleNodeDataUpdateWithSyncDraft({ + id, + data: { + retry_config: value, + }, + }) + }, [id, handleNodeDataUpdateWithSyncDraft]) + + return { + handleRetryConfigChange, + } +} + +export const useRetryDetailShowInSingleRun = () => { + const [retryDetails, setRetryDetails] = useState() + + const handleRetryDetailsChange = useCallback((details: NodeTracing[] | undefined) => { + setRetryDetails(details) + }, []) + + return { + retryDetails, + handleRetryDetailsChange, + } +} diff --git a/web/app/components/workflow/nodes/_base/components/retry/retry-on-node.tsx b/web/app/components/workflow/nodes/_base/components/retry/retry-on-node.tsx new file mode 100644 index 0000000000..f5d2f08ac8 --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/retry/retry-on-node.tsx @@ -0,0 +1,88 @@ +import { useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import { + RiAlertFill, + RiCheckboxCircleFill, + RiLoader2Line, +} from '@remixicon/react' +import type { Node } from '@/app/components/workflow/types' +import { NodeRunningStatus } from '@/app/components/workflow/types' +import cn from '@/utils/classnames' + +type RetryOnNodeProps = Pick +const RetryOnNode = ({ + data, +}: RetryOnNodeProps) => { + const { t } = useTranslation() + const { retry_config } = data + const showSelectedBorder = data.selected || data._isBundled || data._isEntering + const { + isRunning, + isSuccessful, + isException, + isFailed, + } = useMemo(() => { + return { + isRunning: data._runningStatus === NodeRunningStatus.Running && !showSelectedBorder, + isSuccessful: data._runningStatus === NodeRunningStatus.Succeeded && !showSelectedBorder, + isFailed: data._runningStatus === NodeRunningStatus.Failed && !showSelectedBorder, + isException: data._runningStatus === NodeRunningStatus.Exception && !showSelectedBorder, + } + }, [data._runningStatus, showSelectedBorder]) + const showDefault = !isRunning && !isSuccessful && !isException && !isFailed + + if (!retry_config) + return null + + return ( +
+
+
+ { + showDefault && ( + t('workflow.nodes.common.retry.retryTimes', { times: retry_config.max_retries }) + ) + } + { + isRunning && ( + <> + + {t('workflow.nodes.common.retry.retrying')} + + ) + } + { + isSuccessful && ( + <> + + {t('workflow.nodes.common.retry.retrySuccessful')} + + ) + } + { + (isFailed || isException) && ( + <> + + {t('workflow.nodes.common.retry.retryFailed')} + + ) + } +
+ { + !showDefault && ( +
+ {data._retryIndex}/{data.retry_config?.max_retries} +
+ ) + } +
+
+ ) +} + +export default RetryOnNode diff --git a/web/app/components/workflow/nodes/_base/components/retry/retry-on-panel.tsx b/web/app/components/workflow/nodes/_base/components/retry/retry-on-panel.tsx new file mode 100644 index 0000000000..dc877a632c --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/retry/retry-on-panel.tsx @@ -0,0 +1,117 @@ +import { useTranslation } from 'react-i18next' +import { useRetryConfig } from './hooks' +import s from './style.module.css' +import Switch from '@/app/components/base/switch' +import Slider from '@/app/components/base/slider' +import Input from '@/app/components/base/input' +import type { + Node, +} from '@/app/components/workflow/types' +import Split from '@/app/components/workflow/nodes/_base/components/split' + +type RetryOnPanelProps = Pick +const RetryOnPanel = ({ + id, + data, +}: RetryOnPanelProps) => { + const { t } = useTranslation() + const { handleRetryConfigChange } = useRetryConfig(id) + const { retry_config } = data + + const handleRetryEnabledChange = (value: boolean) => { + handleRetryConfigChange({ + retry_enabled: value, + max_retries: retry_config?.max_retries || 3, + retry_interval: retry_config?.retry_interval || 1000, + }) + } + + const handleMaxRetriesChange = (value: number) => { + if (value > 10) + value = 10 + else if (value < 1) + value = 1 + handleRetryConfigChange({ + retry_enabled: true, + max_retries: value, + retry_interval: retry_config?.retry_interval || 1000, + }) + } + + const handleRetryIntervalChange = (value: number) => { + if (value > 5000) + value = 5000 + else if (value < 100) + value = 100 + handleRetryConfigChange({ + retry_enabled: true, + max_retries: retry_config?.max_retries || 3, + retry_interval: value, + }) + } + + return ( + <> +
+
+
+
{t('workflow.nodes.common.retry.retryOnFailure')}
+
+ handleRetryEnabledChange(v)} + /> +
+ { + retry_config?.retry_enabled && ( +
+
+
{t('workflow.nodes.common.retry.maxRetries')}
+ + handleMaxRetriesChange(e.target.value as any)} + min={1} + max={10} + unit={t('workflow.nodes.common.retry.times') || ''} + className={s.input} + /> +
+
+
{t('workflow.nodes.common.retry.retryInterval')}
+ + handleRetryIntervalChange(e.target.value as any)} + min={100} + max={5000} + unit={t('workflow.nodes.common.retry.ms') || ''} + className={s.input} + /> +
+
+ ) + } +
+ + + ) +} + +export default RetryOnPanel diff --git a/web/app/components/workflow/nodes/_base/components/retry/style.module.css b/web/app/components/workflow/nodes/_base/components/retry/style.module.css new file mode 100644 index 0000000000..2ce8717af8 --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/retry/style.module.css @@ -0,0 +1,5 @@ +.input::-webkit-inner-spin-button, +.input::-webkit-outer-spin-button { + -webkit-appearance: none; + margin: 0; +} \ No newline at end of file diff --git a/web/app/components/workflow/nodes/_base/components/retry/types.ts b/web/app/components/workflow/nodes/_base/components/retry/types.ts new file mode 100644 index 0000000000..bb5f593fd5 --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/retry/types.ts @@ -0,0 +1,5 @@ +export type WorkflowRetryConfig = { + max_retries: number + retry_interval: number + retry_enabled: boolean +} diff --git a/web/app/components/workflow/nodes/_base/components/retry/utils.ts b/web/app/components/workflow/nodes/_base/components/retry/utils.ts new file mode 100644 index 0000000000..e69de29bb2 diff --git a/web/app/components/workflow/nodes/_base/node.tsx b/web/app/components/workflow/nodes/_base/node.tsx index f2da2da35a..4807fa3b2b 100644 --- a/web/app/components/workflow/nodes/_base/node.tsx +++ b/web/app/components/workflow/nodes/_base/node.tsx @@ -25,7 +25,10 @@ import { useNodesReadOnly, useToolIcon, } from '../../hooks' -import { hasErrorHandleNode } from '../../utils' +import { + hasErrorHandleNode, + hasRetryNode, +} from '../../utils' import { useNodeIterationInteractions } from '../iteration/use-interactions' import type { IterationNodeType } from '../iteration/types' import { @@ -35,6 +38,7 @@ import { import NodeResizer from './components/node-resizer' import NodeControl from './components/node-control' import ErrorHandleOnNode from './components/error-handle/error-handle-on-node' +import RetryOnNode from './components/retry/retry-on-node' import AddVariablePopupWithPosition from './components/add-variable-popup-with-position' import cn from '@/utils/classnames' import BlockIcon from '@/app/components/workflow/block-icon' @@ -237,6 +241,14 @@ const BaseNode: FC = ({
) } + { + hasRetryNode(data.type) && ( + + ) + } { hasErrorHandleNode(data.type) && ( = ({
{cloneElement(children, { id, data })}
+ + { + hasRetryNode(data.type) && ( + + ) + } { hasErrorHandleNode(data.type) && ( = { defaultValue: { @@ -24,6 +27,11 @@ const nodeDefault: NodeDefault = { max_read_timeout: 0, max_write_timeout: 0, }, + retry_config: { + retry_enabled: true, + max_retries: 3, + retry_interval: 100, + }, }, getAvailablePrevNodes(isChatMode: boolean) { const nodes = isChatMode diff --git a/web/app/components/workflow/nodes/http/panel.tsx b/web/app/components/workflow/nodes/http/panel.tsx index 5c613aa0f3..91b3a6140d 100644 --- a/web/app/components/workflow/nodes/http/panel.tsx +++ b/web/app/components/workflow/nodes/http/panel.tsx @@ -1,5 +1,5 @@ import type { FC } from 'react' -import React from 'react' +import { memo } from 'react' import { useTranslation } from 'react-i18next' import useConfig from './use-config' import ApiInput from './components/api-input' @@ -18,6 +18,7 @@ import { FileArrow01 } from '@/app/components/base/icons/src/vender/line/files' import type { NodePanelProps } from '@/app/components/workflow/types' import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form' import ResultPanel from '@/app/components/workflow/run/result-panel' +import { useRetryDetailShowInSingleRun } from '@/app/components/workflow/nodes/_base/components/retry/hooks' const i18nPrefix = 'workflow.nodes.http' @@ -60,6 +61,10 @@ const Panel: FC> = ({ hideCurlPanel, handleCurlImport, } = useConfig(id, data) + const { + retryDetails, + handleRetryDetailsChange, + } = useRetryDetailShowInSingleRun() // To prevent prompt editor in body not update data. if (!isDataReady) return null @@ -181,6 +186,7 @@ const Panel: FC> = ({ {isShowSingleRun && ( > = ({ runningStatus={runningStatus} onRun={handleRun} onStop={handleStop} - result={} + retryDetails={retryDetails} + onRetryDetailBack={handleRetryDetailsChange} + result={} /> )} {(isShowCurlPanel && !readOnly) && ( @@ -207,4 +215,4 @@ const Panel: FC> = ({ ) } -export default React.memo(Panel) +export default memo(Panel) diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts b/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts index e9da9acccc..794fcbca4a 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts @@ -129,9 +129,6 @@ export const getMultipleRetrievalConfig = ( reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true, } - if (!rerankModelIsValid) - result.reranking_model = undefined - const setDefaultWeights = () => { result.weights = { vector_setting: { @@ -198,7 +195,6 @@ export const getMultipleRetrievalConfig = ( setDefaultWeights() } } - if (reranking_mode === RerankingModeEnum.RerankingModel && !rerankModelIsValid && shouldSetWeightDefaultValue) { result.reranking_mode = RerankingModeEnum.WeightedScore setDefaultWeights() diff --git a/web/app/components/workflow/nodes/llm/panel.tsx b/web/app/components/workflow/nodes/llm/panel.tsx index 21ef6395b1..60f68d93e2 100644 --- a/web/app/components/workflow/nodes/llm/panel.tsx +++ b/web/app/components/workflow/nodes/llm/panel.tsx @@ -19,6 +19,7 @@ import type { Props as FormProps } from '@/app/components/workflow/nodes/_base/c import ResultPanel from '@/app/components/workflow/run/result-panel' import Tooltip from '@/app/components/base/tooltip' import Editor from '@/app/components/workflow/nodes/_base/components/prompt/editor' +import { useRetryDetailShowInSingleRun } from '@/app/components/workflow/nodes/_base/components/retry/hooks' const i18nPrefix = 'workflow.nodes.llm' @@ -69,6 +70,10 @@ const Panel: FC> = ({ runResult, filterJinjia2InputVar, } = useConfig(id, data) + const { + retryDetails, + handleRetryDetailsChange, + } = useRetryDetailShowInSingleRun() const model = inputs.model @@ -282,12 +287,15 @@ const Panel: FC> = ({ {isShowSingleRun && ( } + retryDetails={retryDetails} + onRetryDetailBack={handleRetryDetailsChange} + result={} /> )}
diff --git a/web/app/components/workflow/nodes/tool/components/input-var-list.tsx b/web/app/components/workflow/nodes/tool/components/input-var-list.tsx index db1a32e319..bab7c20d5b 100644 --- a/web/app/components/workflow/nodes/tool/components/input-var-list.tsx +++ b/web/app/components/workflow/nodes/tool/components/input-var-list.tsx @@ -162,7 +162,7 @@ const InputVarList: FC = ({ readonly={readOnly} isShowNodeName nodeId={nodeId} - value={varInput?.type === VarKindType.constant ? (varInput?.value || '') : (varInput?.value || [])} + value={varInput?.type === VarKindType.constant ? (varInput?.value ?? '') : (varInput?.value ?? [])} onChange={handleNotMixedTypeChange(variable)} onOpen={handleOpen(index)} defaultVarKindType={varInput?.type || (isNumber ? VarKindType.constant : VarKindType.variable)} diff --git a/web/app/components/workflow/nodes/tool/panel.tsx b/web/app/components/workflow/nodes/tool/panel.tsx index 49e645faa4..d0d4c3a839 100644 --- a/web/app/components/workflow/nodes/tool/panel.tsx +++ b/web/app/components/workflow/nodes/tool/panel.tsx @@ -14,6 +14,8 @@ import Loading from '@/app/components/base/loading' import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form' import OutputVars, { VarItem } from '@/app/components/workflow/nodes/_base/components/output-vars' import ResultPanel from '@/app/components/workflow/run/result-panel' +import { useRetryDetailShowInSingleRun } from '@/app/components/workflow/nodes/_base/components/retry/hooks' +import { useToolIcon } from '@/app/components/workflow/hooks' const i18nPrefix = 'workflow.nodes.tool' @@ -48,6 +50,11 @@ const Panel: FC> = ({ handleStop, runResult, } = useConfig(id, data) + const toolIcon = useToolIcon(data) + const { + retryDetails, + handleRetryDetailsChange, + } = useRetryDetailShowInSingleRun() if (isLoading) { return
@@ -143,12 +150,16 @@ const Panel: FC> = ({ {isShowSingleRun && ( } + retryDetails={retryDetails} + onRetryDetailBack={handleRetryDetailsChange} + result={} /> )}
diff --git a/web/app/components/workflow/panel/debug-and-preview/hooks.ts b/web/app/components/workflow/panel/debug-and-preview/hooks.ts index 5d932a1ba2..ebd5e7a99d 100644 --- a/web/app/components/workflow/panel/debug-and-preview/hooks.ts +++ b/web/app/components/workflow/panel/debug-and-preview/hooks.ts @@ -27,6 +27,7 @@ import { getProcessedFilesFromResponse, } from '@/app/components/base/file-uploader/utils' import type { FileEntity } from '@/app/components/base/file-uploader/types' +import type { NodeTracing } from '@/types/workflow' type GetAbortController = (abortController: AbortController) => void type SendCallback = { @@ -381,6 +382,28 @@ export const useChat = ( } })) }, + onNodeRetry: ({ data }) => { + if (data.iteration_id) + return + + const currentIndex = responseItem.workflowProcess!.tracing!.findIndex((item) => { + if (!item.execution_metadata?.parallel_id) + return item.node_id === data.node_id + return item.node_id === data.node_id && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id) + }) + if (responseItem.workflowProcess!.tracing[currentIndex].retryDetail) + responseItem.workflowProcess!.tracing[currentIndex].retryDetail?.push(data as NodeTracing) + else + responseItem.workflowProcess!.tracing[currentIndex].retryDetail = [data as NodeTracing] + + handleUpdateChatList(produce(chatListRef.current, (draft) => { + const currentIndex = draft.findIndex(item => item.id === responseItem.id) + draft[currentIndex] = { + ...draft[currentIndex], + ...responseItem, + } + })) + }, onNodeFinished: ({ data }) => { if (data.iteration_id) return @@ -394,6 +417,9 @@ export const useChat = ( ...(responseItem.workflowProcess!.tracing[currentIndex]?.extras ? { extras: responseItem.workflowProcess!.tracing[currentIndex].extras } : {}), + ...(responseItem.workflowProcess!.tracing[currentIndex]?.retryDetail + ? { retryDetail: responseItem.workflowProcess!.tracing[currentIndex].retryDetail } + : {}), ...data, } as any handleUpdateChatList(produce(chatListRef.current, (draft) => { diff --git a/web/app/components/workflow/panel/workflow-preview.tsx b/web/app/components/workflow/panel/workflow-preview.tsx index 2139ebd338..210a95f1f8 100644 --- a/web/app/components/workflow/panel/workflow-preview.tsx +++ b/web/app/components/workflow/panel/workflow-preview.tsx @@ -25,6 +25,7 @@ import { import { SimpleBtn } from '../../app/text-generate/item' import Toast from '../../base/toast' import IterationResultPanel from '../run/iteration-result-panel' +import RetryResultPanel from '../run/retry-result-panel' import InputsPanel from './inputs-panel' import cn from '@/utils/classnames' import Loading from '@/app/components/base/loading' @@ -53,11 +54,16 @@ const WorkflowPreview = () => { }, [workflowRunningData]) const [iterationRunResult, setIterationRunResult] = useState([]) + const [retryRunResult, setRetryRunResult] = useState([]) const [iterDurationMap, setIterDurationMap] = useState({}) const [isShowIterationDetail, { setTrue: doShowIterationDetail, setFalse: doHideIterationDetail, }] = useBoolean(false) + const [isShowRetryDetail, { + setTrue: doShowRetryDetail, + setFalse: doHideRetryDetail, + }] = useBoolean(false) const handleShowIterationDetail = useCallback((detail: NodeTracing[][], iterationDurationMap: IterationDurationMap) => { setIterDurationMap(iterationDurationMap) @@ -65,6 +71,11 @@ const WorkflowPreview = () => { doShowIterationDetail() }, [doShowIterationDetail]) + const handleRetryDetail = useCallback((detail: NodeTracing[]) => { + setRetryRunResult(detail) + doShowRetryDetail() + }, [doShowRetryDetail]) + if (isShowIterationDetail) { return (
{
)} - {currentTab === 'TRACING' && ( + {currentTab === 'TRACING' && !isShowRetryDetail && ( )} {currentTab === 'TRACING' && !workflowRunningData?.tracing?.length && ( @@ -213,7 +225,14 @@ const WorkflowPreview = () => { )} - + { + currentTab === 'TRACING' && isShowRetryDetail && ( + + ) + } )} diff --git a/web/app/components/workflow/run/index.tsx b/web/app/components/workflow/run/index.tsx index 2bf705f4ce..520c59bf4c 100644 --- a/web/app/components/workflow/run/index.tsx +++ b/web/app/components/workflow/run/index.tsx @@ -9,6 +9,7 @@ import OutputPanel from './output-panel' import ResultPanel from './result-panel' import TracingPanel from './tracing-panel' import IterationResultPanel from './iteration-result-panel' +import RetryResultPanel from './retry-result-panel' import cn from '@/utils/classnames' import { ToastContext } from '@/app/components/base/toast' import Loading from '@/app/components/base/loading' @@ -107,6 +108,18 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe const processNonIterationNode = (item: NodeTracing) => { const { execution_metadata } = item if (!execution_metadata?.iteration_id) { + if (item.status === 'retry') { + const retryNode = result.find(node => node.node_id === item.node_id) + + if (retryNode) { + if (retryNode?.retryDetail) + retryNode.retryDetail.push(item) + else + retryNode.retryDetail = [item] + } + + return + } result.push(item) return } @@ -181,10 +194,15 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe const [iterationRunResult, setIterationRunResult] = useState([]) const [iterDurationMap, setIterDurationMap] = useState({}) + const [retryRunResult, setRetryRunResult] = useState([]) const [isShowIterationDetail, { setTrue: doShowIterationDetail, setFalse: doHideIterationDetail, }] = useBoolean(false) + const [isShowRetryDetail, { + setTrue: doShowRetryDetail, + setFalse: doHideRetryDetail, + }] = useBoolean(false) const handleShowIterationDetail = useCallback((detail: NodeTracing[][], iterDurationMap: IterationDurationMap) => { setIterationRunResult(detail) @@ -192,6 +210,11 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe setIterDurationMap(iterDurationMap) }, [doShowIterationDetail, setIterationRunResult, setIterDurationMap]) + const handleShowRetryDetail = useCallback((detail: NodeTracing[]) => { + setRetryRunResult(detail) + doShowRetryDetail() + }, [doShowRetryDetail, setRetryRunResult]) + if (isShowIterationDetail) { return (
@@ -261,13 +284,22 @@ const RunPanel: FC = ({ hideResult, activeTab = 'RESULT', runID, getRe exceptionCounts={runDetail.exceptions_count} /> )} - {!loading && currentTab === 'TRACING' && ( + {!loading && currentTab === 'TRACING' && !isShowRetryDetail && ( )} + { + !loading && currentTab === 'TRACING' && isShowRetryDetail && ( + + ) + }
) diff --git a/web/app/components/workflow/run/node.tsx b/web/app/components/workflow/run/node.tsx index d1a02ecfe0..bb07bd1e8c 100644 --- a/web/app/components/workflow/run/node.tsx +++ b/web/app/components/workflow/run/node.tsx @@ -8,6 +8,7 @@ import { RiCheckboxCircleFill, RiErrorWarningLine, RiLoader2Line, + RiRestartFill, } from '@remixicon/react' import BlockIcon from '../block-icon' import { BlockEnum } from '../types' @@ -20,6 +21,7 @@ import Button from '@/app/components/base/button' import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' import type { IterationDurationMap, NodeTracing } from '@/types/workflow' import ErrorHandleTip from '@/app/components/workflow/nodes/_base/components/error-handle/error-handle-tip' +import { hasRetryNode } from '@/app/components/workflow/utils' type Props = { className?: string @@ -28,8 +30,10 @@ type Props = { hideInfo?: boolean hideProcessDetail?: boolean onShowIterationDetail?: (detail: NodeTracing[][], iterDurationMap: IterationDurationMap) => void + onShowRetryDetail?: (detail: NodeTracing[]) => void notShowIterationNav?: boolean justShowIterationNavArrow?: boolean + justShowRetryNavArrow?: boolean } const NodePanel: FC = ({ @@ -39,6 +43,7 @@ const NodePanel: FC = ({ hideInfo = false, hideProcessDetail, onShowIterationDetail, + onShowRetryDetail, notShowIterationNav, justShowIterationNavArrow, }) => { @@ -88,11 +93,17 @@ const NodePanel: FC = ({ }, [nodeInfo.expand, setCollapseState]) const isIterationNode = nodeInfo.node_type === BlockEnum.Iteration + const isRetryNode = hasRetryNode(nodeInfo.node_type) && nodeInfo.retryDetail const handleOnShowIterationDetail = (e: React.MouseEvent) => { e.stopPropagation() e.nativeEvent.stopImmediatePropagation() onShowIterationDetail?.(nodeInfo.details || [], nodeInfo?.iterDurationMap || nodeInfo.execution_metadata?.iteration_duration_map || {}) } + const handleOnShowRetryDetail = (e: React.MouseEvent) => { + e.stopPropagation() + e.nativeEvent.stopImmediatePropagation() + onShowRetryDetail?.(nodeInfo.retryDetail || []) + } return (
@@ -169,6 +180,19 @@ const NodePanel: FC = ({
)} + {isRetryNode && ( + + )}
{(nodeInfo.status === 'stopped') && ( diff --git a/web/app/components/workflow/run/result-panel.tsx b/web/app/components/workflow/run/result-panel.tsx index a688693e4f..bbe740ad48 100644 --- a/web/app/components/workflow/run/result-panel.tsx +++ b/web/app/components/workflow/run/result-panel.tsx @@ -1,11 +1,17 @@ 'use client' import type { FC } from 'react' import { useTranslation } from 'react-i18next' +import { + RiArrowRightSLine, + RiRestartFill, +} from '@remixicon/react' import StatusPanel from './status' import MetaData from './meta' import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' import ErrorHandleTip from '@/app/components/workflow/nodes/_base/components/error-handle/error-handle-tip' +import type { NodeTracing } from '@/types/workflow' +import Button from '@/app/components/base/button' type ResultPanelProps = { inputs?: string @@ -22,6 +28,8 @@ type ResultPanelProps = { showSteps?: boolean exceptionCounts?: number execution_metadata?: any + retry_events?: NodeTracing[] + onShowRetryDetail?: (retries: NodeTracing[]) => void } const ResultPanel: FC = ({ @@ -38,8 +46,11 @@ const ResultPanel: FC = ({ showSteps, exceptionCounts, execution_metadata, + retry_events, + onShowRetryDetail, }) => { const { t } = useTranslation() + return (
@@ -51,6 +62,23 @@ const ResultPanel: FC = ({ exceptionCounts={exceptionCounts} />
+ { + retry_events?.length && onShowRetryDetail && ( +
+ +
+ ) + }
void +} + +const RetryResultPanel: FC = ({ + list, + onBack, +}) => { + const { t } = useTranslation() + + return ( +
+
{ + e.stopPropagation() + e.nativeEvent.stopImmediatePropagation() + onBack() + }} + > + + {t('workflow.singleRun.back')} +
+ ({ + ...item, + title: `${t('workflow.nodes.common.retry.retry')} ${index + 1}`, + }))} + className='bg-background-section-burn' + /> +
+ ) +} +export default memo(RetryResultPanel) diff --git a/web/app/components/workflow/run/tracing-panel.tsx b/web/app/components/workflow/run/tracing-panel.tsx index 57b3a5cf5f..ad78971895 100644 --- a/web/app/components/workflow/run/tracing-panel.tsx +++ b/web/app/components/workflow/run/tracing-panel.tsx @@ -21,6 +21,7 @@ import type { IterationDurationMap, NodeTracing } from '@/types/workflow' type TracingPanelProps = { list: NodeTracing[] onShowIterationDetail?: (detail: NodeTracing[][], iterDurationMap: IterationDurationMap) => void + onShowRetryDetail?: (detail: NodeTracing[]) => void className?: string hideNodeInfo?: boolean hideNodeProcessDetail?: boolean @@ -160,6 +161,7 @@ function buildLogTree(nodes: NodeTracing[], t: (key: string) => string): Tracing const TracingPanel: FC = ({ list, onShowIterationDetail, + onShowRetryDetail, className, hideNodeInfo = false, hideNodeProcessDetail = false, @@ -251,7 +253,9 @@ const TracingPanel: FC = ({ diff --git a/web/app/components/workflow/types.ts b/web/app/components/workflow/types.ts index c40ea0de55..6d0fabd90e 100644 --- a/web/app/components/workflow/types.ts +++ b/web/app/components/workflow/types.ts @@ -13,6 +13,7 @@ import type { DefaultValueForm, ErrorHandleTypeEnum, } from '@/app/components/workflow/nodes/_base/components/error-handle/types' +import type { WorkflowRetryConfig } from '@/app/components/workflow/nodes/_base/components/retry/types' export enum BlockEnum { Start = 'start', @@ -68,6 +69,7 @@ export type CommonNodeType = { _iterationIndex?: number _inParallelHovering?: boolean _waitingRun?: boolean + _retryIndex?: number isInIteration?: boolean iteration_id?: string selected?: boolean @@ -77,6 +79,7 @@ export type CommonNodeType = { width?: number height?: number error_strategy?: ErrorHandleTypeEnum + retry_config?: WorkflowRetryConfig default_value?: DefaultValueForm[] } & T & Partial> @@ -293,6 +296,7 @@ export enum NodeRunningStatus { Succeeded = 'succeeded', Failed = 'failed', Exception = 'exception', + Retry = 'retry', } export type OnNodeAdd = ( diff --git a/web/app/components/workflow/utils.ts b/web/app/components/workflow/utils.ts index abe129e6d6..4c61267e4c 100644 --- a/web/app/components/workflow/utils.ts +++ b/web/app/components/workflow/utils.ts @@ -26,6 +26,8 @@ import { } from './types' import { CUSTOM_NODE, + DEFAULT_RETRY_INTERVAL, + DEFAULT_RETRY_MAX, ITERATION_CHILDREN_Z_INDEX, ITERATION_NODE_Z_INDEX, NODE_WIDTH_X_OFFSET, @@ -279,6 +281,14 @@ export const initialNodes = (originNodes: Node[], originEdges: Edge[]) => { iterationNodeData.error_handle_mode = iterationNodeData.error_handle_mode || ErrorHandleMode.Terminated } + if (node.data.type === BlockEnum.HttpRequest && !node.data.retry_config) { + node.data.retry_config = { + retry_enabled: true, + max_retries: DEFAULT_RETRY_MAX, + retry_interval: DEFAULT_RETRY_INTERVAL, + } + } + return node }) } @@ -797,3 +807,7 @@ export const isExceptionVariable = (variable: string, nodeType?: BlockEnum) => { return false } + +export const hasRetryNode = (nodeType?: BlockEnum) => { + return nodeType === BlockEnum.LLM || nodeType === BlockEnum.Tool || nodeType === BlockEnum.HttpRequest || nodeType === BlockEnum.Code +} diff --git a/web/i18n/de-DE/workflow.ts b/web/i18n/de-DE/workflow.ts index 8888e23739..38686f8c1d 100644 --- a/web/i18n/de-DE/workflow.ts +++ b/web/i18n/de-DE/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Fehlerbehandlung', tip: 'Ausnahmebehandlungsstrategie, die ausgelöst wird, wenn ein Knoten auf eine Ausnahme stößt.', }, + retry: { + retry: 'Wiederholen', + retryOnFailure: 'Wiederholen bei Fehler', + maxRetries: 'Max. Wiederholungen', + retryInterval: 'Wiederholungsintervall', + retryTimes: 'Wiederholen Sie {{times}} mal bei einem Fehler', + retrying: 'Wiederholung...', + retrySuccessful: 'Wiederholen erfolgreich', + retryFailed: 'Wiederholung fehlgeschlagen', + retryFailedTimes: '{{times}} fehlgeschlagene Wiederholungen', + times: 'mal', + ms: 'Frau', + retries: '{{num}} Wiederholungen', + }, }, start: { required: 'erforderlich', diff --git a/web/i18n/en-US/workflow.ts b/web/i18n/en-US/workflow.ts index e2a2fdb59d..fab25fa509 100644 --- a/web/i18n/en-US/workflow.ts +++ b/web/i18n/en-US/workflow.ts @@ -329,6 +329,20 @@ const translation = { tip: 'There are {{num}} nodes in the process running abnormally, please go to tracing to check the logs.', }, }, + retry: { + retry: 'Retry', + retryOnFailure: 'retry on failure', + maxRetries: 'max retries', + retryInterval: 'retry interval', + retryTimes: 'Retry {{times}} times on failure', + retrying: 'Retrying...', + retrySuccessful: 'Retry successful', + retryFailed: 'Retry failed', + retryFailedTimes: '{{times}} retries failed', + times: 'times', + ms: 'ms', + retries: '{{num}} Retries', + }, }, start: { required: 'required', diff --git a/web/i18n/es-ES/workflow.ts b/web/i18n/es-ES/workflow.ts index c49c611da8..d112ad97b6 100644 --- a/web/i18n/es-ES/workflow.ts +++ b/web/i18n/es-ES/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Manejo de errores', tip: 'Estrategia de control de excepciones, que se desencadena cuando un nodo encuentra una excepción.', }, + retry: { + retryOnFailure: 'Volver a intentarlo en caso de error', + maxRetries: 'Número máximo de reintentos', + retryInterval: 'Intervalo de reintento', + retryTimes: 'Reintentar {{times}} veces en caso de error', + retrying: 'Reintentando...', + retrySuccessful: 'Volver a intentarlo correctamente', + retryFailed: 'Error en el reintento', + retryFailedTimes: '{{veces}} reintentos fallidos', + times: 'veces', + ms: 'Sra.', + retries: '{{num}} Reintentos', + retry: 'Reintentar', + }, }, start: { required: 'requerido', diff --git a/web/i18n/fa-IR/workflow.ts b/web/i18n/fa-IR/workflow.ts index c29f911556..37cba2f16b 100644 --- a/web/i18n/fa-IR/workflow.ts +++ b/web/i18n/fa-IR/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'مدیریت خطا', tip: 'استراتژی مدیریت استثنا، زمانی که یک گره با یک استثنا مواجه می شود، فعال می شود.', }, + retry: { + times: 'بار', + retryInterval: 'فاصله تلاش مجدد', + retryOnFailure: 'در مورد شکست دوباره امتحان کنید', + ms: 'خانم', + retry: 'دوباره', + retries: '{{عدد}} تلاش های مجدد', + maxRetries: 'حداکثر تلاش مجدد', + retrying: 'تلاش مجدد...', + retryFailed: 'تلاش مجدد ناموفق بود', + retryTimes: '{{times}} بار در صورت شکست دوباره امتحان کنید', + retrySuccessful: 'امتحان مجدد با موفقیت انجام دهید', + retryFailedTimes: '{{بار}} تلاش های مجدد ناموفق بود', + }, }, start: { required: 'الزامی', diff --git a/web/i18n/fr-FR/workflow.ts b/web/i18n/fr-FR/workflow.ts index a2b2406113..e7d2802cb4 100644 --- a/web/i18n/fr-FR/workflow.ts +++ b/web/i18n/fr-FR/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Gestion des erreurs', tip: 'Stratégie de gestion des exceptions, déclenchée lorsqu’un nœud rencontre une exception.', }, + retry: { + retry: 'Réessayer', + retryOnFailure: 'Réessai en cas d’échec', + maxRetries: 'Nombre maximal de tentatives', + retryInterval: 'intervalle de nouvelle tentative', + retryTimes: 'Réessayez {{times}} fois en cas d’échec', + retrying: 'Réessayer...', + retrySuccessful: 'Réessai réussi', + retryFailed: 'Échec de la nouvelle tentative', + retryFailedTimes: '{{times}} les tentatives ont échoué', + times: 'fois', + ms: 'ms', + retries: '{{num}} Tentatives', + }, }, start: { required: 'requis', diff --git a/web/i18n/hi-IN/workflow.ts b/web/i18n/hi-IN/workflow.ts index 47589078ce..619abee128 100644 --- a/web/i18n/hi-IN/workflow.ts +++ b/web/i18n/hi-IN/workflow.ts @@ -334,6 +334,20 @@ const translation = { title: 'त्रुटि हैंडलिंग', tip: 'अपवाद हैंडलिंग रणनीति, ट्रिगर जब एक नोड एक अपवाद का सामना करता है।', }, + retry: { + times: 'गुणा', + ms: 'सुश्री', + retryInterval: 'अंतराल का पुनः प्रयास करें', + retrying: 'पुनर्प्रयास।।।', + retryFailed: 'पुनः प्रयास विफल रहा', + retryFailedTimes: '{{times}} पुनः प्रयास विफल रहे', + retryTimes: 'विफलता पर {{times}} बार पुनः प्रयास करें', + retries: '{{num}} पुनर्प्रयास', + maxRetries: 'अधिकतम पुनः प्रयास करता है', + retrySuccessful: 'पुनः प्रयास सफल', + retry: 'पुनर्प्रयास', + retryOnFailure: 'विफलता पर पुनः प्रयास करें', + }, }, start: { required: 'आवश्यक', diff --git a/web/i18n/it-IT/workflow.ts b/web/i18n/it-IT/workflow.ts index e760074e6a..f4390580d5 100644 --- a/web/i18n/it-IT/workflow.ts +++ b/web/i18n/it-IT/workflow.ts @@ -337,6 +337,20 @@ const translation = { title: 'Gestione degli errori', tip: 'Strategia di gestione delle eccezioni, attivata quando un nodo rileva un\'eccezione.', }, + retry: { + retry: 'Ripetere', + retryOnFailure: 'Riprova in caso di errore', + maxRetries: 'Numero massimo di tentativi', + retryInterval: 'Intervallo tentativi', + retryTimes: 'Riprova {{times}} volte in caso di errore', + retrying: 'Riprovare...', + retryFailedTimes: '{{times}} tentativi falliti', + times: 'tempi', + retries: '{{num}} Tentativi', + retrySuccessful: 'Riprova riuscito', + retryFailed: 'Nuovo tentativo non riuscito', + ms: 'ms', + }, }, start: { required: 'richiesto', diff --git a/web/i18n/ja-JP/workflow.ts b/web/i18n/ja-JP/workflow.ts index 8305105c22..1aa764a19f 100644 --- a/web/i18n/ja-JP/workflow.ts +++ b/web/i18n/ja-JP/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'エラー処理', tip: 'ノードが例外を検出したときにトリガーされる例外処理戦略。', }, + retry: { + retry: 'リトライ', + retryOnFailure: '失敗時の再試行', + maxRetries: '最大再試行回数', + retryInterval: '再試行間隔', + retrying: '再試行。。。', + retryFailed: '再試行に失敗しました', + times: '倍', + ms: 'さん', + retryTimes: '失敗時に{{times}}回再試行', + retrySuccessful: '再試行に成功しました', + retries: '{{num}} 回の再試行', + retryFailedTimes: '{{times}}回のリトライが失敗しました', + }, }, start: { required: '必須', diff --git a/web/i18n/ko-KR/workflow.ts b/web/i18n/ko-KR/workflow.ts index cc2c1b1a28..4a4d2f9193 100644 --- a/web/i18n/ko-KR/workflow.ts +++ b/web/i18n/ko-KR/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: '오류 처리', tip: '노드에 예외가 발생할 때 트리거되는 예외 처리 전략입니다.', }, + retry: { + retry: '재시도', + retryOnFailure: '실패 시 재시도', + maxRetries: '최대 재시도 횟수', + retryInterval: '재시도 간격', + retryTimes: '실패 시 {{times}}번 재시도', + retrying: '재시도...', + retrySuccessful: '재시도 성공', + retryFailed: '재시도 실패', + retryFailedTimes: '{{times}} 재시도 실패', + times: '배', + ms: '미에스', + retries: '{{숫자}} 재시도', + }, }, start: { required: '필수', diff --git a/web/i18n/pl-PL/workflow.ts b/web/i18n/pl-PL/workflow.ts index 2db6cf2bfb..13784df603 100644 --- a/web/i18n/pl-PL/workflow.ts +++ b/web/i18n/pl-PL/workflow.ts @@ -322,6 +322,20 @@ const translation = { tip: 'Strategia obsługi wyjątków, wyzwalana, gdy węzeł napotka wyjątek.', title: 'Obsługa błędów', }, + retry: { + retry: 'Ponów próbę', + maxRetries: 'Maksymalna liczba ponownych prób', + retryInterval: 'Interwał ponawiania prób', + retryTimes: 'Ponów próbę {{times}} razy w przypadku niepowodzenia', + retrying: 'Ponawianie...', + retrySuccessful: 'Ponawianie próby powiodło się', + retryFailed: 'Ponawianie próby nie powiodło się', + times: 'razy', + retries: '{{liczba}} Ponownych prób', + retryOnFailure: 'Ponawianie próby w przypadku niepowodzenia', + retryFailedTimes: '{{times}} ponawianie prób nie powiodło się', + ms: 'Ms', + }, }, start: { required: 'wymagane', diff --git a/web/i18n/pt-BR/workflow.ts b/web/i18n/pt-BR/workflow.ts index 4d53ec07c7..b99c64cdf4 100644 --- a/web/i18n/pt-BR/workflow.ts +++ b/web/i18n/pt-BR/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Tratamento de erros', tip: 'Estratégia de tratamento de exceções, disparada quando um nó encontra uma exceção.', }, + retry: { + retry: 'Repetir', + retryOnFailure: 'Tentar novamente em caso de falha', + maxRetries: 'Máximo de tentativas', + retryInterval: 'Intervalo de repetição', + retryTimes: 'Tente novamente {{times}} vezes em caso de falha', + retrying: 'Repetindo...', + retrySuccessful: 'Repetição bem-sucedida', + retryFailed: 'Falha na nova tentativa', + retryFailedTimes: '{{times}} tentativas falharam', + times: 'vezes', + ms: 'ms', + retries: '{{num}} Tentativas', + }, }, start: { required: 'requerido', diff --git a/web/i18n/ro-RO/workflow.ts b/web/i18n/ro-RO/workflow.ts index 3dfa6d04ed..b142640c9b 100644 --- a/web/i18n/ro-RO/workflow.ts +++ b/web/i18n/ro-RO/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Gestionarea erorilor', tip: 'Strategie de gestionare a excepțiilor, declanșată atunci când un nod întâlnește o excepție.', }, + retry: { + retry: 'Reîncercare', + retryOnFailure: 'Reîncercați în caz de eșec', + maxRetries: 'numărul maxim de încercări', + retryInterval: 'Interval de reîncercare', + retrying: 'Reîncerca...', + retrySuccessful: 'Reîncercați cu succes', + retryFailed: 'Reîncercarea a eșuat', + retryFailedTimes: '{{times}} reîncercări eșuate', + times: 'Ori', + ms: 'Ms', + retries: '{{num}} Încercări', + retryTimes: 'Reîncercați {{times}} ori în caz de eșec', + }, }, start: { required: 'necesar', diff --git a/web/i18n/ru-RU/workflow.ts b/web/i18n/ru-RU/workflow.ts index 600c59f2ed..49c43b4d6d 100644 --- a/web/i18n/ru-RU/workflow.ts +++ b/web/i18n/ru-RU/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Обработка ошибок', tip: 'Стратегия обработки исключений, запускаемая при обнаружении исключения на узле.', }, + retry: { + retry: 'Снова пробовать', + retryOnFailure: 'Повторная попытка при неудаче', + maxRetries: 'максимальное количество повторных попыток', + retryInterval: 'Интервал повторных попыток', + retryTimes: 'Повторите {{раз}} раз при неудаче', + retrying: 'Повтор...', + retrySuccessful: 'Повторить попытку успешно', + retryFailed: 'Повторная попытка не удалась', + times: 'раз', + ms: 'госпожа', + retryFailedTimes: 'Повторные попытки {{times}} не увенчались успехом', + retries: '{{число}} Повторных попыток', + }, }, start: { required: 'обязательно', diff --git a/web/i18n/sl-SI/workflow.ts b/web/i18n/sl-SI/workflow.ts index 2c9dab8b55..7c40c25e92 100644 --- a/web/i18n/sl-SI/workflow.ts +++ b/web/i18n/sl-SI/workflow.ts @@ -759,6 +759,20 @@ const translation = { title: 'Ravnanje z napakami', tip: 'Strategija ravnanja z izjemami, ki se sproži, ko vozlišče naleti na izjemo.', }, + retry: { + retryOnFailure: 'Ponovni poskus ob neuspehu', + retryInterval: 'Interval ponovnega poskusa', + retrying: 'Ponovnim...', + retry: 'Ponoviti', + retryFailedTimes: '{{times}} ponovni poskusi niso uspeli', + retries: '{{num}} Poskusov', + times: 'Krat', + retryTimes: 'Ponovni poskus {{times}}-krat ob neuspehu', + retryFailed: 'Ponovni poskus ni uspel', + retrySuccessful: 'Ponovni poskus je bil uspešen', + maxRetries: 'Največ ponovnih poskusov', + ms: 'Ms', + }, }, start: { outputVars: { diff --git a/web/i18n/th-TH/workflow.ts b/web/i18n/th-TH/workflow.ts index c4305466aa..b8d2e72de0 100644 --- a/web/i18n/th-TH/workflow.ts +++ b/web/i18n/th-TH/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'การจัดการข้อผิดพลาด', tip: 'กลยุทธ์การจัดการข้อยกเว้น ทริกเกอร์เมื่อโหนดพบข้อยกเว้น', }, + retry: { + retry: 'ลอง', + retryOnFailure: 'ลองใหม่เมื่อล้มเหลว', + maxRetries: 'การลองซ้ําสูงสุด', + retryInterval: 'ช่วงเวลาลองใหม่', + retryTimes: 'ลอง {{times}} ครั้งเมื่อล้มเหลว', + retrying: 'กําลังลองซ้ํา...', + retrySuccessful: 'ลองใหม่สําเร็จ', + retryFailed: 'ลองใหม่ล้มเหลว', + retryFailedTimes: '{{times}} การลองซ้ําล้มเหลว', + times: 'ครั้ง', + retries: '{{num}} ลอง', + ms: 'นางสาว', + }, }, start: { required: 'ต้องระบุ', diff --git a/web/i18n/tr-TR/workflow.ts b/web/i18n/tr-TR/workflow.ts index 951a20e049..edec6a0b49 100644 --- a/web/i18n/tr-TR/workflow.ts +++ b/web/i18n/tr-TR/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Hata İşleme', tip: 'Bir düğüm bir özel durumla karşılaştığında tetiklenen özel durum işleme stratejisi.', }, + retry: { + retry: 'Yeni -den deneme', + retryOnFailure: 'Hata durumunda yeniden dene', + maxRetries: 'En fazla yeniden deneme', + times: 'kere', + retries: '{{sayı}} Yeni -den deneme', + retryFailed: 'Yeniden deneme başarısız oldu', + retryInterval: 'Yeniden deneme aralığı', + retryTimes: 'Hata durumunda {{times}} kez yeniden deneyin', + retryFailedTimes: '{{times}} yeniden denemeleri başarısız oldu', + retrySuccessful: 'Yeniden deneme başarılı', + retrying: 'Yeniden deneniyor...', + ms: 'Ms', + }, }, start: { required: 'gerekli', diff --git a/web/i18n/uk-UA/workflow.ts b/web/i18n/uk-UA/workflow.ts index 2c00d3bf59..29fd9d8188 100644 --- a/web/i18n/uk-UA/workflow.ts +++ b/web/i18n/uk-UA/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: 'Обробка помилок', tip: 'Стратегія обробки винятків, що спрацьовує, коли вузол стикається з винятком.', }, + retry: { + retry: 'Повторити', + retryOnFailure: 'повторити спробу в разі невдачі', + retryInterval: 'Інтервал повторних спроб', + retrying: 'Спроби...', + retryFailed: 'Повторна спроба не вдалася', + times: 'Разів', + ms: 'МС', + retries: '{{num}} Спроб', + maxRetries: 'Максимальна кількість повторних спроб', + retrySuccessful: 'Повторна спроба успішна', + retryFailedTimes: '{{times}} повторні спроби не вдалися', + retryTimes: 'Повторіть спробу {{times}} у разі невдачі', + }, }, start: { required: 'обов\'язковий', diff --git a/web/i18n/vi-VN/workflow.ts b/web/i18n/vi-VN/workflow.ts index 956fe84159..9e16cb5347 100644 --- a/web/i18n/vi-VN/workflow.ts +++ b/web/i18n/vi-VN/workflow.ts @@ -322,6 +322,20 @@ const translation = { tip: 'Chiến lược xử lý ngoại lệ, được kích hoạt khi một nút gặp phải ngoại lệ.', title: 'Xử lý lỗi', }, + retry: { + retry: 'Thử lại', + maxRetries: 'Số lần thử lại tối đa', + retryInterval: 'Khoảng thời gian thử lại', + retryTimes: 'Thử lại {{lần}} lần khi không thành công', + retrying: 'Thử lại...', + retrySuccessful: 'Thử lại thành công', + retryFailed: 'Thử lại không thành công', + retryFailedTimes: '{{lần}} lần thử lại không thành công', + retries: '{{số}} Thử lại', + retryOnFailure: 'Thử lại khi không thành công', + times: 'lần', + ms: 'Ms', + }, }, start: { required: 'bắt buộc', diff --git a/web/i18n/zh-Hans/workflow.ts b/web/i18n/zh-Hans/workflow.ts index 19cda33057..dfad9208e7 100644 --- a/web/i18n/zh-Hans/workflow.ts +++ b/web/i18n/zh-Hans/workflow.ts @@ -329,6 +329,20 @@ const translation = { tip: '流程中有 {{num}} 个节点运行异常,请前往追踪查看日志。', }, }, + retry: { + retry: '重试', + retryOnFailure: '失败时重试', + maxRetries: '最大重试次数', + retryInterval: '重试间隔', + retryTimes: '失败时重试 {{times}} 次', + retrying: '重试中...', + retrySuccessful: '重试成功', + retryFailed: '重试失败', + retryFailedTimes: '{{times}} 次重试失败', + times: '次', + ms: '毫秒', + retries: '{{num}} 重试次数', + }, }, start: { required: '必填', diff --git a/web/i18n/zh-Hant/workflow.ts b/web/i18n/zh-Hant/workflow.ts index 4bbbf7a04f..a78c6a2f04 100644 --- a/web/i18n/zh-Hant/workflow.ts +++ b/web/i18n/zh-Hant/workflow.ts @@ -322,6 +322,20 @@ const translation = { title: '錯誤處理', tip: '異常處理策略,當節點遇到異常時觸發。', }, + retry: { + retry: '重試', + retryOnFailure: '失敗時重試', + maxRetries: '最大重試次數', + retryInterval: '重試間隔', + retryTimes: '失敗時重試 {{times}} 次', + retrying: '重試。。。', + retrySuccessful: '重試成功', + retryFailed: '重試失敗', + retryFailedTimes: '{{times}} 次重試失敗', + times: '次', + ms: '女士', + retries: '{{num}}重試', + }, }, start: { required: '必填', diff --git a/web/service/base.ts b/web/service/base.ts index 03421d92a4..22b1a43ad1 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -62,6 +62,7 @@ export type IOnNodeStarted = (nodeStarted: NodeStartedResponse) => void export type IOnNodeFinished = (nodeFinished: NodeFinishedResponse) => void export type IOnIterationStarted = (workflowStarted: IterationStartedResponse) => void export type IOnIterationNext = (workflowStarted: IterationNextResponse) => void +export type IOnNodeRetry = (nodeFinished: NodeFinishedResponse) => void export type IOnIterationFinished = (workflowFinished: IterationFinishedResponse) => void export type IOnParallelBranchStarted = (parallelBranchStarted: ParallelBranchStartedResponse) => void export type IOnParallelBranchFinished = (parallelBranchFinished: ParallelBranchFinishedResponse) => void @@ -92,6 +93,7 @@ export type IOtherOptions = { onIterationStart?: IOnIterationStarted onIterationNext?: IOnIterationNext onIterationFinish?: IOnIterationFinished + onNodeRetry?: IOnNodeRetry onParallelBranchStarted?: IOnParallelBranchStarted onParallelBranchFinished?: IOnParallelBranchFinished onTextChunk?: IOnTextChunk @@ -165,6 +167,7 @@ const handleStream = ( onIterationStart?: IOnIterationStarted, onIterationNext?: IOnIterationNext, onIterationFinish?: IOnIterationFinished, + onNodeRetry?: IOnNodeRetry, onParallelBranchStarted?: IOnParallelBranchStarted, onParallelBranchFinished?: IOnParallelBranchFinished, onTextChunk?: IOnTextChunk, @@ -256,6 +259,9 @@ const handleStream = ( else if (bufferObj.event === 'iteration_completed') { onIterationFinish?.(bufferObj as IterationFinishedResponse) } + else if (bufferObj.event === 'node_retry') { + onNodeRetry?.(bufferObj as NodeFinishedResponse) + } else if (bufferObj.event === 'parallel_branch_started') { onParallelBranchStarted?.(bufferObj as ParallelBranchStartedResponse) } @@ -462,6 +468,7 @@ export const ssePost = ( onIterationStart, onIterationNext, onIterationFinish, + onNodeRetry, onParallelBranchStarted, onParallelBranchFinished, onTextChunk, @@ -533,7 +540,7 @@ export const ssePost = ( return } onData?.(str, isFirstMessage, moreInfo) - }, onCompleted, onThought, onMessageEnd, onMessageReplace, onFile, onWorkflowStarted, onWorkflowFinished, onNodeStarted, onNodeFinished, onIterationStart, onIterationNext, onIterationFinish, onParallelBranchStarted, onParallelBranchFinished, onTextChunk, onTTSChunk, onTTSEnd, onTextReplace) + }, onCompleted, onThought, onMessageEnd, onMessageReplace, onFile, onWorkflowStarted, onWorkflowFinished, onNodeStarted, onNodeFinished, onIterationStart, onIterationNext, onIterationFinish, onNodeRetry, onParallelBranchStarted, onParallelBranchFinished, onTextChunk, onTTSChunk, onTTSEnd, onTextReplace) }).catch((e) => { if (e.toString() !== 'AbortError: The user aborted a request.' && !e.toString().errorMessage.includes('TypeError: Cannot assign to read only property')) Toast.notify({ type: 'error', message: e }) diff --git a/web/types/workflow.ts b/web/types/workflow.ts index 38f0bb5a40..cd6e9cfa5f 100644 --- a/web/types/workflow.ts +++ b/web/types/workflow.ts @@ -52,10 +52,12 @@ export type NodeTracing = { extras?: any expand?: boolean // for UI details?: NodeTracing[][] // iteration detail + retryDetail?: NodeTracing[] // retry detail parallel_id?: string parallel_start_node_id?: string parent_parallel_id?: string parent_parallel_start_node_id?: string + retry_index?: number } export type FetchWorkflowDraftResponse = { @@ -178,6 +180,7 @@ export type NodeFinishedResponse = { } created_at: number files?: FileResponse[] + retry_index?: number } }