diff --git a/api/commands.py b/api/commands.py index 09548ac9f3..bf013cc77e 100644 --- a/api/commands.py +++ b/api/commands.py @@ -555,7 +555,8 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str if language not in languages: language = "en-US" - name = name.strip() + # Validates name encoding for non-Latin characters. + name = name.strip().encode("utf-8").decode("utf-8") if name else None # generate random password new_password = secrets.token_urlsafe(16) diff --git a/api/controllers/common/errors.py b/api/controllers/common/errors.py index c71f1ce5a3..9f762b3135 100644 --- a/api/controllers/common/errors.py +++ b/api/controllers/common/errors.py @@ -4,3 +4,8 @@ from werkzeug.exceptions import HTTPException class FilenameNotExistsError(HTTPException): code = 400 description = "The specified filename does not exist." + + +class RemoteFileUploadError(HTTPException): + code = 400 + description = "Error uploading remote file." diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 6f9d7769b9..5e7a3da017 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,12 +1,14 @@ from flask_login import current_user from flask_restful import marshal_with, reqparse from flask_restful.inputs import int_range +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from controllers.console import api from controllers.console.explore.error import NotChatAppError from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value from models.model import AppMode @@ -34,14 +36,16 @@ class ConversationListApi(InstalledAppResource): pinned = True if args["pinned"] == "true" else False try: - return WebConversationService.pagination_by_last_id( - app_model=app_model, - user=current_user, - last_id=args["last_id"], - limit=args["limit"], - invoke_from=InvokeFrom.EXPLORE, - pinned=pinned, - ) + with Session(db.engine) as session: + return WebConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=current_user, + last_id=args["last_id"], + limit=args["limit"], + invoke_from=InvokeFrom.EXPLORE, + pinned=pinned, + ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 3d221ff30a..4e11d8005f 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -70,7 +70,7 @@ class MessageFeedbackApi(InstalledAppResource): args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, current_user, args["rating"]) + MessageService.create_feedback(app_model, message_id, current_user, args["rating"], args["content"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index fac1341b39..b8cf019e4f 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -7,6 +7,7 @@ from flask_restful import Resource, marshal_with, reqparse import services from controllers.common import helpers +from controllers.common.errors import RemoteFileUploadError from core.file import helpers as file_helpers from core.helper import ssrf_proxy from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields @@ -43,10 +44,14 @@ class RemoteFileUploadApi(Resource): url = args["url"] - resp = ssrf_proxy.head(url=url) - if resp.status_code != httpx.codes.OK: - resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) - resp.raise_for_status() + try: + resp = ssrf_proxy.head(url=url) + if resp.status_code != httpx.codes.OK: + resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) + if resp.status_code != httpx.codes.OK: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") + except httpx.RequestError as e: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}") file_info = helpers.guess_file_info_from_response(resp) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 2cd6dcda3b..9e62a54699 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -3,12 +3,14 @@ import io from flask import send_file from flask_login import current_user from flask_restful import Resource, reqparse +from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder +from extensions.ext_database import db from libs.helper import alphanumeric, uuid_value from libs.login import login_required from services.tools.api_tools_manage_service import ApiToolManageService @@ -91,12 +93,16 @@ class ToolBuiltinProviderUpdateApi(Resource): args = parser.parse_args() - return BuiltinToolManageService.update_builtin_tool_provider( - user_id, - tenant_id, - provider, - args["credentials"], - ) + with Session(db.engine) as session: + result = BuiltinToolManageService.update_builtin_tool_provider( + session=session, + user_id=user_id, + tenant_id=tenant_id, + provider_name=provider, + credentials=args["credentials"], + ) + session.commit() + return result class ToolBuiltinProviderGetCredentialsApi(Resource): @@ -104,13 +110,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): @login_required @account_initialization_required def get(self, provider): - user_id = current_user.id tenant_id = current_user.current_tenant_id return BuiltinToolManageService.get_builtin_tool_provider_credentials( - user_id, - tenant_id, - provider, + tenant_id=tenant_id, + provider_name=provider, ) diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index c62fd77d36..32940cbc29 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,5 +1,6 @@ from flask_restful import Resource, marshal_with, reqparse from flask_restful.inputs import int_range +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound import services @@ -7,6 +8,7 @@ from controllers.service_api import api from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db from fields.conversation_fields import ( conversation_delete_fields, conversation_infinite_scroll_pagination_fields, @@ -39,14 +41,16 @@ class ConversationApi(Resource): args = parser.parse_args() try: - return ConversationService.pagination_by_last_id( - app_model=app_model, - user=end_user, - last_id=args["last_id"], - limit=args["limit"], - invoke_from=InvokeFrom.SERVICE_API, - sort_by=args["sort_by"], - ) + with Session(db.engine) as session: + return ConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=end_user, + last_id=args["last_id"], + limit=args["limit"], + invoke_from=InvokeFrom.SERVICE_API, + sort_by=args["sort_by"], + ) except services.errors.conversation.LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index ada40ec9cb..599401bc6f 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -104,10 +104,11 @@ class MessageFeedbackApi(Resource): parser = reqparse.RequestParser() parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") + parser.add_argument("content", type=str, location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) + MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index c3b0cd4f44..fe0d7c74f3 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,11 +1,13 @@ from flask_restful import marshal_with, reqparse from flask_restful.inputs import int_range +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from controllers.web import api from controllers.web.error import NotChatAppError from controllers.web.wraps import WebApiResource from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value from models.model import AppMode @@ -40,15 +42,17 @@ class ConversationListApi(WebApiResource): pinned = True if args["pinned"] == "true" else False try: - return WebConversationService.pagination_by_last_id( - app_model=app_model, - user=end_user, - last_id=args["last_id"], - limit=args["limit"], - invoke_from=InvokeFrom.WEB_APP, - pinned=pinned, - sort_by=args["sort_by"], - ) + with Session(db.engine) as session: + return WebConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=end_user, + last_id=args["last_id"], + limit=args["limit"], + invoke_from=InvokeFrom.WEB_APP, + pinned=pinned, + sort_by=args["sort_by"], + ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 98891f5d00..febaab5328 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -108,7 +108,7 @@ class MessageFeedbackApi(WebApiResource): args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) + MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index d6b8eb2855..ae68df6bdc 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -5,6 +5,7 @@ from flask_restful import marshal_with, reqparse import services from controllers.common import helpers +from controllers.common.errors import RemoteFileUploadError from controllers.web.wraps import WebApiResource from core.file import helpers as file_helpers from core.helper import ssrf_proxy @@ -38,10 +39,14 @@ class RemoteFileUploadApi(WebApiResource): url = args["url"] - resp = ssrf_proxy.head(url=url) - if resp.status_code != httpx.codes.OK: - resp = ssrf_proxy.get(url=url, timeout=3) - resp.raise_for_status() + try: + resp = ssrf_proxy.head(url=url) + if resp.status_code != httpx.codes.OK: + resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) + if resp.status_code != httpx.codes.OK: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") + except httpx.RequestError as e: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}") file_info = helpers.guess_file_info_from_response(resp) diff --git a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py index 18b115dfe4..29709914b7 100644 --- a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py +++ b/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py @@ -4,14 +4,17 @@ import logging import queue import re import threading +from collections.abc import Iterable from core.app.entities.queue_entities import ( + MessageQueueMessage, QueueAgentMessageEvent, QueueLLMChunkEvent, QueueNodeSucceededEvent, QueueTextChunkEvent, + WorkflowQueueMessage, ) -from core.model_manager import ModelManager +from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType @@ -21,7 +24,7 @@ class AudioTrunk: self.status = status -def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str): +def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str): if not text_content or text_content.isspace(): return return model_instance.invoke_tts( @@ -29,13 +32,19 @@ def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str): ) -def _process_future(future_queue, audio_queue): +def _process_future( + future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None], + audio_queue: queue.Queue[AudioTrunk], +): while True: try: future = future_queue.get() if future is None: break - for audio in future.result(): + invoke_result = future.result() + if not invoke_result: + continue + for audio in invoke_result: audio_base64 = base64.b64encode(bytes(audio)) audio_queue.put(AudioTrunk("responding", audio=audio_base64)) except Exception as e: @@ -49,8 +58,8 @@ class AppGeneratorTTSPublisher: self.logger = logging.getLogger(__name__) self.tenant_id = tenant_id self.msg_text = "" - self._audio_queue = queue.Queue() - self._msg_queue = queue.Queue() + self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue() + self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() self.match = re.compile(r"[。.!?]") self.model_manager = ModelManager() self.model_instance = self.model_manager.get_default_model_instance( @@ -66,14 +75,11 @@ class AppGeneratorTTSPublisher: self._runtime_thread = threading.Thread(target=self._runtime).start() self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3) - def publish(self, message): - try: - self._msg_queue.put(message) - except Exception as e: - self.logger.warning(e) + def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /): + self._msg_queue.put(message) def _runtime(self): - future_queue = queue.Queue() + future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None] = queue.Queue() threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start() while True: try: @@ -110,7 +116,7 @@ class AppGeneratorTTSPublisher: break future_queue.put(None) - def check_and_get_audio(self) -> AudioTrunk | None: + def check_and_get_audio(self): try: if self._last_audio_event and self._last_audio_event.status == "finish": if self.executor: diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 8e1731b314..c7bf37dd08 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -197,11 +197,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc stream_response=stream_response, ) - def _listen_audio_msg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str): if not publisher: return None - audio_msg: AudioTrunk = publisher.check_and_get_audio() - if audio_msg and audio_msg.status != "finish": + audio_msg = publisher.check_and_get_audio() + if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish": return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None @@ -222,7 +222,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) + audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -512,7 +512,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc # only publish tts message at text chunk streaming if tts_publisher: - tts_publisher.publish(message=queue_message) + tts_publisher.publish(queue_message) self._task_state.answer += delta_text yield self._message_to_stream_response( diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 4c4d282e99..3725c6e6dd 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -1,7 +1,6 @@ import queue import time from abc import abstractmethod -from collections.abc import Generator from enum import Enum from typing import Any @@ -11,9 +10,11 @@ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, + MessageQueueMessage, QueueErrorEvent, QueuePingEvent, QueueStopEvent, + WorkflowQueueMessage, ) from extensions.ext_redis import redis_client @@ -37,11 +38,11 @@ class AppQueueManager: AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}" ) - q = queue.Queue() + q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() self._q = q - def listen(self) -> Generator: + def listen(self): """ Listen to queue :return: diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index b129904efb..fd84908975 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -171,11 +171,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response) - def _listen_audio_msg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str): if not publisher: return None - audio_msg: AudioTrunk = publisher.check_and_get_audio() - if audio_msg and audio_msg.status != "finish": + audio_msg = publisher.check_and_get_audio() + if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish": return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None @@ -196,7 +196,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager): while True: - audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id) + audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id) if audio_response: yield audio_response else: @@ -421,7 +421,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa # only publish tts message at text chunk streaming if tts_publisher: - tts_publisher.publish(message=queue_message) + tts_publisher.publish(queue_message) self._task_state.answer += delta_text yield self._text_chunk_to_stream_response( diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 917649f34e..4216cd46cf 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -201,11 +201,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan stream_response=stream_response, ) - def _listen_audio_msg(self, publisher, task_id: str): + def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str): if publisher is None: return None - audio_msg: AudioTrunk = publisher.check_and_get_audio() - if audio_msg and audio_msg.status != "finish": + audio_msg = publisher.check_and_get_audio() + if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish": # audio_str = audio_msg.audio.decode('utf-8', errors='ignore') return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id) return None diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 951fef1fa1..5061804310 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -508,6 +508,12 @@ class WorkflowCycleManage: :param workflow_run: workflow run :return: """ + # Attach WorkflowRun to an active session so "created_by_role" can be accessed. + workflow_run = db.session.merge(workflow_run) + + # Refresh to ensure any expired attributes are fully loaded + db.session.refresh(workflow_run) + created_by = None if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value: created_by_account = workflow_run.created_by_account diff --git a/api/core/errors/error.py b/api/core/errors/error.py index 3b186476eb..ad921bc255 100644 --- a/api/core/errors/error.py +++ b/api/core/errors/error.py @@ -1,7 +1,7 @@ from typing import Optional -class LLMError(Exception): +class LLMError(ValueError): """Base class for all LLM exceptions.""" description: Optional[str] = None @@ -16,7 +16,7 @@ class LLMBadRequestError(LLMError): description = "Bad Request" -class ProviderTokenNotInitError(Exception): +class ProviderTokenNotInitError(ValueError): """ Custom exception raised when the provider token is not initialized. """ @@ -27,7 +27,7 @@ class ProviderTokenNotInitError(Exception): self.description = args[0] if args else self.description -class QuotaExceededError(Exception): +class QuotaExceededError(ValueError): """ Custom exception raised when the quota for a provider has been exceeded. """ @@ -35,7 +35,7 @@ class QuotaExceededError(Exception): description = "Quota Exceeded" -class AppInvokeQuotaExceededError(Exception): +class AppInvokeQuotaExceededError(ValueError): """ Custom exception raised when the quota for an app has been exceeded. """ @@ -43,7 +43,7 @@ class AppInvokeQuotaExceededError(Exception): description = "App Invoke Quota Exceeded" -class ModelCurrentlyNotSupportError(Exception): +class ModelCurrentlyNotSupportError(ValueError): """ Custom exception raised when the model not support """ @@ -51,7 +51,7 @@ class ModelCurrentlyNotSupportError(Exception): description = "Model Currently Not Support" -class InvokeRateLimitError(Exception): +class InvokeRateLimitError(ValueError): """Raised when the Invoke returns rate limit error.""" description = "Rate Limit Error" diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 011ff382ea..584e3e9698 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -118,7 +118,7 @@ class CodeExecutor: return response.data.stdout or "" @classmethod - def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]) -> dict: + def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]): """ Execute code :param language: code language diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index b7a07b21e1..605719747a 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -25,7 +25,7 @@ class TemplateTransformer(ABC): return runner_script, preload_script @classmethod - def extract_result_str_from_response(cls, response: str) -> str: + def extract_result_str_from_response(cls, response: str): result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL) if not result: raise ValueError("Failed to parse result") @@ -33,13 +33,21 @@ class TemplateTransformer(ABC): return result @classmethod - def transform_response(cls, response: str) -> dict: + def transform_response(cls, response: str) -> Mapping[str, Any]: """ Transform response to dict :param response: response :return: """ - return json.loads(cls.extract_result_str_from_response(response)) + try: + result = json.loads(cls.extract_result_str_from_response(response)) + except json.JSONDecodeError: + raise ValueError("failed to parse response") + if not isinstance(result, dict): + raise ValueError("result must be a dict") + if not all(isinstance(k, str) for k in result): + raise ValueError("result keys must be strings") + return result @classmethod @abstractmethod diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 425b3535c4..424983a819 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -24,7 +24,7 @@ BACKOFF_FACTOR = 0.5 STATUS_FORCELIST = [429, 500, 502, 503, 504] -class MaxRetriesExceededError(Exception): +class MaxRetriesExceededError(ValueError): """Raised when the maximum number of retries is exceeded.""" pass diff --git a/api/core/llm_generator/output_parser/errors.py b/api/core/llm_generator/output_parser/errors.py index 1e743f1757..0922806ca8 100644 --- a/api/core/llm_generator/output_parser/errors.py +++ b/api/core/llm_generator/output_parser/errors.py @@ -1,2 +1,2 @@ -class OutputParserError(Exception): +class OutputParserError(ValueError): pass diff --git a/api/core/model_runtime/errors/invoke.py b/api/core/model_runtime/errors/invoke.py index edfb19c7d0..7675425361 100644 --- a/api/core/model_runtime/errors/invoke.py +++ b/api/core/model_runtime/errors/invoke.py @@ -1,7 +1,7 @@ from typing import Optional -class InvokeError(Exception): +class InvokeError(ValueError): """Base class for all LLM exceptions.""" description: Optional[str] = None diff --git a/api/core/model_runtime/errors/validate.py b/api/core/model_runtime/errors/validate.py index 7fcd2133f9..16bebcc67d 100644 --- a/api/core/model_runtime/errors/validate.py +++ b/api/core/model_runtime/errors/validate.py @@ -1,4 +1,4 @@ -class CredentialsValidateFailedError(Exception): +class CredentialsValidateFailedError(ValueError): """ Credentials validate failed error """ diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index e1d35ff872..c0ea8c6325 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -531,7 +531,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): "source": { "type": "base64", "media_type": message_content.mime_type, - "data": message_content.data, + "data": message_content.base64_data, }, } sub_messages.append(sub_message_dict) diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index b54668a12d..7d19ccbb74 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -21,6 +21,7 @@ from core.model_runtime.entities.message_entities import ( PromptMessageContentType, PromptMessageTool, SystemPromptMessage, + TextPromptMessageContent, ToolPromptMessage, UserPromptMessage, ) @@ -143,7 +144,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): """ try: - ping_message = SystemPromptMessage(content="ping") + ping_message = UserPromptMessage(content="ping") self._generate(model, credentials, [ping_message], {"max_output_tokens": 5}) except Exception as ex: @@ -187,17 +188,23 @@ class GoogleLargeLanguageModel(LargeLanguageModel): config_kwargs["stop_sequences"] = stop genai.configure(api_key=credentials["google_api_key"]) - google_model = genai.GenerativeModel(model_name=model) history = [] + system_instruction = None for msg in prompt_messages: # makes message roles strictly alternating content = self._format_message_to_glm_content(msg) if history and history[-1]["role"] == content["role"]: history[-1]["parts"].extend(content["parts"]) + elif content["role"] == "system": + system_instruction = content["parts"][0] else: history.append(content) + if not history: + raise InvokeError("The user prompt message is required. You only add a system prompt message.") + + google_model = genai.GenerativeModel(model_name=model, system_instruction=system_instruction) response = google_model.generate_content( contents=history, generation_config=genai.types.GenerationConfig(**config_kwargs), @@ -404,7 +411,10 @@ class GoogleLargeLanguageModel(LargeLanguageModel): ) return glm_content elif isinstance(message, SystemPromptMessage): - return {"role": "user", "parts": [to_part(message.content)]} + if isinstance(message.content, list): + text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content) + message.content = "".join(c.data for c in text_contents) + return {"role": "system", "parts": [to_part(message.content)]} elif isinstance(message, ToolPromptMessage): return { "role": "function", diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index b7799ce1fb..a04fc6ee78 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -355,7 +355,13 @@ class TraceTask: def conversation_trace(self, **kwargs): return kwargs - def workflow_trace(self, workflow_run: WorkflowRun, conversation_id, user_id): + def workflow_trace(self, workflow_run: WorkflowRun | None, conversation_id, user_id): + if not workflow_run: + raise ValueError("Workflow run not found") + + db.session.merge(workflow_run) + db.sessoin.refresh(workflow_run) + workflow_id = workflow_run.workflow_id tenant_id = workflow_run.tenant_id workflow_run_id = workflow_run.id diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 992415657e..d17d76333e 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -83,11 +83,15 @@ class DataPostProcessor: if reranking_model: try: model_manager = ModelManager() + reranking_provider_name = reranking_model.get("reranking_provider_name") + reranking_model_name = reranking_model.get("reranking_model_name") + if not reranking_provider_name or not reranking_model_name: + return None rerank_model_instance = model_manager.get_model_instance( tenant_id=tenant_id, - provider=reranking_model["reranking_provider_name"], + provider=reranking_provider_name, model_type=ModelType.RERANK, - model=reranking_model["reranking_model_name"], + model=reranking_model_name, ) return rerank_model_instance except InvokeAuthorizationError: diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index b2141396d6..18f8d4e839 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -103,7 +103,7 @@ class RetrievalService: if exceptions: exception_message = ";\n".join(exceptions) - raise Exception(exception_message) + raise ValueError(exception_message) if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value: data_post_processor = DataPostProcessor( diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py index 09f328cd1f..582ad636b1 100644 --- a/api/core/tools/provider/app_tool_provider.py +++ b/api/core/tools/provider/app_tool_provider.py @@ -62,7 +62,7 @@ class AppToolProviderEntity(ToolProviderController): user_input_form_list = app_model_config.user_input_form_list for input_form in user_input_form_list: # get type - form_type = input_form.keys()[0] + form_type = list(input_form.keys())[0] default = input_form[form_type]["default"] required = input_form[form_type]["required"] label = input_form[form_type]["label"] diff --git a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py new file mode 100644 index 0000000000..050b468b74 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.py @@ -0,0 +1,115 @@ +import json +import operator +from typing import Any, Optional, Union + +import boto3 + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class BedrockRetrieveTool(BuiltinTool): + bedrock_client: Any = None + knowledge_base_id: str = None + topk: int = None + + def _bedrock_retrieve( + self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None + ): + try: + retrieval_query = {"text": query_input} + + retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}} + + # 如果有元数据过滤条件,则添加到检索配置中 + if metadata_filter: + retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter + + response = self.bedrock_client.retrieve( + knowledgeBaseId=knowledge_base_id, + retrievalQuery=retrieval_query, + retrievalConfiguration=retrieval_configuration, + ) + + results = [] + for result in response.get("retrievalResults", []): + results.append( + { + "content": result.get("content", {}).get("text", ""), + "score": result.get("score", 0.0), + "metadata": result.get("metadata", {}), + } + ) + + return results + except Exception as e: + raise Exception(f"Error retrieving from knowledge base: {str(e)}") + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + line = 0 + try: + if not self.bedrock_client: + aws_region = tool_parameters.get("aws_region") + if aws_region: + self.bedrock_client = boto3.client("bedrock-agent-runtime", region_name=aws_region) + else: + self.bedrock_client = boto3.client("bedrock-agent-runtime") + + line = 1 + if not self.knowledge_base_id: + self.knowledge_base_id = tool_parameters.get("knowledge_base_id") + if not self.knowledge_base_id: + return self.create_text_message("Please provide knowledge_base_id") + + line = 2 + if not self.topk: + self.topk = tool_parameters.get("topk", 5) + + line = 3 + query = tool_parameters.get("query", "") + if not query: + return self.create_text_message("Please input query") + + # 获取元数据过滤条件(如果存在) + metadata_filter_str = tool_parameters.get("metadata_filter") + metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None + + line = 4 + retrieved_docs = self._bedrock_retrieve( + query_input=query, + knowledge_base_id=self.knowledge_base_id, + num_results=self.topk, + metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法 + ) + + line = 5 + # Sort results by score in descending order + sorted_docs = sorted(retrieved_docs, key=operator.itemgetter("score"), reverse=True) + + line = 6 + return [self.create_json_message(res) for res in sorted_docs] + + except Exception as e: + return self.create_text_message(f"Exception {str(e)}, line : {line}") + + def validate_parameters(self, parameters: dict[str, Any]) -> None: + """ + Validate the parameters + """ + if not parameters.get("knowledge_base_id"): + raise ValueError("knowledge_base_id is required") + + if not parameters.get("query"): + raise ValueError("query is required") + + # 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供) + metadata_filter_str = parameters.get("metadata_filter") + if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict): + raise ValueError("metadata_filter must be a valid JSON object") diff --git a/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml new file mode 100644 index 0000000000..9e51d52def --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/bedrock_retrieve.yaml @@ -0,0 +1,87 @@ +identity: + name: bedrock_retrieve + author: AWS + label: + en_US: Bedrock Retrieve + zh_Hans: Bedrock检索 + pt_BR: Bedrock Retrieve + icon: icon.svg + +description: + human: + en_US: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool + zh_Hans: Amazon Bedrock知识库检索工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署说明 + pt_BR: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. + llm: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool + +parameters: + - name: knowledge_base_id + type: string + required: true + label: + en_US: Bedrock Knowledge Base ID + zh_Hans: Bedrock知识库ID + pt_BR: Bedrock Knowledge Base ID + human_description: + en_US: ID of the Bedrock Knowledge Base to retrieve from + zh_Hans: 用于检索的Bedrock知识库ID + pt_BR: ID of the Bedrock Knowledge Base to retrieve from + llm_description: ID of the Bedrock Knowledge Base to retrieve from + form: form + + - name: query + type: string + required: true + label: + en_US: Query string + zh_Hans: 查询语句 + pt_BR: Query string + human_description: + en_US: The search query to retrieve relevant information + zh_Hans: 用于检索相关信息的查询语句 + pt_BR: The search query to retrieve relevant information + llm_description: The search query to retrieve relevant information + form: llm + + - name: topk + type: number + required: false + form: form + label: + en_US: Limit for results count + zh_Hans: 返回结果数量限制 + pt_BR: Limit for results count + human_description: + en_US: Maximum number of results to return + zh_Hans: 最大返回结果数量 + pt_BR: Maximum number of results to return + min: 1 + max: 10 + default: 5 + + - name: aws_region + type: string + required: false + label: + en_US: AWS Region + zh_Hans: AWS 区域 + pt_BR: AWS Region + human_description: + en_US: AWS region where the Bedrock Knowledge Base is located + zh_Hans: Bedrock知识库所在的AWS区域 + pt_BR: AWS region where the Bedrock Knowledge Base is located + llm_description: AWS region where the Bedrock Knowledge Base is located + form: form + + - name: metadata_filter + type: string + required: false + label: + en_US: Metadata Filter + zh_Hans: 元数据过滤器 + pt_BR: Metadata Filter + human_description: + en_US: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})' + zh_Hans: '元数据的JSON格式过滤条件(例如,{{"greaterThan": {"key: "aaa", "value": 10}})' + pt_BR: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})' + form: form diff --git a/api/core/tools/provider/builtin/aws/tools/nova_canvas.py b/api/core/tools/provider/builtin/aws/tools/nova_canvas.py new file mode 100644 index 0000000000..954dbe35a4 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/nova_canvas.py @@ -0,0 +1,357 @@ +import base64 +import json +import logging +import re +from datetime import datetime +from typing import Any, Union +from urllib.parse import urlparse + +import boto3 + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.tool.builtin_tool import BuiltinTool + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class NovaCanvasTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + Invoke AWS Bedrock Nova Canvas model for image generation + """ + # Get common parameters + prompt = tool_parameters.get("prompt", "") + image_output_s3uri = tool_parameters.get("image_output_s3uri", "").strip() + if not prompt: + return self.create_text_message("Please provide a text prompt for image generation.") + if not image_output_s3uri or urlparse(image_output_s3uri).scheme != "s3": + return self.create_text_message("Please provide an valid S3 URI for image output.") + + task_type = tool_parameters.get("task_type", "TEXT_IMAGE") + aws_region = tool_parameters.get("aws_region", "us-east-1") + + # Get common image generation config parameters + width = tool_parameters.get("width", 1024) + height = tool_parameters.get("height", 1024) + cfg_scale = tool_parameters.get("cfg_scale", 8.0) + negative_prompt = tool_parameters.get("negative_prompt", "") + seed = tool_parameters.get("seed", 0) + quality = tool_parameters.get("quality", "standard") + + # Handle S3 image if provided + image_input_s3uri = tool_parameters.get("image_input_s3uri", "") + if task_type != "TEXT_IMAGE": + if not image_input_s3uri or urlparse(image_input_s3uri).scheme != "s3": + return self.create_text_message("Please provide a valid S3 URI for image to image generation.") + + # Parse S3 URI + parsed_uri = urlparse(image_input_s3uri) + bucket = parsed_uri.netloc + key = parsed_uri.path.lstrip("/") + + # Initialize S3 client and download image + s3_client = boto3.client("s3") + response = s3_client.get_object(Bucket=bucket, Key=key) + image_data = response["Body"].read() + + # Base64 encode the image + input_image = base64.b64encode(image_data).decode("utf-8") + + try: + # Initialize Bedrock client + bedrock = boto3.client(service_name="bedrock-runtime", region_name=aws_region) + + # Base image generation config + image_generation_config = { + "width": width, + "height": height, + "cfgScale": cfg_scale, + "seed": seed, + "numberOfImages": 1, + "quality": quality, + } + + # Prepare request body based on task type + body = {"imageGenerationConfig": image_generation_config} + + if task_type == "TEXT_IMAGE": + body["taskType"] = "TEXT_IMAGE" + body["textToImageParams"] = {"text": prompt} + if negative_prompt: + body["textToImageParams"]["negativeText"] = negative_prompt + + elif task_type == "COLOR_GUIDED_GENERATION": + colors = tool_parameters.get("colors", "#ff8080-#ffb280-#ffe680-#ffe680") + if not self._validate_color_string(colors): + return self.create_text_message("Please provide valid colors in hexadecimal format.") + + body["taskType"] = "COLOR_GUIDED_GENERATION" + body["colorGuidedGenerationParams"] = { + "colors": colors.split("-"), + "referenceImage": input_image, + "text": prompt, + } + if negative_prompt: + body["colorGuidedGenerationParams"]["negativeText"] = negative_prompt + + elif task_type == "IMAGE_VARIATION": + similarity_strength = tool_parameters.get("similarity_strength", 0.5) + + body["taskType"] = "IMAGE_VARIATION" + body["imageVariationParams"] = { + "images": [input_image], + "similarityStrength": similarity_strength, + "text": prompt, + } + if negative_prompt: + body["imageVariationParams"]["negativeText"] = negative_prompt + + elif task_type == "INPAINTING": + mask_prompt = tool_parameters.get("mask_prompt") + if not mask_prompt: + return self.create_text_message("Please provide a mask prompt for image inpainting.") + + body["taskType"] = "INPAINTING" + body["inPaintingParams"] = {"image": input_image, "maskPrompt": mask_prompt, "text": prompt} + if negative_prompt: + body["inPaintingParams"]["negativeText"] = negative_prompt + + elif task_type == "OUTPAINTING": + mask_prompt = tool_parameters.get("mask_prompt") + if not mask_prompt: + return self.create_text_message("Please provide a mask prompt for image outpainting.") + outpainting_mode = tool_parameters.get("outpainting_mode", "DEFAULT") + + body["taskType"] = "OUTPAINTING" + body["outPaintingParams"] = { + "image": input_image, + "maskPrompt": mask_prompt, + "outPaintingMode": outpainting_mode, + "text": prompt, + } + if negative_prompt: + body["outPaintingParams"]["negativeText"] = negative_prompt + + elif task_type == "BACKGROUND_REMOVAL": + body["taskType"] = "BACKGROUND_REMOVAL" + body["backgroundRemovalParams"] = {"image": input_image} + + else: + return self.create_text_message(f"Unsupported task type: {task_type}") + + # Call Nova Canvas model + response = bedrock.invoke_model( + body=json.dumps(body), + modelId="amazon.nova-canvas-v1:0", + accept="application/json", + contentType="application/json", + ) + + # Process response + response_body = json.loads(response.get("body").read()) + if response_body.get("error"): + raise Exception(f"Error in model response: {response_body.get('error')}") + base64_image = response_body.get("images")[0] + + # Upload to S3 if image_output_s3uri is provided + try: + # Parse S3 URI for output + parsed_uri = urlparse(image_output_s3uri) + output_bucket = parsed_uri.netloc + output_base_path = parsed_uri.path.lstrip("/") + # Generate filename with timestamp + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_key = f"{output_base_path}/canvas-output-{timestamp}.png" + + # Initialize S3 client if not already done + s3_client = boto3.client("s3", region_name=aws_region) + + # Decode base64 image and upload to S3 + image_data = base64.b64decode(base64_image) + s3_client.put_object(Bucket=output_bucket, Key=output_key, Body=image_data, ContentType="image/png") + logger.info(f"Image uploaded to s3://{output_bucket}/{output_key}") + except Exception as e: + logger.exception("Failed to upload image to S3") + # Return image + return [ + self.create_text_message(f"Image is available at: s3://{output_bucket}/{output_key}"), + self.create_blob_message( + blob=base64.b64decode(base64_image), + meta={"mime_type": "image/png"}, + save_as=self.VariableKey.IMAGE.value, + ), + ] + + except Exception as e: + return self.create_text_message(f"Failed to generate image: {str(e)}") + + def _validate_color_string(self, color_string) -> bool: + color_pattern = r"^#[0-9a-fA-F]{6}(?:-#[0-9a-fA-F]{6})*$" + + if re.match(color_pattern, color_string): + return True + return False + + def get_runtime_parameters(self) -> list[ToolParameter]: + parameters = [ + ToolParameter( + name="prompt", + label=I18nObject(en_US="Prompt", zh_Hans="提示词"), + type=ToolParameter.ToolParameterType.STRING, + required=True, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="Text description of the image you want to generate or modify", + zh_Hans="您想要生成或修改的图像的文本描述", + ), + llm_description="Describe the image you want to generate or how you want to modify the input image", + ), + ToolParameter( + name="image_input_s3uri", + label=I18nObject(en_US="Input image s3 uri", zh_Hans="输入图片的s3 uri"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject(en_US="Image to be modified", zh_Hans="想要修改的图片"), + ), + ToolParameter( + name="image_output_s3uri", + label=I18nObject(en_US="Output Image S3 URI", zh_Hans="输出图片的S3 URI目录"), + type=ToolParameter.ToolParameterType.STRING, + required=True, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="S3 URI where the generated image should be uploaded", zh_Hans="生成的图像应该上传到的S3 URI" + ), + ), + ToolParameter( + name="width", + label=I18nObject(en_US="Width", zh_Hans="宽度"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=1024, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Width of the generated image", zh_Hans="生成图像的宽度"), + ), + ToolParameter( + name="height", + label=I18nObject(en_US="Height", zh_Hans="高度"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=1024, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Height of the generated image", zh_Hans="生成图像的高度"), + ), + ToolParameter( + name="cfg_scale", + label=I18nObject(en_US="CFG Scale", zh_Hans="CFG比例"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=8.0, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="How strongly the image should conform to the prompt", zh_Hans="图像应该多大程度上符合提示词" + ), + ), + ToolParameter( + name="negative_prompt", + label=I18nObject(en_US="Negative Prompt", zh_Hans="负面提示词"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default="", + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="Things you don't want in the generated image", zh_Hans="您不想在生成的图像中出现的内容" + ), + ), + ToolParameter( + name="seed", + label=I18nObject(en_US="Seed", zh_Hans="种子值"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=0, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Random seed for image generation", zh_Hans="图像生成的随机种子"), + ), + ToolParameter( + name="aws_region", + label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default="us-east-1", + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"), + ), + ToolParameter( + name="task_type", + label=I18nObject(en_US="Task Type", zh_Hans="任务类型"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default="TEXT_IMAGE", + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject(en_US="Type of image generation task", zh_Hans="图像生成任务的类型"), + ), + ToolParameter( + name="quality", + label=I18nObject(en_US="Quality", zh_Hans="质量"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default="standard", + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="Quality of the generated image (standard or premium)", zh_Hans="生成图像的质量(标准或高级)" + ), + ), + ToolParameter( + name="colors", + label=I18nObject(en_US="Colors", zh_Hans="颜色"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="List of colors for color-guided generation, example: #ff8080-#ffb280-#ffe680-#ffe680", + zh_Hans="颜色引导生成的颜色列表, 例子: #ff8080-#ffb280-#ffe680-#ffe680", + ), + ), + ToolParameter( + name="similarity_strength", + label=I18nObject(en_US="Similarity Strength", zh_Hans="相似度强度"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=0.5, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="How similar the generated image should be to the input image (0.0 to 1.0)", + zh_Hans="生成的图像应该与输入图像的相似程度(0.0到1.0)", + ), + ), + ToolParameter( + name="mask_prompt", + label=I18nObject(en_US="Mask Prompt", zh_Hans="蒙版提示词"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="Text description to generate mask for inpainting/outpainting", + zh_Hans="用于生成内补绘制/外补绘制蒙版的文本描述", + ), + ), + ToolParameter( + name="outpainting_mode", + label=I18nObject(en_US="Outpainting Mode", zh_Hans="外补绘制模式"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default="DEFAULT", + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="Mode for outpainting (DEFAULT or other supported modes)", + zh_Hans="外补绘制的模式(DEFAULT或其他支持的模式)", + ), + ), + ] + + return parameters diff --git a/api/core/tools/provider/builtin/aws/tools/nova_canvas.yaml b/api/core/tools/provider/builtin/aws/tools/nova_canvas.yaml new file mode 100644 index 0000000000..a72fd9c8ef --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/nova_canvas.yaml @@ -0,0 +1,175 @@ +identity: + name: nova_canvas + author: AWS + label: + en_US: AWS Bedrock Nova Canvas + zh_Hans: AWS Bedrock Nova Canvas + icon: icon.svg +description: + human: + en_US: A tool for generating and modifying images using AWS Bedrock's Nova Canvas model. Supports text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html + zh_Hans: 使用 AWS Bedrock 的 Nova Canvas 模型生成和修改图像的工具。支持文生图、颜色引导生成、图像变体、内补绘制、外补绘制和背景移除功能, 输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html。 + llm: Generate or modify images using AWS Bedrock's Nova Canvas model with multiple task types including text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal. +parameters: + - name: task_type + type: string + required: false + default: TEXT_IMAGE + label: + en_US: Task Type + zh_Hans: 任务类型 + human_description: + en_US: Type of image generation task (TEXT_IMAGE, COLOR_GUIDED_GENERATION, IMAGE_VARIATION, INPAINTING, OUTPAINTING, BACKGROUND_REMOVAL) + zh_Hans: 图像生成任务的类型(文生图、颜色引导生成、图像变体、内补绘制、外补绘制、背景移除) + form: llm + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + human_description: + en_US: Text description of the image you want to generate or modify + zh_Hans: 您想要生成或修改的图像的文本描述 + llm_description: Describe the image you want to generate or how you want to modify the input image + form: llm + - name: image_input_s3uri + type: string + required: false + label: + en_US: Input image s3 uri + zh_Hans: 输入图片的s3 uri + human_description: + en_US: The input image to modify (required for all modes except TEXT_IMAGE) + zh_Hans: 要修改的输入图像(除文生图外的所有模式都需要) + llm_description: The input image you want to modify. Required for all modes except TEXT_IMAGE. + form: llm + - name: image_output_s3uri + type: string + required: true + label: + en_US: Output S3 URI + zh_Hans: 输出S3 URI + human_description: + en_US: The S3 URI where the generated image will be saved. If provided, the image will be uploaded with name format canvas-output-{timestamp}.png + zh_Hans: 生成的图像将保存到的S3 URI。如果提供,图像将以canvas-output-{timestamp}.png的格式上传 + llm_description: Optional S3 URI where the generated image will be uploaded. The image will be saved with a timestamp-based filename. + form: form + - name: negative_prompt + type: string + required: false + label: + en_US: Negative Prompt + zh_Hans: 负面提示词 + human_description: + en_US: Things you don't want in the generated image + zh_Hans: 您不想在生成的图像中出现的内容 + form: llm + - name: width + type: number + required: false + label: + en_US: Width + zh_Hans: 宽度 + human_description: + en_US: Width of the generated image + zh_Hans: 生成图像的宽度 + form: form + default: 1024 + - name: height + type: number + required: false + label: + en_US: Height + zh_Hans: 高度 + human_description: + en_US: Height of the generated image + zh_Hans: 生成图像的高度 + form: form + default: 1024 + - name: cfg_scale + type: number + required: false + label: + en_US: CFG Scale + zh_Hans: CFG比例 + human_description: + en_US: How strongly the image should conform to the prompt + zh_Hans: 图像应该多大程度上符合提示词 + form: form + default: 8.0 + - name: seed + type: number + required: false + label: + en_US: Seed + zh_Hans: 种子值 + human_description: + en_US: Random seed for image generation + zh_Hans: 图像生成的随机种子 + form: form + default: 0 + - name: aws_region + type: string + required: false + default: us-east-1 + label: + en_US: AWS Region + zh_Hans: AWS 区域 + human_description: + en_US: AWS region for Bedrock service + zh_Hans: Bedrock 服务的 AWS 区域 + form: form + - name: quality + type: string + required: false + default: standard + label: + en_US: Quality + zh_Hans: 质量 + human_description: + en_US: Quality of the generated image (standard or premium) + zh_Hans: 生成图像的质量(标准或高级) + form: form + - name: colors + type: string + required: false + label: + en_US: Colors + zh_Hans: 颜色 + human_description: + en_US: List of colors for color-guided generation + zh_Hans: 颜色引导生成的颜色列表 + form: form + - name: similarity_strength + type: number + required: false + default: 0.5 + label: + en_US: Similarity Strength + zh_Hans: 相似度强度 + human_description: + en_US: How similar the generated image should be to the input image (0.0 to 1.0) + zh_Hans: 生成的图像应该与输入图像的相似程度(0.0到1.0) + form: form + - name: mask_prompt + type: string + required: false + label: + en_US: Mask Prompt + zh_Hans: 蒙版提示词 + human_description: + en_US: Text description to generate mask for inpainting/outpainting + zh_Hans: 用于生成内补绘制/外补绘制蒙版的文本描述 + form: llm + - name: outpainting_mode + type: string + required: false + default: DEFAULT + label: + en_US: Outpainting Mode + zh_Hans: 外补绘制模式 + human_description: + en_US: Mode for outpainting (DEFAULT or other supported modes) + zh_Hans: 外补绘制的模式(DEFAULT或其他支持的模式) + form: form diff --git a/api/core/tools/provider/builtin/aws/tools/nova_reel.py b/api/core/tools/provider/builtin/aws/tools/nova_reel.py new file mode 100644 index 0000000000..bfd3d302b2 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/nova_reel.py @@ -0,0 +1,371 @@ +import base64 +import logging +import time +from io import BytesIO +from typing import Any, Optional, Union +from urllib.parse import urlparse + +import boto3 +from botocore.exceptions import ClientError +from PIL import Image + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.tool.builtin_tool import BuiltinTool + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +NOVA_REEL_DEFAULT_REGION = "us-east-1" +NOVA_REEL_DEFAULT_DIMENSION = "1280x720" +NOVA_REEL_DEFAULT_FPS = 24 +NOVA_REEL_DEFAULT_DURATION = 6 +NOVA_REEL_MODEL_ID = "amazon.nova-reel-v1:0" +NOVA_REEL_STATUS_CHECK_INTERVAL = 5 + +# Image requirements +NOVA_REEL_REQUIRED_IMAGE_WIDTH = 1280 +NOVA_REEL_REQUIRED_IMAGE_HEIGHT = 720 +NOVA_REEL_REQUIRED_IMAGE_MODE = "RGB" + + +class NovaReelTool(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + Invoke AWS Bedrock Nova Reel model for video generation. + + Args: + user_id: The ID of the user making the request + tool_parameters: Dictionary containing the tool parameters + + Returns: + ToolInvokeMessage containing either the video content or status information + """ + try: + # Validate and extract parameters + params = self._validate_and_extract_parameters(tool_parameters) + if isinstance(params, ToolInvokeMessage): + return params + + # Initialize AWS clients + bedrock, s3_client = self._initialize_aws_clients(params["aws_region"]) + + # Prepare model input + model_input = self._prepare_model_input(params, s3_client) + if isinstance(model_input, ToolInvokeMessage): + return model_input + + # Start video generation + invocation = self._start_video_generation(bedrock, model_input, params["video_output_s3uri"]) + invocation_arn = invocation["invocationArn"] + + # Handle async/sync mode + return self._handle_generation_mode(bedrock, s3_client, invocation_arn, params["async_mode"]) + + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "Unknown") + error_message = e.response.get("Error", {}).get("Message", str(e)) + logger.exception(f"AWS API error: {error_code} - {error_message}") + return self.create_text_message(f"AWS service error: {error_code} - {error_message}") + except Exception as e: + logger.error(f"Unexpected error in video generation: {str(e)}", exc_info=True) + return self.create_text_message(f"Failed to generate video: {str(e)}") + + def _validate_and_extract_parameters( + self, tool_parameters: dict[str, Any] + ) -> Union[dict[str, Any], ToolInvokeMessage]: + """Validate and extract parameters from the input dictionary.""" + prompt = tool_parameters.get("prompt", "") + video_output_s3uri = tool_parameters.get("video_output_s3uri", "").strip() + + # Validate required parameters + if not prompt: + return self.create_text_message("Please provide a text prompt for video generation.") + if not video_output_s3uri: + return self.create_text_message("Please provide an S3 URI for video output.") + + # Validate S3 URI format + if not video_output_s3uri.startswith("s3://"): + return self.create_text_message("Invalid S3 URI format. Must start with 's3://'") + + # Ensure S3 URI ends with '/' + video_output_s3uri = video_output_s3uri if video_output_s3uri.endswith("/") else video_output_s3uri + "/" + + return { + "prompt": prompt, + "video_output_s3uri": video_output_s3uri, + "image_input_s3uri": tool_parameters.get("image_input_s3uri", "").strip(), + "aws_region": tool_parameters.get("aws_region", NOVA_REEL_DEFAULT_REGION), + "dimension": tool_parameters.get("dimension", NOVA_REEL_DEFAULT_DIMENSION), + "seed": int(tool_parameters.get("seed", 0)), + "fps": int(tool_parameters.get("fps", NOVA_REEL_DEFAULT_FPS)), + "duration": int(tool_parameters.get("duration", NOVA_REEL_DEFAULT_DURATION)), + "async_mode": bool(tool_parameters.get("async", True)), + } + + def _initialize_aws_clients(self, region: str) -> tuple[Any, Any]: + """Initialize AWS Bedrock and S3 clients.""" + bedrock = boto3.client(service_name="bedrock-runtime", region_name=region) + s3_client = boto3.client("s3", region_name=region) + return bedrock, s3_client + + def _prepare_model_input(self, params: dict[str, Any], s3_client: Any) -> Union[dict[str, Any], ToolInvokeMessage]: + """Prepare the input for the Nova Reel model.""" + model_input = { + "taskType": "TEXT_VIDEO", + "textToVideoParams": {"text": params["prompt"]}, + "videoGenerationConfig": { + "durationSeconds": params["duration"], + "fps": params["fps"], + "dimension": params["dimension"], + "seed": params["seed"], + }, + } + + # Add image if provided + if params["image_input_s3uri"]: + try: + image_data = self._get_image_from_s3(s3_client, params["image_input_s3uri"]) + if not image_data: + return self.create_text_message("Failed to retrieve image from S3") + + # Process and validate image + processed_image = self._process_and_validate_image(image_data) + if isinstance(processed_image, ToolInvokeMessage): + return processed_image + + # Convert processed image to base64 + img_buffer = BytesIO() + processed_image.save(img_buffer, format="PNG") + img_buffer.seek(0) + input_image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8") + + model_input["textToVideoParams"]["images"] = [ + {"format": "png", "source": {"bytes": input_image_base64}} + ] + except Exception as e: + logger.error(f"Error processing input image: {str(e)}", exc_info=True) + return self.create_text_message(f"Failed to process input image: {str(e)}") + + return model_input + + def _process_and_validate_image(self, image_data: bytes) -> Union[Image.Image, ToolInvokeMessage]: + """ + Process and validate the input image according to Nova Reel requirements. + + Requirements: + - Must be 1280x720 pixels + - Must be RGB format (8 bits per channel) + - If PNG, alpha channel must not have transparent/translucent pixels + """ + try: + # Open image + img = Image.open(BytesIO(image_data)) + + # Convert RGBA to RGB if needed, ensuring no transparency + if img.mode == "RGBA": + # Check for transparency + if img.getchannel("A").getextrema()[0] < 255: + return self.create_text_message( + "PNG image contains transparent or translucent pixels, which is not supported. " + "Please provide an image without transparency." + ) + # Convert to RGB + img = img.convert("RGB") + elif img.mode != "RGB": + # Convert any other mode to RGB + img = img.convert("RGB") + + # Validate/adjust dimensions + if img.size != (NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT): + logger.warning( + f"Image dimensions {img.size} do not match required dimensions " + f"({NOVA_REEL_REQUIRED_IMAGE_WIDTH}x{NOVA_REEL_REQUIRED_IMAGE_HEIGHT}). Resizing..." + ) + img = img.resize( + (NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT), Image.Resampling.LANCZOS + ) + + # Validate bit depth + if img.mode != NOVA_REEL_REQUIRED_IMAGE_MODE: + return self.create_text_message( + f"Image must be in {NOVA_REEL_REQUIRED_IMAGE_MODE} mode with 8 bits per channel" + ) + + return img + + except Exception as e: + logger.error(f"Error processing image: {str(e)}", exc_info=True) + return self.create_text_message( + "Failed to process image. Please ensure the image is a valid JPEG or PNG file." + ) + + def _get_image_from_s3(self, s3_client: Any, s3_uri: str) -> Optional[bytes]: + """Download and return image data from S3.""" + parsed_uri = urlparse(s3_uri) + bucket = parsed_uri.netloc + key = parsed_uri.path.lstrip("/") + + response = s3_client.get_object(Bucket=bucket, Key=key) + return response["Body"].read() + + def _start_video_generation(self, bedrock: Any, model_input: dict[str, Any], output_s3uri: str) -> dict[str, Any]: + """Start the async video generation process.""" + return bedrock.start_async_invoke( + modelId=NOVA_REEL_MODEL_ID, + modelInput=model_input, + outputDataConfig={"s3OutputDataConfig": {"s3Uri": output_s3uri}}, + ) + + def _handle_generation_mode( + self, bedrock: Any, s3_client: Any, invocation_arn: str, async_mode: bool + ) -> ToolInvokeMessage: + """Handle async or sync video generation mode.""" + invocation_response = bedrock.get_async_invoke(invocationArn=invocation_arn) + video_path = invocation_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"] + video_uri = f"{video_path}/output.mp4" + + if async_mode: + return self.create_text_message( + f"Video generation started.\nInvocation ARN: {invocation_arn}\n" + f"Video will be available at: {video_uri}" + ) + + return self._wait_for_completion(bedrock, s3_client, invocation_arn) + + def _wait_for_completion(self, bedrock: Any, s3_client: Any, invocation_arn: str) -> ToolInvokeMessage: + """Wait for video generation completion and handle the result.""" + while True: + status_response = bedrock.get_async_invoke(invocationArn=invocation_arn) + status = status_response["status"] + video_path = status_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"] + + if status == "Completed": + return self._handle_completed_video(s3_client, video_path) + elif status == "Failed": + failure_message = status_response.get("failureMessage", "Unknown error") + return self.create_text_message(f"Video generation failed.\nError: {failure_message}") + elif status == "InProgress": + time.sleep(NOVA_REEL_STATUS_CHECK_INTERVAL) + else: + return self.create_text_message(f"Unexpected status: {status}") + + def _handle_completed_video(self, s3_client: Any, video_path: str) -> ToolInvokeMessage: + """Handle completed video generation and return the result.""" + parsed_uri = urlparse(video_path) + bucket = parsed_uri.netloc + key = parsed_uri.path.lstrip("/") + "/output.mp4" + + try: + response = s3_client.get_object(Bucket=bucket, Key=key) + video_content = response["Body"].read() + return [ + self.create_text_message(f"Video is available at: {video_path}/output.mp4"), + self.create_blob_message(blob=video_content, meta={"mime_type": "video/mp4"}, save_as="output.mp4"), + ] + except Exception as e: + logger.error(f"Error downloading video: {str(e)}", exc_info=True) + return self.create_text_message( + f"Video generation completed but failed to download video: {str(e)}\n" + f"Video is available at: s3://{bucket}/{key}" + ) + + def get_runtime_parameters(self) -> list[ToolParameter]: + """Define the tool's runtime parameters.""" + parameters = [ + ToolParameter( + name="prompt", + label=I18nObject(en_US="Prompt", zh_Hans="提示词"), + type=ToolParameter.ToolParameterType.STRING, + required=True, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="Text description of the video you want to generate", zh_Hans="您想要生成的视频的文本描述" + ), + llm_description="Describe the video you want to generate", + ), + ToolParameter( + name="video_output_s3uri", + label=I18nObject(en_US="Output S3 URI", zh_Hans="输出S3 URI"), + type=ToolParameter.ToolParameterType.STRING, + required=True, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="S3 URI where the generated video will be stored", zh_Hans="生成的视频将存储的S3 URI" + ), + ), + ToolParameter( + name="dimension", + label=I18nObject(en_US="Dimension", zh_Hans="尺寸"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default=NOVA_REEL_DEFAULT_DIMENSION, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Video dimensions (width x height)", zh_Hans="视频尺寸(宽 x 高)"), + ), + ToolParameter( + name="duration", + label=I18nObject(en_US="Duration", zh_Hans="时长"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=NOVA_REEL_DEFAULT_DURATION, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Video duration in seconds", zh_Hans="视频时长(秒)"), + ), + ToolParameter( + name="seed", + label=I18nObject(en_US="Seed", zh_Hans="种子值"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=0, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="Random seed for video generation", zh_Hans="视频生成的随机种子"), + ), + ToolParameter( + name="fps", + label=I18nObject(en_US="FPS", zh_Hans="帧率"), + type=ToolParameter.ToolParameterType.NUMBER, + required=False, + default=NOVA_REEL_DEFAULT_FPS, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject( + en_US="Frames per second for the generated video", zh_Hans="生成视频的每秒帧数" + ), + ), + ToolParameter( + name="async", + label=I18nObject(en_US="Async Mode", zh_Hans="异步模式"), + type=ToolParameter.ToolParameterType.BOOLEAN, + required=False, + default=True, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="Whether to run in async mode (return immediately) or sync mode (wait for completion)", + zh_Hans="是否以异步模式运行(立即返回)或同步模式(等待完成)", + ), + ), + ToolParameter( + name="aws_region", + label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + default=NOVA_REEL_DEFAULT_REGION, + form=ToolParameter.ToolParameterForm.FORM, + human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"), + ), + ToolParameter( + name="image_input_s3uri", + label=I18nObject(en_US="Input Image S3 URI", zh_Hans="输入图像S3 URI"), + type=ToolParameter.ToolParameterType.STRING, + required=False, + form=ToolParameter.ToolParameterForm.LLM, + human_description=I18nObject( + en_US="S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame", + zh_Hans="用作第一帧的输入图像(1280x720 JPEG/PNG)的S3 URI", + ), + ), + ] + + return parameters diff --git a/api/core/tools/provider/builtin/aws/tools/nova_reel.yaml b/api/core/tools/provider/builtin/aws/tools/nova_reel.yaml new file mode 100644 index 0000000000..16df5ba5c9 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/nova_reel.yaml @@ -0,0 +1,124 @@ +identity: + name: nova_reel + author: AWS + label: + en_US: AWS Bedrock Nova Reel + zh_Hans: AWS Bedrock Nova Reel + icon: icon.svg +description: + human: + en_US: A tool for generating videos using AWS Bedrock's Nova Reel model. Supports text-to-video generation and image-to-video generation with customizable parameters like duration, FPS, and dimensions. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html + zh_Hans: 使用 AWS Bedrock 的 Nova Reel 模型生成视频的工具。支持文本生成视频和图像生成视频功能,可自定义持续时间、帧率和尺寸等参数。输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html + llm: Generate videos using AWS Bedrock's Nova Reel model with support for both text-to-video and image-to-video generation, allowing customization of video properties like duration, frame rate, and resolution. + +parameters: + - name: prompt + type: string + required: true + label: + en_US: Prompt + zh_Hans: 提示词 + human_description: + en_US: Text description of the video you want to generate + zh_Hans: 您想要生成的视频的文本描述 + llm_description: Describe the video you want to generate + form: llm + + - name: video_output_s3uri + type: string + required: true + label: + en_US: Output S3 URI + zh_Hans: 输出S3 URI + human_description: + en_US: S3 URI where the generated video will be stored + zh_Hans: 生成的视频将存储的S3 URI + form: form + + - name: dimension + type: string + required: false + default: 1280x720 + label: + en_US: Dimension + zh_Hans: 尺寸 + human_description: + en_US: Video dimensions (width x height) + zh_Hans: 视频尺寸(宽 x 高) + form: form + + - name: duration + type: number + required: false + default: 6 + label: + en_US: Duration + zh_Hans: 时长 + human_description: + en_US: Video duration in seconds + zh_Hans: 视频时长(秒) + form: form + + - name: seed + type: number + required: false + default: 0 + label: + en_US: Seed + zh_Hans: 种子值 + human_description: + en_US: Random seed for video generation + zh_Hans: 视频生成的随机种子 + form: form + + - name: fps + type: number + required: false + default: 24 + label: + en_US: FPS + zh_Hans: 帧率 + human_description: + en_US: Frames per second for the generated video + zh_Hans: 生成视频的每秒帧数 + form: form + + - name: async + type: boolean + required: false + default: true + label: + en_US: Async Mode + zh_Hans: 异步模式 + human_description: + en_US: Whether to run in async mode (return immediately) or sync mode (wait for completion) + zh_Hans: 是否以异步模式运行(立即返回)或同步模式(等待完成) + form: llm + + - name: aws_region + type: string + required: false + default: us-east-1 + label: + en_US: AWS Region + zh_Hans: AWS 区域 + human_description: + en_US: AWS region for Bedrock service + zh_Hans: Bedrock 服务的 AWS 区域 + form: form + + - name: image_input_s3uri + type: string + required: false + label: + en_US: Input Image S3 URI + zh_Hans: 输入图像S3 URI + human_description: + en_US: S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame + zh_Hans: 用作第一帧的输入图像(1280x720 JPEG/PNG)的S3 URI + form: llm + +development: + dependencies: + - boto3 + - pillow diff --git a/api/core/tools/provider/builtin/aws/tools/s3_operator.py b/api/core/tools/provider/builtin/aws/tools/s3_operator.py new file mode 100644 index 0000000000..e4026b07a8 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/s3_operator.py @@ -0,0 +1,80 @@ +from typing import Any, Union +from urllib.parse import urlparse + +import boto3 + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class S3Operator(BuiltinTool): + s3_client: Any = None + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + try: + # Initialize S3 client if not already done + if not self.s3_client: + aws_region = tool_parameters.get("aws_region") + if aws_region: + self.s3_client = boto3.client("s3", region_name=aws_region) + else: + self.s3_client = boto3.client("s3") + + # Parse S3 URI + s3_uri = tool_parameters.get("s3_uri") + if not s3_uri: + return self.create_text_message("s3_uri parameter is required") + + parsed_uri = urlparse(s3_uri) + if parsed_uri.scheme != "s3": + return self.create_text_message("Invalid S3 URI format. Must start with 's3://'") + + bucket = parsed_uri.netloc + # Remove leading slash from key + key = parsed_uri.path.lstrip("/") + + operation_type = tool_parameters.get("operation_type", "read") + generate_presign_url = tool_parameters.get("generate_presign_url", False) + presign_expiry = int(tool_parameters.get("presign_expiry", 3600)) # default 1 hour + + if operation_type == "write": + text_content = tool_parameters.get("text_content") + if not text_content: + return self.create_text_message("text_content parameter is required for write operation") + + # Write content to S3 + self.s3_client.put_object(Bucket=bucket, Key=key, Body=text_content.encode("utf-8")) + result = f"s3://{bucket}/{key}" + + # Generate presigned URL for the written object if requested + if generate_presign_url: + result = self.s3_client.generate_presigned_url( + "get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry + ) + + else: # read operation + # Get object from S3 + response = self.s3_client.get_object(Bucket=bucket, Key=key) + result = response["Body"].read().decode("utf-8") + + # Generate presigned URL if requested + if generate_presign_url: + result = self.s3_client.generate_presigned_url( + "get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry + ) + + return self.create_text_message(text=result) + + except self.s3_client.exceptions.NoSuchBucket: + return self.create_text_message(f"Bucket '{bucket}' does not exist") + except self.s3_client.exceptions.NoSuchKey: + return self.create_text_message(f"Object '{key}' does not exist in bucket '{bucket}'") + except Exception as e: + return self.create_text_message(f"Exception: {str(e)}") diff --git a/api/core/tools/provider/builtin/aws/tools/s3_operator.yaml b/api/core/tools/provider/builtin/aws/tools/s3_operator.yaml new file mode 100644 index 0000000000..642fc2966e --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/s3_operator.yaml @@ -0,0 +1,98 @@ +identity: + name: s3_operator + author: AWS + label: + en_US: AWS S3 Operator + zh_Hans: AWS S3 读写器 + pt_BR: AWS S3 Operator + icon: icon.svg +description: + human: + en_US: AWS S3 Writer and Reader + zh_Hans: 读写S3 bucket中的文件 + pt_BR: AWS S3 Writer and Reader + llm: AWS S3 Writer and Reader +parameters: + - name: text_content + type: string + required: false + label: + en_US: The text to write + zh_Hans: 待写入的文本 + pt_BR: The text to write + human_description: + en_US: The text to write + zh_Hans: 待写入的文本 + pt_BR: The text to write + llm_description: The text to write + form: llm + - name: s3_uri + type: string + required: true + label: + en_US: s3 uri + zh_Hans: s3 uri + pt_BR: s3 uri + human_description: + en_US: s3 uri + zh_Hans: s3 uri + pt_BR: s3 uri + llm_description: s3 uri + form: llm + - name: aws_region + type: string + required: true + label: + en_US: region of bucket + zh_Hans: bucket 所在的region + pt_BR: region of bucket + human_description: + en_US: region of bucket + zh_Hans: bucket 所在的region + pt_BR: region of bucket + llm_description: region of bucket + form: form + - name: operation_type + type: select + required: true + label: + en_US: operation type + zh_Hans: 操作类型 + pt_BR: operation type + human_description: + en_US: operation type + zh_Hans: 操作类型 + pt_BR: operation type + default: read + options: + - value: read + label: + en_US: read + zh_Hans: 读 + - value: write + label: + en_US: write + zh_Hans: 写 + form: form + - name: generate_presign_url + type: boolean + required: false + label: + en_US: Generate presigned URL + zh_Hans: 生成预签名URL + human_description: + en_US: Whether to generate a presigned URL for the S3 object + zh_Hans: 是否生成S3对象的预签名URL + default: false + form: form + - name: presign_expiry + type: number + required: false + label: + en_US: Presigned URL expiration time + zh_Hans: 预签名URL有效期 + human_description: + en_US: Expiration time in seconds for the presigned URL + zh_Hans: 预签名URL的有效期(秒) + default: 3600 + form: form diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index 636debffd4..48aac75dbb 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -210,7 +210,7 @@ class ApiTool(Tool): ) return response else: - raise ValueError(f"Invalid http method {self.method}") + raise ValueError(f"Invalid http method {method}") def _convert_body_property_any_of( self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10 diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 5052f0897a..2aaca6d82e 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -8,9 +8,10 @@ from mimetypes import guess_extension, guess_type from typing import Optional, Union from uuid import uuid4 -from httpx import get +import httpx from configs import dify_config +from core.helper import ssrf_proxy from extensions.ext_database import db from extensions.ext_storage import storage from models.model import MessageFile @@ -94,12 +95,11 @@ class ToolFileManager: ) -> ToolFile: # try to download image try: - response = get(file_url) + response = ssrf_proxy.get(file_url) response.raise_for_status() blob = response.content - except Exception as e: - logger.exception(f"Failed to download file from {file_url}") - raise + except httpx.TimeoutException as e: + raise ValueError(f"timeout when downloading file from {file_url}") mimetype = guess_type(file_url)[0] or "octet/stream" extension = guess_extension(mimetype) or ".bin" diff --git a/api/core/workflow/nodes/base/exc.py b/api/core/workflow/nodes/base/exc.py index ec134e031c..aeecf40640 100644 --- a/api/core/workflow/nodes/base/exc.py +++ b/api/core/workflow/nodes/base/exc.py @@ -1,4 +1,4 @@ -class BaseNodeError(Exception): +class BaseNodeError(ValueError): """Base class for node errors.""" pass diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 19b9078a5c..4e371ca436 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Optional, Union +from typing import Any, Optional from configs import dify_config from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage @@ -59,7 +59,7 @@ class CodeNode(BaseNode[CodeNodeData]): ) # Transform result - result = self._transform_result(result, self.node_data.outputs) + result = self._transform_result(result=result, output_schema=self.node_data.outputs) except (CodeExecutionError, CodeNodeError) as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ @@ -67,18 +67,17 @@ class CodeNode(BaseNode[CodeNodeData]): return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) - def _check_string(self, value: str, variable: str) -> str: + def _check_string(self, value: str | None, variable: str) -> str | None: """ Check string :param value: value :param variable: variable :return: """ + if value is None: + return None if not isinstance(value, str): - if value is None: - return None - else: - raise OutputValidationError(f"Output variable `{variable}` must be a string") + raise OutputValidationError(f"Output variable `{variable}` must be a string") if len(value) > dify_config.CODE_MAX_STRING_LENGTH: raise OutputValidationError( @@ -88,18 +87,17 @@ class CodeNode(BaseNode[CodeNodeData]): return value.replace("\x00", "") - def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]: + def _check_number(self, value: int | float | None, variable: str) -> int | float | None: """ Check number :param value: value :param variable: variable :return: """ + if value is None: + return None if not isinstance(value, int | float): - if value is None: - return None - else: - raise OutputValidationError(f"Output variable `{variable}` must be a number") + raise OutputValidationError(f"Output variable `{variable}` must be a number") if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER: raise OutputValidationError( @@ -118,14 +116,12 @@ class CodeNode(BaseNode[CodeNodeData]): return value def _transform_result( - self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1 - ) -> dict: - """ - Transform result - :param result: result - :param output_schema: output schema - :return: - """ + self, + result: Mapping[str, Any], + output_schema: Optional[dict[str, CodeNodeData.Output]], + prefix: str = "", + depth: int = 1, + ): if depth > dify_config.CODE_MAX_DEPTH: raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.") diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 95d0ea3aab..6d82dbe6d7 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -1,6 +1,7 @@ import csv import io import json +import logging import os import tempfile @@ -22,6 +23,8 @@ from models.workflow import WorkflowNodeExecutionStatus from .entities import DocumentExtractorNodeData from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError +logger = logging.getLogger(__name__) + class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): """ @@ -177,10 +180,43 @@ def _extract_text_from_pdf(file_content: bytes) -> str: def _extract_text_from_doc(file_content: bytes) -> str: + """ + Extract text from a DOC/DOCX file. + For now support only paragraph and table add more if needed + """ try: doc_file = io.BytesIO(file_content) doc = docx.Document(doc_file) - return "\n".join([paragraph.text for paragraph in doc.paragraphs]) + text = [] + # Process paragraphs + for paragraph in doc.paragraphs: + if paragraph.text.strip(): + text.append(paragraph.text) + + # Process tables + for table in doc.tables: + # Table header + try: + # table maybe cause errors so ignore it. + if len(table.rows) > 0 and table.rows[0].cells is not None: + # Check if any cell in the table has text + has_content = False + for row in table.rows: + if any(cell.text.strip() for cell in row.cells): + has_content = True + break + + if has_content: + markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n" + markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n" + for row in table.rows[1:]: + markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n" + text.append(markdown_table) + except Exception as e: + logger.warning(f"Failed to extract table from DOC/DOCX: {e}") + continue + + return "\n".join(text) except Exception as e: raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 5b960ea615..c8c854a43b 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -179,6 +179,15 @@ class ParameterExtractorNode(LLMNode): error=str(e), metadata={}, ) + except Exception as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=inputs, + process_data=process_data, + outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)}, + error=str(e), + metadata={}, + ) error = None diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 15cfabe478..5043e25e2b 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -154,8 +154,7 @@ class QuestionClassifierNode(LLMNode): }, llm_usage=usage, ) - - except ValueError as e: + except Exception as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, diff --git a/api/core/workflow/nodes/variable_assigner/common/exc.py b/api/core/workflow/nodes/variable_assigner/common/exc.py index a1178fb020..f8dbedc290 100644 --- a/api/core/workflow/nodes/variable_assigner/common/exc.py +++ b/api/core/workflow/nodes/variable_assigner/common/exc.py @@ -1,4 +1,4 @@ -class VariableOperatorNodeError(Exception): +class VariableOperatorNodeError(ValueError): """Base error type, don't use directly.""" pass diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 01d95dcfb3..13034f5cf5 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -116,8 +116,11 @@ def _build_from_local_file( tenant_id: str, transfer_method: FileTransferMethod, ) -> File: + upload_file_id = mapping.get("upload_file_id") + if not upload_file_id: + raise ValueError("Invalid upload file id") stmt = select(UploadFile).where( - UploadFile.id == mapping.get("upload_file_id"), + UploadFile.id == upload_file_id, UploadFile.tenant_id == tenant_id, ) diff --git a/api/libs/json_in_md_parser.py b/api/libs/json_in_md_parser.py index 41c5d20c4b..267af611f5 100644 --- a/api/libs/json_in_md_parser.py +++ b/api/libs/json_in_md_parser.py @@ -27,7 +27,7 @@ def parse_json_markdown(json_string: str) -> dict: extracted_content = json_string[start_index:end_index].strip() parsed = json.loads(extracted_content) else: - raise Exception("Could not find JSON block in the output.") + raise ValueError("could not find json block in the output.") return parsed @@ -36,10 +36,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: try: json_obj = parse_json_markdown(text) except json.JSONDecodeError as e: - raise OutputParserError(f"Got invalid JSON object. Error: {e}") + raise OutputParserError(f"got invalid json object. error: {e}") for key in expected_keys: if key not in json_obj: raise OutputParserError( - f"Got invalid return object. Expected key `{key}` to be present, but got {json_obj}" + f"got invalid return object. expected key `{key}` to be present, but got {json_obj}" ) return json_obj diff --git a/api/models/account.py b/api/models/account.py index ce17b90def..932ba1da57 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -2,6 +2,7 @@ import enum import json from flask_login import UserMixin +from sqlalchemy import func from .engine import db from .types import StringUUID @@ -30,11 +31,11 @@ class Account(UserMixin, db.Model): timezone = db.Column(db.String(255)) last_login_at = db.Column(db.DateTime) last_login_ip = db.Column(db.String(255)) - last_active_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + last_active_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying")) initialized_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def is_password_set(self): @@ -187,8 +188,8 @@ class Tenant(db.Model): plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying")) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) custom_config = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def get_accounts(self) -> list[Account]: return ( @@ -228,8 +229,8 @@ class TenantAccountJoin(db.Model): current = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) role = db.Column(db.String(16), nullable=False, server_default="normal") invited_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class AccountIntegrate(db.Model): @@ -245,8 +246,8 @@ class AccountIntegrate(db.Model): provider = db.Column(db.String(16), nullable=False) open_id = db.Column(db.String(255), nullable=False) encrypted_token = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class InvitationCode(db.Model): @@ -265,4 +266,4 @@ class InvitationCode(db.Model): used_by_tenant_id = db.Column(StringUUID) used_by_account_id = db.Column(StringUUID) deprecated_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 4d4182cabd..fbffe7a3b2 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -1,5 +1,7 @@ import enum +from sqlalchemy import func + from .engine import db from .types import StringUUID @@ -23,4 +25,4 @@ class APIBasedExtension(db.Model): name = db.Column(db.String(255), nullable=False) api_endpoint = db.Column(db.String(255), nullable=False) api_key = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/dataset.py b/api/models/dataset.py index 97e4d6c0ef..7279e8d5b3 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -50,9 +50,9 @@ class Dataset(db.Model): indexing_technique = db.Column(db.String(255), nullable=True) index_struct = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) embedding_model = db.Column(db.String(255), nullable=True) embedding_model_provider = db.Column(db.String(255), nullable=True) collection_binding_id = db.Column(StringUUID, nullable=True) @@ -212,7 +212,7 @@ class DatasetProcessRule(db.Model): mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) rules = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) MODES = ["automatic", "custom"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] @@ -264,7 +264,7 @@ class Document(db.Model): created_from = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) created_api_request_id = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) # start processing processing_started_at = db.Column(db.DateTime, nullable=True) @@ -303,7 +303,7 @@ class Document(db.Model): archived_reason = db.Column(db.String(255), nullable=True) archived_by = db.Column(StringUUID, nullable=True) archived_at = db.Column(db.DateTime, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) doc_type = db.Column(db.String(40), nullable=True) doc_metadata = db.Column(db.JSON, nullable=True) doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) @@ -527,9 +527,9 @@ class DocumentSegment(db.Model): disabled_by = db.Column(StringUUID, nullable=True) status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) indexing_at = db.Column(db.DateTime, nullable=True) completed_at = db.Column(db.DateTime, nullable=True) error = db.Column(db.Text, nullable=True) @@ -697,7 +697,7 @@ class Embedding(db.Model): ) hash = db.Column(db.String(64), nullable=False) embedding = db.Column(db.LargeBinary, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying")) def set_embedding(self, embedding_data: list[float]): @@ -719,7 +719,7 @@ class DatasetCollectionBinding(db.Model): model_name = db.Column(db.String(255), nullable=False) type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) collection_name = db.Column(db.String(64), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TidbAuthBinding(db.Model): @@ -739,7 +739,7 @@ class TidbAuthBinding(db.Model): status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING")) account = db.Column(db.String(255), nullable=False) password = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class Whitelist(db.Model): @@ -751,7 +751,7 @@ class Whitelist(db.Model): id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) tenant_id = db.Column(StringUUID, nullable=True) category = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class DatasetPermission(db.Model): @@ -768,7 +768,7 @@ class DatasetPermission(db.Model): account_id = db.Column(StringUUID, nullable=False) tenant_id = db.Column(StringUUID, nullable=False) has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ExternalKnowledgeApis(db.Model): @@ -785,9 +785,9 @@ class ExternalKnowledgeApis(db.Model): tenant_id = db.Column(StringUUID, nullable=False) settings = db.Column(db.Text, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def to_dict(self): return { @@ -840,6 +840,6 @@ class ExternalKnowledgeBindings(db.Model): dataset_id = db.Column(StringUUID, nullable=False) external_knowledge_id = db.Column(db.Text, nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/model.py b/api/models/model.py index f484acde78..ebf0c16c56 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -30,7 +30,7 @@ class DifySetup(db.Model): __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) version = db.Column(db.String(255), nullable=False) - setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + setup_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class AppMode(StrEnum): @@ -85,9 +85,9 @@ class App(db.Model): tracing = db.Column(db.Text, nullable=True) max_active_requests = db.Column(db.Integer, nullable=True) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) @property @@ -226,9 +226,9 @@ class AppModelConfig(db.Model): model_id = db.Column(db.String(255), nullable=True) configs = db.Column(db.JSON, nullable=True) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) opening_statement = db.Column(db.Text) suggested_questions = db.Column(db.Text) suggested_questions_after_answer = db.Column(db.Text) @@ -482,8 +482,8 @@ class RecommendedApp(db.Model): is_listed = db.Column(db.Boolean, nullable=False, default=True) install_count = db.Column(db.Integer, nullable=False, default=0) language = db.Column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def app(self): @@ -507,7 +507,7 @@ class InstalledApp(db.Model): position = db.Column(db.Integer, nullable=False, default=0) is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def app(self): @@ -548,8 +548,8 @@ class Conversation(db.Model): read_at = db.Column(db.DateTime) read_account_id = db.Column(StringUUID) dialogue_count: Mapped[int] = mapped_column(default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all") message_annotations = db.relationship( @@ -700,8 +700,10 @@ class Conversation(db.Model): def status_count(self): messages = db.session.query(Message).filter(Message.conversation_id == self.id).all() status_counts = { + WorkflowRunStatus.RUNNING: 0, WorkflowRunStatus.SUCCEEDED: 0, WorkflowRunStatus.FAILED: 0, + WorkflowRunStatus.STOPPED: 0, WorkflowRunStatus.PARTIAL_SUCCESSED: 0, } @@ -789,8 +791,8 @@ class Message(db.Model): from_source = db.Column(db.String(255), nullable=False) from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID) from_account_id: Mapped[Optional[str]] = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) workflow_run_id = db.Column(StringUUID) @@ -1115,8 +1117,8 @@ class MessageFeedback(db.Model): from_source = db.Column(db.String(255), nullable=False) from_end_user_id = db.Column(StringUUID) from_account_id = db.Column(StringUUID) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def from_account(self): @@ -1162,9 +1164,7 @@ class MessageFile(db.Model): upload_file_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True) created_by_role: Mapped[str] = db.Column(db.String(255), nullable=False) created_by: Mapped[str] = db.Column(StringUUID, nullable=False) - created_at: Mapped[datetime] = db.Column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") - ) + created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class MessageAnnotation(db.Model): @@ -1184,8 +1184,8 @@ class MessageAnnotation(db.Model): content = db.Column(db.Text, nullable=False) hit_count = db.Column(db.Integer, nullable=False, server_default=db.text("0")) account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def account(self): @@ -1214,7 +1214,7 @@ class AppAnnotationHitHistory(db.Model): source = db.Column(db.Text, nullable=False) question = db.Column(db.Text, nullable=False) account_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) score = db.Column(Float, nullable=False, server_default=db.text("0")) message_id = db.Column(StringUUID, nullable=False) annotation_question = db.Column(db.Text, nullable=False) @@ -1248,9 +1248,9 @@ class AppAnnotationSetting(db.Model): score_threshold = db.Column(Float, nullable=False, server_default=db.text("0")) collection_binding_id = db.Column(StringUUID, nullable=False) created_user_id = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_user_id = db.Column(StringUUID, nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def created_account(self): @@ -1296,9 +1296,9 @@ class OperationLog(db.Model): account_id = db.Column(StringUUID, nullable=False) action = db.Column(db.String(255), nullable=False) content = db.Column(db.JSON) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_ip = db.Column(db.String(255), nullable=False) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class EndUser(UserMixin, db.Model): @@ -1317,8 +1317,8 @@ class EndUser(UserMixin, db.Model): name = db.Column(db.String(255)) is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) session_id = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class Site(db.Model): @@ -1349,9 +1349,9 @@ class Site(db.Model): prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) created_by = db.Column(StringUUID, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = db.Column(StringUUID, nullable=True) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) code = db.Column(db.String(255)) @property @@ -1393,7 +1393,7 @@ class ApiToken(db.Model): type = db.Column(db.String(16), nullable=False) token = db.Column(db.String(255), nullable=False) last_used_at = db.Column(db.DateTime, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @staticmethod def generate_api_key(prefix, n): @@ -1424,9 +1424,7 @@ class UploadFile(db.Model): db.String(255), nullable=False, server_default=db.text("'account'::character varying") ) created_by: Mapped[str] = db.Column(StringUUID, nullable=False) - created_at: Mapped[datetime] = db.Column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") - ) + created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) used: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) used_by: Mapped[str | None] = db.Column(StringUUID, nullable=True) used_at: Mapped[datetime | None] = db.Column(db.DateTime, nullable=True) @@ -1483,7 +1481,7 @@ class ApiRequest(db.Model): request = db.Column(db.Text, nullable=True) response = db.Column(db.Text, nullable=True) ip = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class MessageChain(db.Model): @@ -1655,7 +1653,7 @@ class Tag(db.Model): type = db.Column(db.String(16), nullable=False) name = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TagBinding(db.Model): @@ -1671,7 +1669,7 @@ class TagBinding(db.Model): tag_id = db.Column(StringUUID, nullable=True) target_id = db.Column(StringUUID, nullable=True) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TraceAppConfig(db.Model): @@ -1685,8 +1683,10 @@ class TraceAppConfig(db.Model): app_id = db.Column(StringUUID, nullable=False) tracing_provider = db.Column(db.String(255), nullable=True) tracing_config = db.Column(db.JSON, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.now()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.now(), onupdate=func.now()) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column( + db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() + ) is_active = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) @property diff --git a/api/models/provider.py b/api/models/provider.py index 65f70b76e9..fdd3e802d7 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,5 +1,7 @@ from enum import Enum +from sqlalchemy import func + from .engine import db from .types import StringUUID @@ -60,8 +62,8 @@ class Provider(db.Model): quota_limit = db.Column(db.BigInteger, nullable=True) quota_used = db.Column(db.BigInteger, default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def __repr__(self): return ( @@ -108,8 +110,8 @@ class ProviderModel(db.Model): model_type = db.Column(db.String(40), nullable=False) encrypted_config = db.Column(db.Text, nullable=True) is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TenantDefaultModel(db.Model): @@ -124,8 +126,8 @@ class TenantDefaultModel(db.Model): provider_name = db.Column(db.String(255), nullable=False) model_name = db.Column(db.String(255), nullable=False) model_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TenantPreferredModelProvider(db.Model): @@ -139,8 +141,8 @@ class TenantPreferredModelProvider(db.Model): tenant_id = db.Column(StringUUID, nullable=False) provider_name = db.Column(db.String(255), nullable=False) preferred_provider_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderOrder(db.Model): @@ -164,8 +166,8 @@ class ProviderOrder(db.Model): paid_at = db.Column(db.DateTime) pay_failed_at = db.Column(db.DateTime) refunded_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderModelSetting(db.Model): @@ -186,8 +188,8 @@ class ProviderModelSetting(db.Model): model_type = db.Column(db.String(40), nullable=False) enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class LoadBalancingModelConfig(db.Model): @@ -209,5 +211,5 @@ class LoadBalancingModelConfig(db.Model): name = db.Column(db.String(255), nullable=False) encrypted_config = db.Column(db.Text, nullable=True) enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/source.py b/api/models/source.py index 4d98572ef8..114db8e110 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -1,5 +1,6 @@ import json +from sqlalchemy import func from sqlalchemy.dialects.postgresql import JSONB from .engine import db @@ -19,8 +20,8 @@ class DataSourceOauthBinding(db.Model): access_token = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) source_info = db.Column(JSONB, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) @@ -37,8 +38,8 @@ class DataSourceApiKeyAuthBinding(db.Model): category = db.Column(db.String(255), nullable=False) provider = db.Column(db.String(255), nullable=False) credentials = db.Column(db.Text, nullable=True) # JSON - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) def to_dict(self): diff --git a/api/models/tools.py b/api/models/tools.py index c390be4625..e90ab669c6 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -2,7 +2,7 @@ import json from typing import Optional import sqlalchemy as sa -from sqlalchemy import ForeignKey +from sqlalchemy import ForeignKey, func from sqlalchemy.orm import Mapped, mapped_column from core.tools.entities.common_entities import I18nObject @@ -36,8 +36,8 @@ class BuiltinToolProvider(db.Model): provider = db.Column(db.String(40), nullable=False) # credential of the tool provider encrypted_credentials = db.Column(db.Text, nullable=True) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def credentials(self) -> dict: @@ -74,8 +74,8 @@ class PublishedAppTool(db.Model): tool_name = db.Column(db.String(40), nullable=False) # author author = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def description_i18n(self) -> I18nObject: @@ -120,8 +120,8 @@ class ApiToolProvider(db.Model): # custom_disclaimer custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="") - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def schema_type(self) -> ApiProviderSchemaType: @@ -198,8 +198,8 @@ class WorkflowToolProvider(db.Model): # privacy policy privacy_policy = db.Column(db.String(255), nullable=True, server_default="") - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def user(self) -> Account | None: @@ -251,8 +251,8 @@ class ToolModelInvoke(db.Model): provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_price = db.Column(db.Numeric(10, 7)) currency = db.Column(db.String(255), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ToolConversationVariables(db.Model): @@ -278,8 +278,8 @@ class ToolConversationVariables(db.Model): # variables pool variables_str = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) - updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def variables(self) -> dict: diff --git a/api/models/web.py b/api/models/web.py index a0f87cf456..028a768519 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -1,3 +1,6 @@ +from sqlalchemy import func +from sqlalchemy.orm import Mapped, mapped_column + from .engine import db from .model import Message from .types import StringUUID @@ -15,7 +18,7 @@ class SavedMessage(db.Model): message_id = db.Column(StringUUID, nullable=False) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def message(self): @@ -31,7 +34,7 @@ class PinnedConversation(db.Model): id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) - conversation_id = db.Column(StringUUID, nullable=False) + conversation_id: Mapped[str] = mapped_column(StringUUID) created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/workflow.py b/api/models/workflow.py index 51a6fbc8c8..d5be949bf4 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -103,12 +103,13 @@ class Workflow(db.Model): graph: Mapped[str] = mapped_column(sa.Text) _features: Mapped[str] = mapped_column("features", sa.TEXT) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_at: Mapped[datetime] = mapped_column( - db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") - ) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) updated_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, default=datetime.now(tz=UTC), server_onupdate=func.current_timestamp() + db.DateTime, + nullable=False, + default=datetime.now(UTC).replace(tzinfo=None), + server_onupdate=func.current_timestamp(), ) _environment_variables: Mapped[str] = mapped_column( "environment_variables", db.Text, nullable=False, server_default="{}" @@ -406,7 +407,7 @@ class WorkflowRun(db.Model): total_steps = db.Column(db.Integer, server_default=db.text("0")) created_by_role = db.Column(db.String(255), nullable=False) # account, end_user created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) finished_at = db.Column(db.DateTime) exceptions_count = db.Column(db.Integer, server_default=db.text("0")) @@ -636,7 +637,7 @@ class WorkflowNodeExecution(db.Model): error = db.Column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) execution_metadata = db.Column(db.Text) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) finished_at = db.Column(db.DateTime) @@ -754,7 +755,7 @@ class WorkflowAppLog(db.Model): created_from = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False) created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def workflow_run(self): @@ -780,7 +781,7 @@ class ConversationVariable(db.Model): conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True) app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True) data = db.Column(db.Text, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=func.current_timestamp()) updated_at = db.Column( db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 8180c3b400..0478903fa4 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -340,7 +340,10 @@ class AppDslService: ) -> App: """Create a new app or update an existing one.""" app_data = data.get("app", {}) - app_mode = AppMode(app_data["mode"]) + app_mode = app_data.get("mode") + if not app_mode: + raise ValueError("loss app mode") + app_mode = AppMode(app_mode) # Set icon type icon_type_value = icon_type or app_data.get("icon_type") diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 8642972710..456dc3ebeb 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,8 +1,9 @@ -from collections.abc import Callable +from collections.abc import Callable, Sequence from datetime import UTC, datetime from typing import Optional, Union -from sqlalchemy import asc, desc, or_ +from sqlalchemy import asc, desc, func, or_, select +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator @@ -18,19 +19,21 @@ class ConversationService: @classmethod def pagination_by_last_id( cls, + *, + session: Session, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], limit: int, invoke_from: InvokeFrom, - include_ids: Optional[list] = None, - exclude_ids: Optional[list] = None, + include_ids: Optional[Sequence[str]] = None, + exclude_ids: Optional[Sequence[str]] = None, sort_by: str = "-updated_at", ) -> InfiniteScrollPagination: if not user: return InfiniteScrollPagination(data=[], limit=limit, has_more=False) - base_query = db.session.query(Conversation).filter( + stmt = select(Conversation).where( Conversation.is_deleted == False, Conversation.app_id == app_model.id, Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"), @@ -38,37 +41,40 @@ class ConversationService: Conversation.from_account_id == (user.id if isinstance(user, Account) else None), or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value), ) - if include_ids is not None: - base_query = base_query.filter(Conversation.id.in_(include_ids)) - + stmt = stmt.where(Conversation.id.in_(include_ids)) if exclude_ids is not None: - base_query = base_query.filter(~Conversation.id.in_(exclude_ids)) + stmt = stmt.where(~Conversation.id.in_(exclude_ids)) # define sort fields and directions sort_field, sort_direction = cls._get_sort_params(sort_by) if last_id: - last_conversation = base_query.filter(Conversation.id == last_id).first() + last_conversation = session.scalar(stmt.where(Conversation.id == last_id)) if not last_conversation: raise LastConversationNotExistsError() # build filters based on sorting - filter_condition = cls._build_filter_condition(sort_field, sort_direction, last_conversation) - base_query = base_query.filter(filter_condition) - - base_query = base_query.order_by(sort_direction(getattr(Conversation, sort_field))) - - conversations = base_query.limit(limit).all() + filter_condition = cls._build_filter_condition( + sort_field=sort_field, + sort_direction=sort_direction, + reference_conversation=last_conversation, + ) + stmt = stmt.where(filter_condition) + query_stmt = stmt.order_by(sort_direction(getattr(Conversation, sort_field))).limit(limit) + conversations = session.scalars(query_stmt).all() has_more = False if len(conversations) == limit: current_page_last_conversation = conversations[-1] rest_filter_condition = cls._build_filter_condition( - sort_field, sort_direction, current_page_last_conversation, is_next_page=True + sort_field=sort_field, + sort_direction=sort_direction, + reference_conversation=current_page_last_conversation, ) - rest_count = base_query.filter(rest_filter_condition).count() - + count_stmt = stmt.where(rest_filter_condition) + count_stmt = select(func.count()).select_from(count_stmt.subquery()) + rest_count = session.scalar(count_stmt) or 0 if rest_count > 0: has_more = True @@ -81,11 +87,9 @@ class ConversationService: return sort_by, asc @classmethod - def _build_filter_condition( - cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation, is_next_page: bool = False - ): + def _build_filter_condition(cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation): field_value = getattr(reference_conversation, sort_field) - if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page): + if sort_direction == desc: return getattr(Conversation, sort_field) < field_value else: return getattr(Conversation, sort_field) > field_value diff --git a/api/services/message_service.py b/api/services/message_service.py index f432a77c80..be2922f4c5 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -151,7 +151,12 @@ class MessageService: @classmethod def create_feedback( - cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], rating: Optional[str] + cls, + app_model: App, + message_id: str, + user: Optional[Union[Account, EndUser]], + rating: Optional[str], + content: Optional[str], ) -> MessageFeedback: if not user: raise ValueError("user cannot be None") @@ -164,6 +169,7 @@ class MessageService: db.session.delete(feedback) elif rating and feedback: feedback.rating = rating + feedback.content = content elif not rating and not feedback: raise ValueError("rating cannot be None when feedback not exists") else: @@ -172,6 +178,7 @@ class MessageService: conversation_id=message.conversation_id, message_id=message.id, rating=rating, + content=content, from_source=("user" if isinstance(user, EndUser) else "admin"), from_end_user_id=(user.id if isinstance(user, EndUser) else None), from_account_id=(user.id if isinstance(user, Account) else None), diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index e2e49d017e..fada881fde 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -2,6 +2,9 @@ import json import logging from pathlib import Path +from sqlalchemy import select +from sqlalchemy.orm import Session + from configs import dify_config from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder @@ -32,7 +35,7 @@ class BuiltinToolManageService: tenant_id=tenant_id, provider_controller=provider_controller ) # check if user has added the provider - builtin_provider: BuiltinToolProvider = ( + builtin_provider = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, @@ -71,19 +74,18 @@ class BuiltinToolManageService: return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()]) @staticmethod - def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict): + def update_builtin_tool_provider( + session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict + ): """ update builtin tool provider """ # get if the provider exists - provider: BuiltinToolProvider = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider_name, - ) - .first() + stmt = select(BuiltinToolProvider).where( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider_name, ) + provider = session.scalar(stmt) try: # get provider @@ -115,13 +117,10 @@ class BuiltinToolManageService: encrypted_credentials=json.dumps(credentials), ) - db.session.add(provider) - db.session.commit() + session.add(provider) else: provider.encrypted_credentials = json.dumps(credentials) - db.session.add(provider) - db.session.commit() # delete cache tool_configuration.delete_tool_credentials_cache() @@ -129,15 +128,15 @@ class BuiltinToolManageService: return {"result": "success"} @staticmethod - def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str): + def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str): """ get builtin tool provider credentials """ - provider: BuiltinToolProvider = ( + provider = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == provider, + BuiltinToolProvider.provider == provider_name, ) .first() ) @@ -156,7 +155,7 @@ class BuiltinToolManageService: """ delete tool provider """ - provider: BuiltinToolProvider = ( + provider = ( db.session.query(BuiltinToolProvider) .filter( BuiltinToolProvider.tenant_id == tenant_id, diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index d7ccc964cb..508fe20970 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -1,5 +1,8 @@ from typing import Optional, Union +from sqlalchemy import select +from sqlalchemy.orm import Session + from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -13,6 +16,8 @@ class WebConversationService: @classmethod def pagination_by_last_id( cls, + *, + session: Session, app_model: App, user: Optional[Union[Account, EndUser]], last_id: Optional[str], @@ -23,24 +28,25 @@ class WebConversationService: ) -> InfiniteScrollPagination: include_ids = None exclude_ids = None - if pinned is not None: - pinned_conversations = ( - db.session.query(PinnedConversation) - .filter( + if pinned is not None and user: + stmt = ( + select(PinnedConversation.conversation_id) + .where( PinnedConversation.app_id == app_model.id, PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"), PinnedConversation.created_by == user.id, ) .order_by(PinnedConversation.created_at.desc()) - .all() ) - pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations] + pinned_conversation_ids = session.scalars(stmt).all() + if pinned: include_ids = pinned_conversation_ids else: exclude_ids = pinned_conversation_ids return ConversationService.pagination_by_last_id( + session=session, app_model=app_model, user=user, last_id=last_id, diff --git a/docker/.env.example b/docker/.env.example index e8ec246ae2..43e67a8db4 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -923,6 +923,3 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false # Maximum number of submitted thread count in a ThreadPool for parallel node execution MAX_SUBMIT_COUNT=100 -# Proxy -HTTP_PROXY= -HTTPS_PROXY= diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index d0738d9305..99bc14c717 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -386,8 +386,6 @@ x-shared-env: &shared-api-worker-env CSP_WHITELIST: ${CSP_WHITELIST:-} CREATE_TIDB_SERVICE_JOB_ENABLED: ${CREATE_TIDB_SERVICE_JOB_ENABLED:-false} MAX_SUBMIT_COUNT: ${MAX_SUBMIT_COUNT:-100} - HTTP_PROXY: ${HTTP_PROXY:-} - HTTPS_PROXY: ${HTTPS_PROXY:-} services: # API service diff --git a/web/app/components/base/chat/chat/answer/workflow-process.tsx b/web/app/components/base/chat/chat/answer/workflow-process.tsx index 4a09e27d98..bb9abdb6fc 100644 --- a/web/app/components/base/chat/chat/answer/workflow-process.tsx +++ b/web/app/components/base/chat/chat/answer/workflow-process.tsx @@ -64,6 +64,12 @@ const WorkflowProcessItem = ({ setShowMessageLogModal(true) }, [item, setCurrentLogItem, setCurrentLogModalActiveTab, setShowMessageLogModal]) + const showRetryDetail = useCallback(() => { + setCurrentLogItem(item) + setCurrentLogModalActiveTab('TRACING') + setShowMessageLogModal(true) + }, [item, setCurrentLogItem, setCurrentLogModalActiveTab, setShowMessageLogModal]) + return (
diff --git a/web/app/components/base/input/index.tsx b/web/app/components/base/input/index.tsx index bf8efdb65a..044fc27858 100644 --- a/web/app/components/base/input/index.tsx +++ b/web/app/components/base/input/index.tsx @@ -28,6 +28,7 @@ export type InputProps = { destructive?: boolean wrapperClassName?: string styleCss?: CSSProperties + unit?: string } & React.InputHTMLAttributes