mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into fix/remove-the-retry-index-field
This commit is contained in:
commit
4d35df9210
|
|
@ -555,7 +555,8 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
|
|||
if language not in languages:
|
||||
language = "en-US"
|
||||
|
||||
name = name.strip()
|
||||
# Validates name encoding for non-Latin characters.
|
||||
name = name.strip().encode("utf-8").decode("utf-8") if name else None
|
||||
|
||||
# generate random password
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
|
|
|
|||
|
|
@ -4,3 +4,8 @@ from werkzeug.exceptions import HTTPException
|
|||
class FilenameNotExistsError(HTTPException):
|
||||
code = 400
|
||||
description = "The specified filename does not exist."
|
||||
|
||||
|
||||
class RemoteFileUploadError(HTTPException):
|
||||
code = 400
|
||||
description = "Error uploading remote file."
|
||||
|
|
|
|||
|
|
@ -1,12 +1,14 @@
|
|||
from flask_login import current_user
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.error import NotChatAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
|
|
@ -34,14 +36,16 @@ class ConversationListApi(InstalledAppResource):
|
|||
pinned = True if args["pinned"] == "true" else False
|
||||
|
||||
try:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
pinned=pinned,
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=current_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
pinned=pinned,
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ class MessageFeedbackApi(InstalledAppResource):
|
|||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, current_user, args["rating"])
|
||||
MessageService.create_feedback(app_model, message_id, current_user, args["rating"], args["content"])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from flask_restful import Resource, marshal_with, reqparse
|
|||
|
||||
import services
|
||||
from controllers.common import helpers
|
||||
from controllers.common.errors import RemoteFileUploadError
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
|
|
@ -43,10 +44,14 @@ class RemoteFileUploadApi(Resource):
|
|||
|
||||
url = args["url"]
|
||||
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
resp.raise_for_status()
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
except httpx.RequestError as e:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,12 +3,14 @@ import io
|
|||
from flask import send_file
|
||||
from flask_login import current_user
|
||||
from flask_restful import Resource, reqparse
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import alphanumeric, uuid_value
|
||||
from libs.login import login_required
|
||||
from services.tools.api_tools_manage_service import ApiToolManageService
|
||||
|
|
@ -91,12 +93,16 @@ class ToolBuiltinProviderUpdateApi(Resource):
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
return BuiltinToolManageService.update_builtin_tool_provider(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
args["credentials"],
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
result = BuiltinToolManageService.update_builtin_tool_provider(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider,
|
||||
credentials=args["credentials"],
|
||||
)
|
||||
session.commit()
|
||||
return result
|
||||
|
||||
|
||||
class ToolBuiltinProviderGetCredentialsApi(Resource):
|
||||
|
|
@ -104,13 +110,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, provider):
|
||||
user_id = current_user.id
|
||||
tenant_id = current_user.current_tenant_id
|
||||
|
||||
return BuiltinToolManageService.get_builtin_tool_provider_credentials(
|
||||
user_id,
|
||||
tenant_id,
|
||||
provider,
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
|
|
@ -7,6 +8,7 @@ from controllers.service_api import api
|
|||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import (
|
||||
conversation_delete_fields,
|
||||
conversation_infinite_scroll_pagination_fields,
|
||||
|
|
@ -39,14 +41,16 @@ class ConversationApi(Resource):
|
|||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
sort_by=args["sort_by"],
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
return ConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
sort_by=args["sort_by"],
|
||||
)
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
|
|
|||
|
|
@ -104,10 +104,11 @@ class MessageFeedbackApi(Resource):
|
|||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
|
||||
parser.add_argument("content", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args["rating"])
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.error import NotChatAppError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import uuid_value
|
||||
from models.model import AppMode
|
||||
|
|
@ -40,15 +42,17 @@ class ConversationListApi(WebApiResource):
|
|||
pinned = True if args["pinned"] == "true" else False
|
||||
|
||||
try:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=pinned,
|
||||
sort_by=args["sort_by"],
|
||||
)
|
||||
with Session(db.engine) as session:
|
||||
return WebConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
last_id=args["last_id"],
|
||||
limit=args["limit"],
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
pinned=pinned,
|
||||
sort_by=args["sort_by"],
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ class MessageFeedbackApi(WebApiResource):
|
|||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args["rating"])
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from flask_restful import marshal_with, reqparse
|
|||
|
||||
import services
|
||||
from controllers.common import helpers
|
||||
from controllers.common.errors import RemoteFileUploadError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
|
|
@ -38,10 +39,14 @@ class RemoteFileUploadApi(WebApiResource):
|
|||
|
||||
url = args["url"]
|
||||
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
except httpx.RequestError as e:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,14 +4,17 @@ import logging
|
|||
import queue
|
||||
import re
|
||||
import threading
|
||||
from collections.abc import Iterable
|
||||
|
||||
from core.app.entities.queue_entities import (
|
||||
MessageQueueMessage,
|
||||
QueueAgentMessageEvent,
|
||||
QueueLLMChunkEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueTextChunkEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
|
|
@ -21,7 +24,7 @@ class AudioTrunk:
|
|||
self.status = status
|
||||
|
||||
|
||||
def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
|
||||
def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str):
|
||||
if not text_content or text_content.isspace():
|
||||
return
|
||||
return model_instance.invoke_tts(
|
||||
|
|
@ -29,13 +32,19 @@ def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
|
|||
)
|
||||
|
||||
|
||||
def _process_future(future_queue, audio_queue):
|
||||
def _process_future(
|
||||
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None],
|
||||
audio_queue: queue.Queue[AudioTrunk],
|
||||
):
|
||||
while True:
|
||||
try:
|
||||
future = future_queue.get()
|
||||
if future is None:
|
||||
break
|
||||
for audio in future.result():
|
||||
invoke_result = future.result()
|
||||
if not invoke_result:
|
||||
continue
|
||||
for audio in invoke_result:
|
||||
audio_base64 = base64.b64encode(bytes(audio))
|
||||
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
|
||||
except Exception as e:
|
||||
|
|
@ -49,8 +58,8 @@ class AppGeneratorTTSPublisher:
|
|||
self.logger = logging.getLogger(__name__)
|
||||
self.tenant_id = tenant_id
|
||||
self.msg_text = ""
|
||||
self._audio_queue = queue.Queue()
|
||||
self._msg_queue = queue.Queue()
|
||||
self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue()
|
||||
self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||
self.match = re.compile(r"[。.!?]")
|
||||
self.model_manager = ModelManager()
|
||||
self.model_instance = self.model_manager.get_default_model_instance(
|
||||
|
|
@ -66,14 +75,11 @@ class AppGeneratorTTSPublisher:
|
|||
self._runtime_thread = threading.Thread(target=self._runtime).start()
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
||||
|
||||
def publish(self, message):
|
||||
try:
|
||||
self._msg_queue.put(message)
|
||||
except Exception as e:
|
||||
self.logger.warning(e)
|
||||
def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
|
||||
self._msg_queue.put(message)
|
||||
|
||||
def _runtime(self):
|
||||
future_queue = queue.Queue()
|
||||
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None] = queue.Queue()
|
||||
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
|
||||
while True:
|
||||
try:
|
||||
|
|
@ -110,7 +116,7 @@ class AppGeneratorTTSPublisher:
|
|||
break
|
||||
future_queue.put(None)
|
||||
|
||||
def check_and_get_audio(self) -> AudioTrunk | None:
|
||||
def check_and_get_audio(self):
|
||||
try:
|
||||
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||
if self.executor:
|
||||
|
|
|
|||
|
|
@ -197,11 +197,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
audio_msg = publisher.check_and_get_audio()
|
||||
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
|
|
@ -222,7 +222,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
|
||||
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
|
|
@ -512,7 +512,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
|
|||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(message=queue_message)
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._message_to_stream_response(
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import queue
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Generator
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -11,9 +10,11 @@ from configs import dify_config
|
|||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
MessageQueueMessage,
|
||||
QueueErrorEvent,
|
||||
QueuePingEvent,
|
||||
QueueStopEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
|
@ -37,11 +38,11 @@ class AppQueueManager:
|
|||
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
|
||||
)
|
||||
|
||||
q = queue.Queue()
|
||||
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||
|
||||
self._q = q
|
||||
|
||||
def listen(self) -> Generator:
|
||||
def listen(self):
|
||||
"""
|
||||
Listen to queue
|
||||
:return:
|
||||
|
|
|
|||
|
|
@ -171,11 +171,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
|
||||
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
|
||||
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
|
||||
if not publisher:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
audio_msg = publisher.check_and_get_audio()
|
||||
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
||||
|
|
@ -196,7 +196,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
|
||||
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
|
||||
while True:
|
||||
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
|
||||
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
|
||||
if audio_response:
|
||||
yield audio_response
|
||||
else:
|
||||
|
|
@ -421,7 +421,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
|
|||
|
||||
# only publish tts message at text chunk streaming
|
||||
if tts_publisher:
|
||||
tts_publisher.publish(message=queue_message)
|
||||
tts_publisher.publish(queue_message)
|
||||
|
||||
self._task_state.answer += delta_text
|
||||
yield self._text_chunk_to_stream_response(
|
||||
|
|
|
|||
|
|
@ -201,11 +201,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
|
|||
stream_response=stream_response,
|
||||
)
|
||||
|
||||
def _listen_audio_msg(self, publisher, task_id: str):
|
||||
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
|
||||
if publisher is None:
|
||||
return None
|
||||
audio_msg: AudioTrunk = publisher.check_and_get_audio()
|
||||
if audio_msg and audio_msg.status != "finish":
|
||||
audio_msg = publisher.check_and_get_audio()
|
||||
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
|
||||
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
|
||||
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -508,6 +508,12 @@ class WorkflowCycleManage:
|
|||
:param workflow_run: workflow run
|
||||
:return:
|
||||
"""
|
||||
# Attach WorkflowRun to an active session so "created_by_role" can be accessed.
|
||||
workflow_run = db.session.merge(workflow_run)
|
||||
|
||||
# Refresh to ensure any expired attributes are fully loaded
|
||||
db.session.refresh(workflow_run)
|
||||
|
||||
created_by = None
|
||||
if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value:
|
||||
created_by_account = workflow_run.created_by_account
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Optional
|
||||
|
||||
|
||||
class LLMError(Exception):
|
||||
class LLMError(ValueError):
|
||||
"""Base class for all LLM exceptions."""
|
||||
|
||||
description: Optional[str] = None
|
||||
|
|
@ -16,7 +16,7 @@ class LLMBadRequestError(LLMError):
|
|||
description = "Bad Request"
|
||||
|
||||
|
||||
class ProviderTokenNotInitError(Exception):
|
||||
class ProviderTokenNotInitError(ValueError):
|
||||
"""
|
||||
Custom exception raised when the provider token is not initialized.
|
||||
"""
|
||||
|
|
@ -27,7 +27,7 @@ class ProviderTokenNotInitError(Exception):
|
|||
self.description = args[0] if args else self.description
|
||||
|
||||
|
||||
class QuotaExceededError(Exception):
|
||||
class QuotaExceededError(ValueError):
|
||||
"""
|
||||
Custom exception raised when the quota for a provider has been exceeded.
|
||||
"""
|
||||
|
|
@ -35,7 +35,7 @@ class QuotaExceededError(Exception):
|
|||
description = "Quota Exceeded"
|
||||
|
||||
|
||||
class AppInvokeQuotaExceededError(Exception):
|
||||
class AppInvokeQuotaExceededError(ValueError):
|
||||
"""
|
||||
Custom exception raised when the quota for an app has been exceeded.
|
||||
"""
|
||||
|
|
@ -43,7 +43,7 @@ class AppInvokeQuotaExceededError(Exception):
|
|||
description = "App Invoke Quota Exceeded"
|
||||
|
||||
|
||||
class ModelCurrentlyNotSupportError(Exception):
|
||||
class ModelCurrentlyNotSupportError(ValueError):
|
||||
"""
|
||||
Custom exception raised when the model not support
|
||||
"""
|
||||
|
|
@ -51,7 +51,7 @@ class ModelCurrentlyNotSupportError(Exception):
|
|||
description = "Model Currently Not Support"
|
||||
|
||||
|
||||
class InvokeRateLimitError(Exception):
|
||||
class InvokeRateLimitError(ValueError):
|
||||
"""Raised when the Invoke returns rate limit error."""
|
||||
|
||||
description = "Rate Limit Error"
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ class CodeExecutor:
|
|||
return response.data.stdout or ""
|
||||
|
||||
@classmethod
|
||||
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]) -> dict:
|
||||
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]):
|
||||
"""
|
||||
Execute code
|
||||
:param language: code language
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class TemplateTransformer(ABC):
|
|||
return runner_script, preload_script
|
||||
|
||||
@classmethod
|
||||
def extract_result_str_from_response(cls, response: str) -> str:
|
||||
def extract_result_str_from_response(cls, response: str):
|
||||
result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL)
|
||||
if not result:
|
||||
raise ValueError("Failed to parse result")
|
||||
|
|
@ -33,13 +33,21 @@ class TemplateTransformer(ABC):
|
|||
return result
|
||||
|
||||
@classmethod
|
||||
def transform_response(cls, response: str) -> dict:
|
||||
def transform_response(cls, response: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Transform response to dict
|
||||
:param response: response
|
||||
:return:
|
||||
"""
|
||||
return json.loads(cls.extract_result_str_from_response(response))
|
||||
try:
|
||||
result = json.loads(cls.extract_result_str_from_response(response))
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("failed to parse response")
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError("result must be a dict")
|
||||
if not all(isinstance(k, str) for k in result):
|
||||
raise ValueError("result keys must be strings")
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ BACKOFF_FACTOR = 0.5
|
|||
STATUS_FORCELIST = [429, 500, 502, 503, 504]
|
||||
|
||||
|
||||
class MaxRetriesExceededError(Exception):
|
||||
class MaxRetriesExceededError(ValueError):
|
||||
"""Raised when the maximum number of retries is exceeded."""
|
||||
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
class OutputParserError(Exception):
|
||||
class OutputParserError(ValueError):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Optional
|
||||
|
||||
|
||||
class InvokeError(Exception):
|
||||
class InvokeError(ValueError):
|
||||
"""Base class for all LLM exceptions."""
|
||||
|
||||
description: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
class CredentialsValidateFailedError(Exception):
|
||||
class CredentialsValidateFailedError(ValueError):
|
||||
"""
|
||||
Credentials validate failed error
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -531,7 +531,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": message_content.mime_type,
|
||||
"data": message_content.data,
|
||||
"data": message_content.base64_data,
|
||||
},
|
||||
}
|
||||
sub_messages.append(sub_message_dict)
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from core.model_runtime.entities.message_entities import (
|
|||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
|
@ -143,7 +144,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||
"""
|
||||
|
||||
try:
|
||||
ping_message = SystemPromptMessage(content="ping")
|
||||
ping_message = UserPromptMessage(content="ping")
|
||||
self._generate(model, credentials, [ping_message], {"max_output_tokens": 5})
|
||||
|
||||
except Exception as ex:
|
||||
|
|
@ -187,17 +188,23 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||
config_kwargs["stop_sequences"] = stop
|
||||
|
||||
genai.configure(api_key=credentials["google_api_key"])
|
||||
google_model = genai.GenerativeModel(model_name=model)
|
||||
|
||||
history = []
|
||||
system_instruction = None
|
||||
|
||||
for msg in prompt_messages: # makes message roles strictly alternating
|
||||
content = self._format_message_to_glm_content(msg)
|
||||
if history and history[-1]["role"] == content["role"]:
|
||||
history[-1]["parts"].extend(content["parts"])
|
||||
elif content["role"] == "system":
|
||||
system_instruction = content["parts"][0]
|
||||
else:
|
||||
history.append(content)
|
||||
|
||||
if not history:
|
||||
raise InvokeError("The user prompt message is required. You only add a system prompt message.")
|
||||
|
||||
google_model = genai.GenerativeModel(model_name=model, system_instruction=system_instruction)
|
||||
response = google_model.generate_content(
|
||||
contents=history,
|
||||
generation_config=genai.types.GenerationConfig(**config_kwargs),
|
||||
|
|
@ -404,7 +411,10 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||
)
|
||||
return glm_content
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
return {"role": "user", "parts": [to_part(message.content)]}
|
||||
if isinstance(message.content, list):
|
||||
text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content)
|
||||
message.content = "".join(c.data for c in text_contents)
|
||||
return {"role": "system", "parts": [to_part(message.content)]}
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
return {
|
||||
"role": "function",
|
||||
|
|
|
|||
|
|
@ -355,7 +355,13 @@ class TraceTask:
|
|||
def conversation_trace(self, **kwargs):
|
||||
return kwargs
|
||||
|
||||
def workflow_trace(self, workflow_run: WorkflowRun, conversation_id, user_id):
|
||||
def workflow_trace(self, workflow_run: WorkflowRun | None, conversation_id, user_id):
|
||||
if not workflow_run:
|
||||
raise ValueError("Workflow run not found")
|
||||
|
||||
db.session.merge(workflow_run)
|
||||
db.sessoin.refresh(workflow_run)
|
||||
|
||||
workflow_id = workflow_run.workflow_id
|
||||
tenant_id = workflow_run.tenant_id
|
||||
workflow_run_id = workflow_run.id
|
||||
|
|
|
|||
|
|
@ -83,11 +83,15 @@ class DataPostProcessor:
|
|||
if reranking_model:
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
reranking_provider_name = reranking_model.get("reranking_provider_name")
|
||||
reranking_model_name = reranking_model.get("reranking_model_name")
|
||||
if not reranking_provider_name or not reranking_model_name:
|
||||
return None
|
||||
rerank_model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=reranking_model["reranking_provider_name"],
|
||||
provider=reranking_provider_name,
|
||||
model_type=ModelType.RERANK,
|
||||
model=reranking_model["reranking_model_name"],
|
||||
model=reranking_model_name,
|
||||
)
|
||||
return rerank_model_instance
|
||||
except InvokeAuthorizationError:
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ class RetrievalService:
|
|||
|
||||
if exceptions:
|
||||
exception_message = ";\n".join(exceptions)
|
||||
raise Exception(exception_message)
|
||||
raise ValueError(exception_message)
|
||||
|
||||
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
|
||||
data_post_processor = DataPostProcessor(
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class AppToolProviderEntity(ToolProviderController):
|
|||
user_input_form_list = app_model_config.user_input_form_list
|
||||
for input_form in user_input_form_list:
|
||||
# get type
|
||||
form_type = input_form.keys()[0]
|
||||
form_type = list(input_form.keys())[0]
|
||||
default = input_form[form_type]["default"]
|
||||
required = input_form[form_type]["required"]
|
||||
label = input_form[form_type]["label"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,115 @@
|
|||
import json
|
||||
import operator
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class BedrockRetrieveTool(BuiltinTool):
|
||||
bedrock_client: Any = None
|
||||
knowledge_base_id: str = None
|
||||
topk: int = None
|
||||
|
||||
def _bedrock_retrieve(
|
||||
self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None
|
||||
):
|
||||
try:
|
||||
retrieval_query = {"text": query_input}
|
||||
|
||||
retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}
|
||||
|
||||
# 如果有元数据过滤条件,则添加到检索配置中
|
||||
if metadata_filter:
|
||||
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter
|
||||
|
||||
response = self.bedrock_client.retrieve(
|
||||
knowledgeBaseId=knowledge_base_id,
|
||||
retrievalQuery=retrieval_query,
|
||||
retrievalConfiguration=retrieval_configuration,
|
||||
)
|
||||
|
||||
results = []
|
||||
for result in response.get("retrievalResults", []):
|
||||
results.append(
|
||||
{
|
||||
"content": result.get("content", {}).get("text", ""),
|
||||
"score": result.get("score", 0.0),
|
||||
"metadata": result.get("metadata", {}),
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
raise Exception(f"Error retrieving from knowledge base: {str(e)}")
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
line = 0
|
||||
try:
|
||||
if not self.bedrock_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.bedrock_client = boto3.client("bedrock-agent-runtime", region_name=aws_region)
|
||||
else:
|
||||
self.bedrock_client = boto3.client("bedrock-agent-runtime")
|
||||
|
||||
line = 1
|
||||
if not self.knowledge_base_id:
|
||||
self.knowledge_base_id = tool_parameters.get("knowledge_base_id")
|
||||
if not self.knowledge_base_id:
|
||||
return self.create_text_message("Please provide knowledge_base_id")
|
||||
|
||||
line = 2
|
||||
if not self.topk:
|
||||
self.topk = tool_parameters.get("topk", 5)
|
||||
|
||||
line = 3
|
||||
query = tool_parameters.get("query", "")
|
||||
if not query:
|
||||
return self.create_text_message("Please input query")
|
||||
|
||||
# 获取元数据过滤条件(如果存在)
|
||||
metadata_filter_str = tool_parameters.get("metadata_filter")
|
||||
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None
|
||||
|
||||
line = 4
|
||||
retrieved_docs = self._bedrock_retrieve(
|
||||
query_input=query,
|
||||
knowledge_base_id=self.knowledge_base_id,
|
||||
num_results=self.topk,
|
||||
metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法
|
||||
)
|
||||
|
||||
line = 5
|
||||
# Sort results by score in descending order
|
||||
sorted_docs = sorted(retrieved_docs, key=operator.itemgetter("score"), reverse=True)
|
||||
|
||||
line = 6
|
||||
return [self.create_json_message(res) for res in sorted_docs]
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception {str(e)}, line : {line}")
|
||||
|
||||
def validate_parameters(self, parameters: dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the parameters
|
||||
"""
|
||||
if not parameters.get("knowledge_base_id"):
|
||||
raise ValueError("knowledge_base_id is required")
|
||||
|
||||
if not parameters.get("query"):
|
||||
raise ValueError("query is required")
|
||||
|
||||
# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
|
||||
metadata_filter_str = parameters.get("metadata_filter")
|
||||
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
|
||||
raise ValueError("metadata_filter must be a valid JSON object")
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
identity:
|
||||
name: bedrock_retrieve
|
||||
author: AWS
|
||||
label:
|
||||
en_US: Bedrock Retrieve
|
||||
zh_Hans: Bedrock检索
|
||||
pt_BR: Bedrock Retrieve
|
||||
icon: icon.svg
|
||||
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool
|
||||
zh_Hans: Amazon Bedrock知识库检索工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署说明
|
||||
pt_BR: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base.
|
||||
llm: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool
|
||||
|
||||
parameters:
|
||||
- name: knowledge_base_id
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Bedrock Knowledge Base ID
|
||||
zh_Hans: Bedrock知识库ID
|
||||
pt_BR: Bedrock Knowledge Base ID
|
||||
human_description:
|
||||
en_US: ID of the Bedrock Knowledge Base to retrieve from
|
||||
zh_Hans: 用于检索的Bedrock知识库ID
|
||||
pt_BR: ID of the Bedrock Knowledge Base to retrieve from
|
||||
llm_description: ID of the Bedrock Knowledge Base to retrieve from
|
||||
form: form
|
||||
|
||||
- name: query
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Query string
|
||||
zh_Hans: 查询语句
|
||||
pt_BR: Query string
|
||||
human_description:
|
||||
en_US: The search query to retrieve relevant information
|
||||
zh_Hans: 用于检索相关信息的查询语句
|
||||
pt_BR: The search query to retrieve relevant information
|
||||
llm_description: The search query to retrieve relevant information
|
||||
form: llm
|
||||
|
||||
- name: topk
|
||||
type: number
|
||||
required: false
|
||||
form: form
|
||||
label:
|
||||
en_US: Limit for results count
|
||||
zh_Hans: 返回结果数量限制
|
||||
pt_BR: Limit for results count
|
||||
human_description:
|
||||
en_US: Maximum number of results to return
|
||||
zh_Hans: 最大返回结果数量
|
||||
pt_BR: Maximum number of results to return
|
||||
min: 1
|
||||
max: 10
|
||||
default: 5
|
||||
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 区域
|
||||
pt_BR: AWS Region
|
||||
human_description:
|
||||
en_US: AWS region where the Bedrock Knowledge Base is located
|
||||
zh_Hans: Bedrock知识库所在的AWS区域
|
||||
pt_BR: AWS region where the Bedrock Knowledge Base is located
|
||||
llm_description: AWS region where the Bedrock Knowledge Base is located
|
||||
form: form
|
||||
|
||||
- name: metadata_filter
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Metadata Filter
|
||||
zh_Hans: 元数据过滤器
|
||||
pt_BR: Metadata Filter
|
||||
human_description:
|
||||
en_US: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})'
|
||||
zh_Hans: '元数据的JSON格式过滤条件(例如,{{"greaterThan": {"key: "aaa", "value": 10}})'
|
||||
pt_BR: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})'
|
||||
form: form
|
||||
|
|
@ -0,0 +1,357 @@
|
|||
import base64
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NovaCanvasTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke AWS Bedrock Nova Canvas model for image generation
|
||||
"""
|
||||
# Get common parameters
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
image_output_s3uri = tool_parameters.get("image_output_s3uri", "").strip()
|
||||
if not prompt:
|
||||
return self.create_text_message("Please provide a text prompt for image generation.")
|
||||
if not image_output_s3uri or urlparse(image_output_s3uri).scheme != "s3":
|
||||
return self.create_text_message("Please provide an valid S3 URI for image output.")
|
||||
|
||||
task_type = tool_parameters.get("task_type", "TEXT_IMAGE")
|
||||
aws_region = tool_parameters.get("aws_region", "us-east-1")
|
||||
|
||||
# Get common image generation config parameters
|
||||
width = tool_parameters.get("width", 1024)
|
||||
height = tool_parameters.get("height", 1024)
|
||||
cfg_scale = tool_parameters.get("cfg_scale", 8.0)
|
||||
negative_prompt = tool_parameters.get("negative_prompt", "")
|
||||
seed = tool_parameters.get("seed", 0)
|
||||
quality = tool_parameters.get("quality", "standard")
|
||||
|
||||
# Handle S3 image if provided
|
||||
image_input_s3uri = tool_parameters.get("image_input_s3uri", "")
|
||||
if task_type != "TEXT_IMAGE":
|
||||
if not image_input_s3uri or urlparse(image_input_s3uri).scheme != "s3":
|
||||
return self.create_text_message("Please provide a valid S3 URI for image to image generation.")
|
||||
|
||||
# Parse S3 URI
|
||||
parsed_uri = urlparse(image_input_s3uri)
|
||||
bucket = parsed_uri.netloc
|
||||
key = parsed_uri.path.lstrip("/")
|
||||
|
||||
# Initialize S3 client and download image
|
||||
s3_client = boto3.client("s3")
|
||||
response = s3_client.get_object(Bucket=bucket, Key=key)
|
||||
image_data = response["Body"].read()
|
||||
|
||||
# Base64 encode the image
|
||||
input_image = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
try:
|
||||
# Initialize Bedrock client
|
||||
bedrock = boto3.client(service_name="bedrock-runtime", region_name=aws_region)
|
||||
|
||||
# Base image generation config
|
||||
image_generation_config = {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"cfgScale": cfg_scale,
|
||||
"seed": seed,
|
||||
"numberOfImages": 1,
|
||||
"quality": quality,
|
||||
}
|
||||
|
||||
# Prepare request body based on task type
|
||||
body = {"imageGenerationConfig": image_generation_config}
|
||||
|
||||
if task_type == "TEXT_IMAGE":
|
||||
body["taskType"] = "TEXT_IMAGE"
|
||||
body["textToImageParams"] = {"text": prompt}
|
||||
if negative_prompt:
|
||||
body["textToImageParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "COLOR_GUIDED_GENERATION":
|
||||
colors = tool_parameters.get("colors", "#ff8080-#ffb280-#ffe680-#ffe680")
|
||||
if not self._validate_color_string(colors):
|
||||
return self.create_text_message("Please provide valid colors in hexadecimal format.")
|
||||
|
||||
body["taskType"] = "COLOR_GUIDED_GENERATION"
|
||||
body["colorGuidedGenerationParams"] = {
|
||||
"colors": colors.split("-"),
|
||||
"referenceImage": input_image,
|
||||
"text": prompt,
|
||||
}
|
||||
if negative_prompt:
|
||||
body["colorGuidedGenerationParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "IMAGE_VARIATION":
|
||||
similarity_strength = tool_parameters.get("similarity_strength", 0.5)
|
||||
|
||||
body["taskType"] = "IMAGE_VARIATION"
|
||||
body["imageVariationParams"] = {
|
||||
"images": [input_image],
|
||||
"similarityStrength": similarity_strength,
|
||||
"text": prompt,
|
||||
}
|
||||
if negative_prompt:
|
||||
body["imageVariationParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "INPAINTING":
|
||||
mask_prompt = tool_parameters.get("mask_prompt")
|
||||
if not mask_prompt:
|
||||
return self.create_text_message("Please provide a mask prompt for image inpainting.")
|
||||
|
||||
body["taskType"] = "INPAINTING"
|
||||
body["inPaintingParams"] = {"image": input_image, "maskPrompt": mask_prompt, "text": prompt}
|
||||
if negative_prompt:
|
||||
body["inPaintingParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "OUTPAINTING":
|
||||
mask_prompt = tool_parameters.get("mask_prompt")
|
||||
if not mask_prompt:
|
||||
return self.create_text_message("Please provide a mask prompt for image outpainting.")
|
||||
outpainting_mode = tool_parameters.get("outpainting_mode", "DEFAULT")
|
||||
|
||||
body["taskType"] = "OUTPAINTING"
|
||||
body["outPaintingParams"] = {
|
||||
"image": input_image,
|
||||
"maskPrompt": mask_prompt,
|
||||
"outPaintingMode": outpainting_mode,
|
||||
"text": prompt,
|
||||
}
|
||||
if negative_prompt:
|
||||
body["outPaintingParams"]["negativeText"] = negative_prompt
|
||||
|
||||
elif task_type == "BACKGROUND_REMOVAL":
|
||||
body["taskType"] = "BACKGROUND_REMOVAL"
|
||||
body["backgroundRemovalParams"] = {"image": input_image}
|
||||
|
||||
else:
|
||||
return self.create_text_message(f"Unsupported task type: {task_type}")
|
||||
|
||||
# Call Nova Canvas model
|
||||
response = bedrock.invoke_model(
|
||||
body=json.dumps(body),
|
||||
modelId="amazon.nova-canvas-v1:0",
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
)
|
||||
|
||||
# Process response
|
||||
response_body = json.loads(response.get("body").read())
|
||||
if response_body.get("error"):
|
||||
raise Exception(f"Error in model response: {response_body.get('error')}")
|
||||
base64_image = response_body.get("images")[0]
|
||||
|
||||
# Upload to S3 if image_output_s3uri is provided
|
||||
try:
|
||||
# Parse S3 URI for output
|
||||
parsed_uri = urlparse(image_output_s3uri)
|
||||
output_bucket = parsed_uri.netloc
|
||||
output_base_path = parsed_uri.path.lstrip("/")
|
||||
# Generate filename with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_key = f"{output_base_path}/canvas-output-{timestamp}.png"
|
||||
|
||||
# Initialize S3 client if not already done
|
||||
s3_client = boto3.client("s3", region_name=aws_region)
|
||||
|
||||
# Decode base64 image and upload to S3
|
||||
image_data = base64.b64decode(base64_image)
|
||||
s3_client.put_object(Bucket=output_bucket, Key=output_key, Body=image_data, ContentType="image/png")
|
||||
logger.info(f"Image uploaded to s3://{output_bucket}/{output_key}")
|
||||
except Exception as e:
|
||||
logger.exception("Failed to upload image to S3")
|
||||
# Return image
|
||||
return [
|
||||
self.create_text_message(f"Image is available at: s3://{output_bucket}/{output_key}"),
|
||||
self.create_blob_message(
|
||||
blob=base64.b64decode(base64_image),
|
||||
meta={"mime_type": "image/png"},
|
||||
save_as=self.VariableKey.IMAGE.value,
|
||||
),
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Failed to generate image: {str(e)}")
|
||||
|
||||
def _validate_color_string(self, color_string) -> bool:
|
||||
color_pattern = r"^#[0-9a-fA-F]{6}(?:-#[0-9a-fA-F]{6})*$"
|
||||
|
||||
if re.match(color_pattern, color_string):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
parameters = [
|
||||
ToolParameter(
|
||||
name="prompt",
|
||||
label=I18nObject(en_US="Prompt", zh_Hans="提示词"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=True,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Text description of the image you want to generate or modify",
|
||||
zh_Hans="您想要生成或修改的图像的文本描述",
|
||||
),
|
||||
llm_description="Describe the image you want to generate or how you want to modify the input image",
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_input_s3uri",
|
||||
label=I18nObject(en_US="Input image s3 uri", zh_Hans="输入图片的s3 uri"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(en_US="Image to be modified", zh_Hans="想要修改的图片"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_output_s3uri",
|
||||
label=I18nObject(en_US="Output Image S3 URI", zh_Hans="输出图片的S3 URI目录"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=True,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="S3 URI where the generated image should be uploaded", zh_Hans="生成的图像应该上传到的S3 URI"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="width",
|
||||
label=I18nObject(en_US="Width", zh_Hans="宽度"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=1024,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Width of the generated image", zh_Hans="生成图像的宽度"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="height",
|
||||
label=I18nObject(en_US="Height", zh_Hans="高度"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=1024,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Height of the generated image", zh_Hans="生成图像的高度"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="cfg_scale",
|
||||
label=I18nObject(en_US="CFG Scale", zh_Hans="CFG比例"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=8.0,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="How strongly the image should conform to the prompt", zh_Hans="图像应该多大程度上符合提示词"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="negative_prompt",
|
||||
label=I18nObject(en_US="Negative Prompt", zh_Hans="负面提示词"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="",
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Things you don't want in the generated image", zh_Hans="您不想在生成的图像中出现的内容"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="seed",
|
||||
label=I18nObject(en_US="Seed", zh_Hans="种子值"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=0,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Random seed for image generation", zh_Hans="图像生成的随机种子"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="aws_region",
|
||||
label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="us-east-1",
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="task_type",
|
||||
label=I18nObject(en_US="Task Type", zh_Hans="任务类型"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="TEXT_IMAGE",
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(en_US="Type of image generation task", zh_Hans="图像生成任务的类型"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="quality",
|
||||
label=I18nObject(en_US="Quality", zh_Hans="质量"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="standard",
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="Quality of the generated image (standard or premium)", zh_Hans="生成图像的质量(标准或高级)"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="colors",
|
||||
label=I18nObject(en_US="Colors", zh_Hans="颜色"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="List of colors for color-guided generation, example: #ff8080-#ffb280-#ffe680-#ffe680",
|
||||
zh_Hans="颜色引导生成的颜色列表, 例子: #ff8080-#ffb280-#ffe680-#ffe680",
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="similarity_strength",
|
||||
label=I18nObject(en_US="Similarity Strength", zh_Hans="相似度强度"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=0.5,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="How similar the generated image should be to the input image (0.0 to 1.0)",
|
||||
zh_Hans="生成的图像应该与输入图像的相似程度(0.0到1.0)",
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="mask_prompt",
|
||||
label=I18nObject(en_US="Mask Prompt", zh_Hans="蒙版提示词"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Text description to generate mask for inpainting/outpainting",
|
||||
zh_Hans="用于生成内补绘制/外补绘制蒙版的文本描述",
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="outpainting_mode",
|
||||
label=I18nObject(en_US="Outpainting Mode", zh_Hans="外补绘制模式"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default="DEFAULT",
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="Mode for outpainting (DEFAULT or other supported modes)",
|
||||
zh_Hans="外补绘制的模式(DEFAULT或其他支持的模式)",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
return parameters
|
||||
|
|
@ -0,0 +1,175 @@
|
|||
identity:
|
||||
name: nova_canvas
|
||||
author: AWS
|
||||
label:
|
||||
en_US: AWS Bedrock Nova Canvas
|
||||
zh_Hans: AWS Bedrock Nova Canvas
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for generating and modifying images using AWS Bedrock's Nova Canvas model. Supports text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html
|
||||
zh_Hans: 使用 AWS Bedrock 的 Nova Canvas 模型生成和修改图像的工具。支持文生图、颜色引导生成、图像变体、内补绘制、外补绘制和背景移除功能, 输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html。
|
||||
llm: Generate or modify images using AWS Bedrock's Nova Canvas model with multiple task types including text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal.
|
||||
parameters:
|
||||
- name: task_type
|
||||
type: string
|
||||
required: false
|
||||
default: TEXT_IMAGE
|
||||
label:
|
||||
en_US: Task Type
|
||||
zh_Hans: 任务类型
|
||||
human_description:
|
||||
en_US: Type of image generation task (TEXT_IMAGE, COLOR_GUIDED_GENERATION, IMAGE_VARIATION, INPAINTING, OUTPAINTING, BACKGROUND_REMOVAL)
|
||||
zh_Hans: 图像生成任务的类型(文生图、颜色引导生成、图像变体、内补绘制、外补绘制、背景移除)
|
||||
form: llm
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
human_description:
|
||||
en_US: Text description of the image you want to generate or modify
|
||||
zh_Hans: 您想要生成或修改的图像的文本描述
|
||||
llm_description: Describe the image you want to generate or how you want to modify the input image
|
||||
form: llm
|
||||
- name: image_input_s3uri
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Input image s3 uri
|
||||
zh_Hans: 输入图片的s3 uri
|
||||
human_description:
|
||||
en_US: The input image to modify (required for all modes except TEXT_IMAGE)
|
||||
zh_Hans: 要修改的输入图像(除文生图外的所有模式都需要)
|
||||
llm_description: The input image you want to modify. Required for all modes except TEXT_IMAGE.
|
||||
form: llm
|
||||
- name: image_output_s3uri
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Output S3 URI
|
||||
zh_Hans: 输出S3 URI
|
||||
human_description:
|
||||
en_US: The S3 URI where the generated image will be saved. If provided, the image will be uploaded with name format canvas-output-{timestamp}.png
|
||||
zh_Hans: 生成的图像将保存到的S3 URI。如果提供,图像将以canvas-output-{timestamp}.png的格式上传
|
||||
llm_description: Optional S3 URI where the generated image will be uploaded. The image will be saved with a timestamp-based filename.
|
||||
form: form
|
||||
- name: negative_prompt
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Negative Prompt
|
||||
zh_Hans: 负面提示词
|
||||
human_description:
|
||||
en_US: Things you don't want in the generated image
|
||||
zh_Hans: 您不想在生成的图像中出现的内容
|
||||
form: llm
|
||||
- name: width
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Width
|
||||
zh_Hans: 宽度
|
||||
human_description:
|
||||
en_US: Width of the generated image
|
||||
zh_Hans: 生成图像的宽度
|
||||
form: form
|
||||
default: 1024
|
||||
- name: height
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Height
|
||||
zh_Hans: 高度
|
||||
human_description:
|
||||
en_US: Height of the generated image
|
||||
zh_Hans: 生成图像的高度
|
||||
form: form
|
||||
default: 1024
|
||||
- name: cfg_scale
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: CFG Scale
|
||||
zh_Hans: CFG比例
|
||||
human_description:
|
||||
en_US: How strongly the image should conform to the prompt
|
||||
zh_Hans: 图像应该多大程度上符合提示词
|
||||
form: form
|
||||
default: 8.0
|
||||
- name: seed
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Seed
|
||||
zh_Hans: 种子值
|
||||
human_description:
|
||||
en_US: Random seed for image generation
|
||||
zh_Hans: 图像生成的随机种子
|
||||
form: form
|
||||
default: 0
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
default: us-east-1
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 区域
|
||||
human_description:
|
||||
en_US: AWS region for Bedrock service
|
||||
zh_Hans: Bedrock 服务的 AWS 区域
|
||||
form: form
|
||||
- name: quality
|
||||
type: string
|
||||
required: false
|
||||
default: standard
|
||||
label:
|
||||
en_US: Quality
|
||||
zh_Hans: 质量
|
||||
human_description:
|
||||
en_US: Quality of the generated image (standard or premium)
|
||||
zh_Hans: 生成图像的质量(标准或高级)
|
||||
form: form
|
||||
- name: colors
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Colors
|
||||
zh_Hans: 颜色
|
||||
human_description:
|
||||
en_US: List of colors for color-guided generation
|
||||
zh_Hans: 颜色引导生成的颜色列表
|
||||
form: form
|
||||
- name: similarity_strength
|
||||
type: number
|
||||
required: false
|
||||
default: 0.5
|
||||
label:
|
||||
en_US: Similarity Strength
|
||||
zh_Hans: 相似度强度
|
||||
human_description:
|
||||
en_US: How similar the generated image should be to the input image (0.0 to 1.0)
|
||||
zh_Hans: 生成的图像应该与输入图像的相似程度(0.0到1.0)
|
||||
form: form
|
||||
- name: mask_prompt
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Mask Prompt
|
||||
zh_Hans: 蒙版提示词
|
||||
human_description:
|
||||
en_US: Text description to generate mask for inpainting/outpainting
|
||||
zh_Hans: 用于生成内补绘制/外补绘制蒙版的文本描述
|
||||
form: llm
|
||||
- name: outpainting_mode
|
||||
type: string
|
||||
required: false
|
||||
default: DEFAULT
|
||||
label:
|
||||
en_US: Outpainting Mode
|
||||
zh_Hans: 外补绘制模式
|
||||
human_description:
|
||||
en_US: Mode for outpainting (DEFAULT or other supported modes)
|
||||
zh_Hans: 外补绘制的模式(DEFAULT或其他支持的模式)
|
||||
form: form
|
||||
|
|
@ -0,0 +1,371 @@
|
|||
import base64
|
||||
import logging
|
||||
import time
|
||||
from io import BytesIO
|
||||
from typing import Any, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
from PIL import Image
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NOVA_REEL_DEFAULT_REGION = "us-east-1"
|
||||
NOVA_REEL_DEFAULT_DIMENSION = "1280x720"
|
||||
NOVA_REEL_DEFAULT_FPS = 24
|
||||
NOVA_REEL_DEFAULT_DURATION = 6
|
||||
NOVA_REEL_MODEL_ID = "amazon.nova-reel-v1:0"
|
||||
NOVA_REEL_STATUS_CHECK_INTERVAL = 5
|
||||
|
||||
# Image requirements
|
||||
NOVA_REEL_REQUIRED_IMAGE_WIDTH = 1280
|
||||
NOVA_REEL_REQUIRED_IMAGE_HEIGHT = 720
|
||||
NOVA_REEL_REQUIRED_IMAGE_MODE = "RGB"
|
||||
|
||||
|
||||
class NovaReelTool(BuiltinTool):
|
||||
def _invoke(
|
||||
self, user_id: str, tool_parameters: dict[str, Any]
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
Invoke AWS Bedrock Nova Reel model for video generation.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user making the request
|
||||
tool_parameters: Dictionary containing the tool parameters
|
||||
|
||||
Returns:
|
||||
ToolInvokeMessage containing either the video content or status information
|
||||
"""
|
||||
try:
|
||||
# Validate and extract parameters
|
||||
params = self._validate_and_extract_parameters(tool_parameters)
|
||||
if isinstance(params, ToolInvokeMessage):
|
||||
return params
|
||||
|
||||
# Initialize AWS clients
|
||||
bedrock, s3_client = self._initialize_aws_clients(params["aws_region"])
|
||||
|
||||
# Prepare model input
|
||||
model_input = self._prepare_model_input(params, s3_client)
|
||||
if isinstance(model_input, ToolInvokeMessage):
|
||||
return model_input
|
||||
|
||||
# Start video generation
|
||||
invocation = self._start_video_generation(bedrock, model_input, params["video_output_s3uri"])
|
||||
invocation_arn = invocation["invocationArn"]
|
||||
|
||||
# Handle async/sync mode
|
||||
return self._handle_generation_mode(bedrock, s3_client, invocation_arn, params["async_mode"])
|
||||
|
||||
except ClientError as e:
|
||||
error_code = e.response.get("Error", {}).get("Code", "Unknown")
|
||||
error_message = e.response.get("Error", {}).get("Message", str(e))
|
||||
logger.exception(f"AWS API error: {error_code} - {error_message}")
|
||||
return self.create_text_message(f"AWS service error: {error_code} - {error_message}")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in video generation: {str(e)}", exc_info=True)
|
||||
return self.create_text_message(f"Failed to generate video: {str(e)}")
|
||||
|
||||
def _validate_and_extract_parameters(
|
||||
self, tool_parameters: dict[str, Any]
|
||||
) -> Union[dict[str, Any], ToolInvokeMessage]:
|
||||
"""Validate and extract parameters from the input dictionary."""
|
||||
prompt = tool_parameters.get("prompt", "")
|
||||
video_output_s3uri = tool_parameters.get("video_output_s3uri", "").strip()
|
||||
|
||||
# Validate required parameters
|
||||
if not prompt:
|
||||
return self.create_text_message("Please provide a text prompt for video generation.")
|
||||
if not video_output_s3uri:
|
||||
return self.create_text_message("Please provide an S3 URI for video output.")
|
||||
|
||||
# Validate S3 URI format
|
||||
if not video_output_s3uri.startswith("s3://"):
|
||||
return self.create_text_message("Invalid S3 URI format. Must start with 's3://'")
|
||||
|
||||
# Ensure S3 URI ends with '/'
|
||||
video_output_s3uri = video_output_s3uri if video_output_s3uri.endswith("/") else video_output_s3uri + "/"
|
||||
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"video_output_s3uri": video_output_s3uri,
|
||||
"image_input_s3uri": tool_parameters.get("image_input_s3uri", "").strip(),
|
||||
"aws_region": tool_parameters.get("aws_region", NOVA_REEL_DEFAULT_REGION),
|
||||
"dimension": tool_parameters.get("dimension", NOVA_REEL_DEFAULT_DIMENSION),
|
||||
"seed": int(tool_parameters.get("seed", 0)),
|
||||
"fps": int(tool_parameters.get("fps", NOVA_REEL_DEFAULT_FPS)),
|
||||
"duration": int(tool_parameters.get("duration", NOVA_REEL_DEFAULT_DURATION)),
|
||||
"async_mode": bool(tool_parameters.get("async", True)),
|
||||
}
|
||||
|
||||
def _initialize_aws_clients(self, region: str) -> tuple[Any, Any]:
|
||||
"""Initialize AWS Bedrock and S3 clients."""
|
||||
bedrock = boto3.client(service_name="bedrock-runtime", region_name=region)
|
||||
s3_client = boto3.client("s3", region_name=region)
|
||||
return bedrock, s3_client
|
||||
|
||||
def _prepare_model_input(self, params: dict[str, Any], s3_client: Any) -> Union[dict[str, Any], ToolInvokeMessage]:
|
||||
"""Prepare the input for the Nova Reel model."""
|
||||
model_input = {
|
||||
"taskType": "TEXT_VIDEO",
|
||||
"textToVideoParams": {"text": params["prompt"]},
|
||||
"videoGenerationConfig": {
|
||||
"durationSeconds": params["duration"],
|
||||
"fps": params["fps"],
|
||||
"dimension": params["dimension"],
|
||||
"seed": params["seed"],
|
||||
},
|
||||
}
|
||||
|
||||
# Add image if provided
|
||||
if params["image_input_s3uri"]:
|
||||
try:
|
||||
image_data = self._get_image_from_s3(s3_client, params["image_input_s3uri"])
|
||||
if not image_data:
|
||||
return self.create_text_message("Failed to retrieve image from S3")
|
||||
|
||||
# Process and validate image
|
||||
processed_image = self._process_and_validate_image(image_data)
|
||||
if isinstance(processed_image, ToolInvokeMessage):
|
||||
return processed_image
|
||||
|
||||
# Convert processed image to base64
|
||||
img_buffer = BytesIO()
|
||||
processed_image.save(img_buffer, format="PNG")
|
||||
img_buffer.seek(0)
|
||||
input_image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
|
||||
|
||||
model_input["textToVideoParams"]["images"] = [
|
||||
{"format": "png", "source": {"bytes": input_image_base64}}
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing input image: {str(e)}", exc_info=True)
|
||||
return self.create_text_message(f"Failed to process input image: {str(e)}")
|
||||
|
||||
return model_input
|
||||
|
||||
def _process_and_validate_image(self, image_data: bytes) -> Union[Image.Image, ToolInvokeMessage]:
|
||||
"""
|
||||
Process and validate the input image according to Nova Reel requirements.
|
||||
|
||||
Requirements:
|
||||
- Must be 1280x720 pixels
|
||||
- Must be RGB format (8 bits per channel)
|
||||
- If PNG, alpha channel must not have transparent/translucent pixels
|
||||
"""
|
||||
try:
|
||||
# Open image
|
||||
img = Image.open(BytesIO(image_data))
|
||||
|
||||
# Convert RGBA to RGB if needed, ensuring no transparency
|
||||
if img.mode == "RGBA":
|
||||
# Check for transparency
|
||||
if img.getchannel("A").getextrema()[0] < 255:
|
||||
return self.create_text_message(
|
||||
"PNG image contains transparent or translucent pixels, which is not supported. "
|
||||
"Please provide an image without transparency."
|
||||
)
|
||||
# Convert to RGB
|
||||
img = img.convert("RGB")
|
||||
elif img.mode != "RGB":
|
||||
# Convert any other mode to RGB
|
||||
img = img.convert("RGB")
|
||||
|
||||
# Validate/adjust dimensions
|
||||
if img.size != (NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT):
|
||||
logger.warning(
|
||||
f"Image dimensions {img.size} do not match required dimensions "
|
||||
f"({NOVA_REEL_REQUIRED_IMAGE_WIDTH}x{NOVA_REEL_REQUIRED_IMAGE_HEIGHT}). Resizing..."
|
||||
)
|
||||
img = img.resize(
|
||||
(NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT), Image.Resampling.LANCZOS
|
||||
)
|
||||
|
||||
# Validate bit depth
|
||||
if img.mode != NOVA_REEL_REQUIRED_IMAGE_MODE:
|
||||
return self.create_text_message(
|
||||
f"Image must be in {NOVA_REEL_REQUIRED_IMAGE_MODE} mode with 8 bits per channel"
|
||||
)
|
||||
|
||||
return img
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image: {str(e)}", exc_info=True)
|
||||
return self.create_text_message(
|
||||
"Failed to process image. Please ensure the image is a valid JPEG or PNG file."
|
||||
)
|
||||
|
||||
def _get_image_from_s3(self, s3_client: Any, s3_uri: str) -> Optional[bytes]:
|
||||
"""Download and return image data from S3."""
|
||||
parsed_uri = urlparse(s3_uri)
|
||||
bucket = parsed_uri.netloc
|
||||
key = parsed_uri.path.lstrip("/")
|
||||
|
||||
response = s3_client.get_object(Bucket=bucket, Key=key)
|
||||
return response["Body"].read()
|
||||
|
||||
def _start_video_generation(self, bedrock: Any, model_input: dict[str, Any], output_s3uri: str) -> dict[str, Any]:
|
||||
"""Start the async video generation process."""
|
||||
return bedrock.start_async_invoke(
|
||||
modelId=NOVA_REEL_MODEL_ID,
|
||||
modelInput=model_input,
|
||||
outputDataConfig={"s3OutputDataConfig": {"s3Uri": output_s3uri}},
|
||||
)
|
||||
|
||||
def _handle_generation_mode(
|
||||
self, bedrock: Any, s3_client: Any, invocation_arn: str, async_mode: bool
|
||||
) -> ToolInvokeMessage:
|
||||
"""Handle async or sync video generation mode."""
|
||||
invocation_response = bedrock.get_async_invoke(invocationArn=invocation_arn)
|
||||
video_path = invocation_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"]
|
||||
video_uri = f"{video_path}/output.mp4"
|
||||
|
||||
if async_mode:
|
||||
return self.create_text_message(
|
||||
f"Video generation started.\nInvocation ARN: {invocation_arn}\n"
|
||||
f"Video will be available at: {video_uri}"
|
||||
)
|
||||
|
||||
return self._wait_for_completion(bedrock, s3_client, invocation_arn)
|
||||
|
||||
def _wait_for_completion(self, bedrock: Any, s3_client: Any, invocation_arn: str) -> ToolInvokeMessage:
|
||||
"""Wait for video generation completion and handle the result."""
|
||||
while True:
|
||||
status_response = bedrock.get_async_invoke(invocationArn=invocation_arn)
|
||||
status = status_response["status"]
|
||||
video_path = status_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"]
|
||||
|
||||
if status == "Completed":
|
||||
return self._handle_completed_video(s3_client, video_path)
|
||||
elif status == "Failed":
|
||||
failure_message = status_response.get("failureMessage", "Unknown error")
|
||||
return self.create_text_message(f"Video generation failed.\nError: {failure_message}")
|
||||
elif status == "InProgress":
|
||||
time.sleep(NOVA_REEL_STATUS_CHECK_INTERVAL)
|
||||
else:
|
||||
return self.create_text_message(f"Unexpected status: {status}")
|
||||
|
||||
def _handle_completed_video(self, s3_client: Any, video_path: str) -> ToolInvokeMessage:
|
||||
"""Handle completed video generation and return the result."""
|
||||
parsed_uri = urlparse(video_path)
|
||||
bucket = parsed_uri.netloc
|
||||
key = parsed_uri.path.lstrip("/") + "/output.mp4"
|
||||
|
||||
try:
|
||||
response = s3_client.get_object(Bucket=bucket, Key=key)
|
||||
video_content = response["Body"].read()
|
||||
return [
|
||||
self.create_text_message(f"Video is available at: {video_path}/output.mp4"),
|
||||
self.create_blob_message(blob=video_content, meta={"mime_type": "video/mp4"}, save_as="output.mp4"),
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading video: {str(e)}", exc_info=True)
|
||||
return self.create_text_message(
|
||||
f"Video generation completed but failed to download video: {str(e)}\n"
|
||||
f"Video is available at: s3://{bucket}/{key}"
|
||||
)
|
||||
|
||||
def get_runtime_parameters(self) -> list[ToolParameter]:
|
||||
"""Define the tool's runtime parameters."""
|
||||
parameters = [
|
||||
ToolParameter(
|
||||
name="prompt",
|
||||
label=I18nObject(en_US="Prompt", zh_Hans="提示词"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=True,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Text description of the video you want to generate", zh_Hans="您想要生成的视频的文本描述"
|
||||
),
|
||||
llm_description="Describe the video you want to generate",
|
||||
),
|
||||
ToolParameter(
|
||||
name="video_output_s3uri",
|
||||
label=I18nObject(en_US="Output S3 URI", zh_Hans="输出S3 URI"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=True,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="S3 URI where the generated video will be stored", zh_Hans="生成的视频将存储的S3 URI"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="dimension",
|
||||
label=I18nObject(en_US="Dimension", zh_Hans="尺寸"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default=NOVA_REEL_DEFAULT_DIMENSION,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Video dimensions (width x height)", zh_Hans="视频尺寸(宽 x 高)"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="duration",
|
||||
label=I18nObject(en_US="Duration", zh_Hans="时长"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=NOVA_REEL_DEFAULT_DURATION,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Video duration in seconds", zh_Hans="视频时长(秒)"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="seed",
|
||||
label=I18nObject(en_US="Seed", zh_Hans="种子值"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=0,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="Random seed for video generation", zh_Hans="视频生成的随机种子"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="fps",
|
||||
label=I18nObject(en_US="FPS", zh_Hans="帧率"),
|
||||
type=ToolParameter.ToolParameterType.NUMBER,
|
||||
required=False,
|
||||
default=NOVA_REEL_DEFAULT_FPS,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(
|
||||
en_US="Frames per second for the generated video", zh_Hans="生成视频的每秒帧数"
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="async",
|
||||
label=I18nObject(en_US="Async Mode", zh_Hans="异步模式"),
|
||||
type=ToolParameter.ToolParameterType.BOOLEAN,
|
||||
required=False,
|
||||
default=True,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="Whether to run in async mode (return immediately) or sync mode (wait for completion)",
|
||||
zh_Hans="是否以异步模式运行(立即返回)或同步模式(等待完成)",
|
||||
),
|
||||
),
|
||||
ToolParameter(
|
||||
name="aws_region",
|
||||
label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
default=NOVA_REEL_DEFAULT_REGION,
|
||||
form=ToolParameter.ToolParameterForm.FORM,
|
||||
human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"),
|
||||
),
|
||||
ToolParameter(
|
||||
name="image_input_s3uri",
|
||||
label=I18nObject(en_US="Input Image S3 URI", zh_Hans="输入图像S3 URI"),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=False,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
human_description=I18nObject(
|
||||
en_US="S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame",
|
||||
zh_Hans="用作第一帧的输入图像(1280x720 JPEG/PNG)的S3 URI",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
return parameters
|
||||
|
|
@ -0,0 +1,124 @@
|
|||
identity:
|
||||
name: nova_reel
|
||||
author: AWS
|
||||
label:
|
||||
en_US: AWS Bedrock Nova Reel
|
||||
zh_Hans: AWS Bedrock Nova Reel
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: A tool for generating videos using AWS Bedrock's Nova Reel model. Supports text-to-video generation and image-to-video generation with customizable parameters like duration, FPS, and dimensions. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html
|
||||
zh_Hans: 使用 AWS Bedrock 的 Nova Reel 模型生成视频的工具。支持文本生成视频和图像生成视频功能,可自定义持续时间、帧率和尺寸等参数。输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html
|
||||
llm: Generate videos using AWS Bedrock's Nova Reel model with support for both text-to-video and image-to-video generation, allowing customization of video properties like duration, frame rate, and resolution.
|
||||
|
||||
parameters:
|
||||
- name: prompt
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Prompt
|
||||
zh_Hans: 提示词
|
||||
human_description:
|
||||
en_US: Text description of the video you want to generate
|
||||
zh_Hans: 您想要生成的视频的文本描述
|
||||
llm_description: Describe the video you want to generate
|
||||
form: llm
|
||||
|
||||
- name: video_output_s3uri
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: Output S3 URI
|
||||
zh_Hans: 输出S3 URI
|
||||
human_description:
|
||||
en_US: S3 URI where the generated video will be stored
|
||||
zh_Hans: 生成的视频将存储的S3 URI
|
||||
form: form
|
||||
|
||||
- name: dimension
|
||||
type: string
|
||||
required: false
|
||||
default: 1280x720
|
||||
label:
|
||||
en_US: Dimension
|
||||
zh_Hans: 尺寸
|
||||
human_description:
|
||||
en_US: Video dimensions (width x height)
|
||||
zh_Hans: 视频尺寸(宽 x 高)
|
||||
form: form
|
||||
|
||||
- name: duration
|
||||
type: number
|
||||
required: false
|
||||
default: 6
|
||||
label:
|
||||
en_US: Duration
|
||||
zh_Hans: 时长
|
||||
human_description:
|
||||
en_US: Video duration in seconds
|
||||
zh_Hans: 视频时长(秒)
|
||||
form: form
|
||||
|
||||
- name: seed
|
||||
type: number
|
||||
required: false
|
||||
default: 0
|
||||
label:
|
||||
en_US: Seed
|
||||
zh_Hans: 种子值
|
||||
human_description:
|
||||
en_US: Random seed for video generation
|
||||
zh_Hans: 视频生成的随机种子
|
||||
form: form
|
||||
|
||||
- name: fps
|
||||
type: number
|
||||
required: false
|
||||
default: 24
|
||||
label:
|
||||
en_US: FPS
|
||||
zh_Hans: 帧率
|
||||
human_description:
|
||||
en_US: Frames per second for the generated video
|
||||
zh_Hans: 生成视频的每秒帧数
|
||||
form: form
|
||||
|
||||
- name: async
|
||||
type: boolean
|
||||
required: false
|
||||
default: true
|
||||
label:
|
||||
en_US: Async Mode
|
||||
zh_Hans: 异步模式
|
||||
human_description:
|
||||
en_US: Whether to run in async mode (return immediately) or sync mode (wait for completion)
|
||||
zh_Hans: 是否以异步模式运行(立即返回)或同步模式(等待完成)
|
||||
form: llm
|
||||
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: false
|
||||
default: us-east-1
|
||||
label:
|
||||
en_US: AWS Region
|
||||
zh_Hans: AWS 区域
|
||||
human_description:
|
||||
en_US: AWS region for Bedrock service
|
||||
zh_Hans: Bedrock 服务的 AWS 区域
|
||||
form: form
|
||||
|
||||
- name: image_input_s3uri
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: Input Image S3 URI
|
||||
zh_Hans: 输入图像S3 URI
|
||||
human_description:
|
||||
en_US: S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame
|
||||
zh_Hans: 用作第一帧的输入图像(1280x720 JPEG/PNG)的S3 URI
|
||||
form: llm
|
||||
|
||||
development:
|
||||
dependencies:
|
||||
- boto3
|
||||
- pillow
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
from typing import Any, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import boto3
|
||||
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
|
||||
|
||||
class S3Operator(BuiltinTool):
|
||||
s3_client: Any = None
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
user_id: str,
|
||||
tool_parameters: dict[str, Any],
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
"""
|
||||
invoke tools
|
||||
"""
|
||||
try:
|
||||
# Initialize S3 client if not already done
|
||||
if not self.s3_client:
|
||||
aws_region = tool_parameters.get("aws_region")
|
||||
if aws_region:
|
||||
self.s3_client = boto3.client("s3", region_name=aws_region)
|
||||
else:
|
||||
self.s3_client = boto3.client("s3")
|
||||
|
||||
# Parse S3 URI
|
||||
s3_uri = tool_parameters.get("s3_uri")
|
||||
if not s3_uri:
|
||||
return self.create_text_message("s3_uri parameter is required")
|
||||
|
||||
parsed_uri = urlparse(s3_uri)
|
||||
if parsed_uri.scheme != "s3":
|
||||
return self.create_text_message("Invalid S3 URI format. Must start with 's3://'")
|
||||
|
||||
bucket = parsed_uri.netloc
|
||||
# Remove leading slash from key
|
||||
key = parsed_uri.path.lstrip("/")
|
||||
|
||||
operation_type = tool_parameters.get("operation_type", "read")
|
||||
generate_presign_url = tool_parameters.get("generate_presign_url", False)
|
||||
presign_expiry = int(tool_parameters.get("presign_expiry", 3600)) # default 1 hour
|
||||
|
||||
if operation_type == "write":
|
||||
text_content = tool_parameters.get("text_content")
|
||||
if not text_content:
|
||||
return self.create_text_message("text_content parameter is required for write operation")
|
||||
|
||||
# Write content to S3
|
||||
self.s3_client.put_object(Bucket=bucket, Key=key, Body=text_content.encode("utf-8"))
|
||||
result = f"s3://{bucket}/{key}"
|
||||
|
||||
# Generate presigned URL for the written object if requested
|
||||
if generate_presign_url:
|
||||
result = self.s3_client.generate_presigned_url(
|
||||
"get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry
|
||||
)
|
||||
|
||||
else: # read operation
|
||||
# Get object from S3
|
||||
response = self.s3_client.get_object(Bucket=bucket, Key=key)
|
||||
result = response["Body"].read().decode("utf-8")
|
||||
|
||||
# Generate presigned URL if requested
|
||||
if generate_presign_url:
|
||||
result = self.s3_client.generate_presigned_url(
|
||||
"get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry
|
||||
)
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
|
||||
except self.s3_client.exceptions.NoSuchBucket:
|
||||
return self.create_text_message(f"Bucket '{bucket}' does not exist")
|
||||
except self.s3_client.exceptions.NoSuchKey:
|
||||
return self.create_text_message(f"Object '{key}' does not exist in bucket '{bucket}'")
|
||||
except Exception as e:
|
||||
return self.create_text_message(f"Exception: {str(e)}")
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
identity:
|
||||
name: s3_operator
|
||||
author: AWS
|
||||
label:
|
||||
en_US: AWS S3 Operator
|
||||
zh_Hans: AWS S3 读写器
|
||||
pt_BR: AWS S3 Operator
|
||||
icon: icon.svg
|
||||
description:
|
||||
human:
|
||||
en_US: AWS S3 Writer and Reader
|
||||
zh_Hans: 读写S3 bucket中的文件
|
||||
pt_BR: AWS S3 Writer and Reader
|
||||
llm: AWS S3 Writer and Reader
|
||||
parameters:
|
||||
- name: text_content
|
||||
type: string
|
||||
required: false
|
||||
label:
|
||||
en_US: The text to write
|
||||
zh_Hans: 待写入的文本
|
||||
pt_BR: The text to write
|
||||
human_description:
|
||||
en_US: The text to write
|
||||
zh_Hans: 待写入的文本
|
||||
pt_BR: The text to write
|
||||
llm_description: The text to write
|
||||
form: llm
|
||||
- name: s3_uri
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: s3 uri
|
||||
zh_Hans: s3 uri
|
||||
pt_BR: s3 uri
|
||||
human_description:
|
||||
en_US: s3 uri
|
||||
zh_Hans: s3 uri
|
||||
pt_BR: s3 uri
|
||||
llm_description: s3 uri
|
||||
form: llm
|
||||
- name: aws_region
|
||||
type: string
|
||||
required: true
|
||||
label:
|
||||
en_US: region of bucket
|
||||
zh_Hans: bucket 所在的region
|
||||
pt_BR: region of bucket
|
||||
human_description:
|
||||
en_US: region of bucket
|
||||
zh_Hans: bucket 所在的region
|
||||
pt_BR: region of bucket
|
||||
llm_description: region of bucket
|
||||
form: form
|
||||
- name: operation_type
|
||||
type: select
|
||||
required: true
|
||||
label:
|
||||
en_US: operation type
|
||||
zh_Hans: 操作类型
|
||||
pt_BR: operation type
|
||||
human_description:
|
||||
en_US: operation type
|
||||
zh_Hans: 操作类型
|
||||
pt_BR: operation type
|
||||
default: read
|
||||
options:
|
||||
- value: read
|
||||
label:
|
||||
en_US: read
|
||||
zh_Hans: 读
|
||||
- value: write
|
||||
label:
|
||||
en_US: write
|
||||
zh_Hans: 写
|
||||
form: form
|
||||
- name: generate_presign_url
|
||||
type: boolean
|
||||
required: false
|
||||
label:
|
||||
en_US: Generate presigned URL
|
||||
zh_Hans: 生成预签名URL
|
||||
human_description:
|
||||
en_US: Whether to generate a presigned URL for the S3 object
|
||||
zh_Hans: 是否生成S3对象的预签名URL
|
||||
default: false
|
||||
form: form
|
||||
- name: presign_expiry
|
||||
type: number
|
||||
required: false
|
||||
label:
|
||||
en_US: Presigned URL expiration time
|
||||
zh_Hans: 预签名URL有效期
|
||||
human_description:
|
||||
en_US: Expiration time in seconds for the presigned URL
|
||||
zh_Hans: 预签名URL的有效期(秒)
|
||||
default: 3600
|
||||
form: form
|
||||
|
|
@ -210,7 +210,7 @@ class ApiTool(Tool):
|
|||
)
|
||||
return response
|
||||
else:
|
||||
raise ValueError(f"Invalid http method {self.method}")
|
||||
raise ValueError(f"Invalid http method {method}")
|
||||
|
||||
def _convert_body_property_any_of(
|
||||
self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10
|
||||
|
|
|
|||
|
|
@ -8,9 +8,10 @@ from mimetypes import guess_extension, guess_type
|
|||
from typing import Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from httpx import get
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import MessageFile
|
||||
|
|
@ -94,12 +95,11 @@ class ToolFileManager:
|
|||
) -> ToolFile:
|
||||
# try to download image
|
||||
try:
|
||||
response = get(file_url)
|
||||
response = ssrf_proxy.get(file_url)
|
||||
response.raise_for_status()
|
||||
blob = response.content
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to download file from {file_url}")
|
||||
raise
|
||||
except httpx.TimeoutException as e:
|
||||
raise ValueError(f"timeout when downloading file from {file_url}")
|
||||
|
||||
mimetype = guess_type(file_url)[0] or "octet/stream"
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
class BaseNodeError(Exception):
|
||||
class BaseNodeError(ValueError):
|
||||
"""Base class for node errors."""
|
||||
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
|
||||
|
|
@ -59,7 +59,7 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
)
|
||||
|
||||
# Transform result
|
||||
result = self._transform_result(result, self.node_data.outputs)
|
||||
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
|
||||
except (CodeExecutionError, CodeNodeError) as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
|
||||
|
|
@ -67,18 +67,17 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
|
||||
|
||||
def _check_string(self, value: str, variable: str) -> str:
|
||||
def _check_string(self, value: str | None, variable: str) -> str | None:
|
||||
"""
|
||||
Check string
|
||||
:param value: value
|
||||
:param variable: variable
|
||||
:return:
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, str):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a string")
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a string")
|
||||
|
||||
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
|
||||
raise OutputValidationError(
|
||||
|
|
@ -88,18 +87,17 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
|
||||
return value.replace("\x00", "")
|
||||
|
||||
def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]:
|
||||
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
|
||||
"""
|
||||
Check number
|
||||
:param value: value
|
||||
:param variable: variable
|
||||
:return:
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if not isinstance(value, int | float):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a number")
|
||||
raise OutputValidationError(f"Output variable `{variable}` must be a number")
|
||||
|
||||
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
|
||||
raise OutputValidationError(
|
||||
|
|
@ -118,14 +116,12 @@ class CodeNode(BaseNode[CodeNodeData]):
|
|||
return value
|
||||
|
||||
def _transform_result(
|
||||
self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1
|
||||
) -> dict:
|
||||
"""
|
||||
Transform result
|
||||
:param result: result
|
||||
:param output_schema: output schema
|
||||
:return:
|
||||
"""
|
||||
self,
|
||||
result: Mapping[str, Any],
|
||||
output_schema: Optional[dict[str, CodeNodeData.Output]],
|
||||
prefix: str = "",
|
||||
depth: int = 1,
|
||||
):
|
||||
if depth > dify_config.CODE_MAX_DEPTH:
|
||||
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import csv
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
|
|
@ -22,6 +23,8 @@ from models.workflow import WorkflowNodeExecutionStatus
|
|||
from .entities import DocumentExtractorNodeData
|
||||
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
|
||||
"""
|
||||
|
|
@ -177,10 +180,43 @@ def _extract_text_from_pdf(file_content: bytes) -> str:
|
|||
|
||||
|
||||
def _extract_text_from_doc(file_content: bytes) -> str:
|
||||
"""
|
||||
Extract text from a DOC/DOCX file.
|
||||
For now support only paragraph and table add more if needed
|
||||
"""
|
||||
try:
|
||||
doc_file = io.BytesIO(file_content)
|
||||
doc = docx.Document(doc_file)
|
||||
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
|
||||
text = []
|
||||
# Process paragraphs
|
||||
for paragraph in doc.paragraphs:
|
||||
if paragraph.text.strip():
|
||||
text.append(paragraph.text)
|
||||
|
||||
# Process tables
|
||||
for table in doc.tables:
|
||||
# Table header
|
||||
try:
|
||||
# table maybe cause errors so ignore it.
|
||||
if len(table.rows) > 0 and table.rows[0].cells is not None:
|
||||
# Check if any cell in the table has text
|
||||
has_content = False
|
||||
for row in table.rows:
|
||||
if any(cell.text.strip() for cell in row.cells):
|
||||
has_content = True
|
||||
break
|
||||
|
||||
if has_content:
|
||||
markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n"
|
||||
markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n"
|
||||
for row in table.rows[1:]:
|
||||
markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n"
|
||||
text.append(markdown_table)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
|
||||
continue
|
||||
|
||||
return "\n".join(text)
|
||||
except Exception as e:
|
||||
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e
|
||||
|
||||
|
|
|
|||
|
|
@ -179,6 +179,15 @@ class ParameterExtractorNode(LLMNode):
|
|||
error=str(e),
|
||||
metadata={},
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)},
|
||||
error=str(e),
|
||||
metadata={},
|
||||
)
|
||||
|
||||
error = None
|
||||
|
||||
|
|
|
|||
|
|
@ -154,8 +154,7 @@ class QuestionClassifierNode(LLMNode):
|
|||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
class VariableOperatorNodeError(Exception):
|
||||
class VariableOperatorNodeError(ValueError):
|
||||
"""Base error type, don't use directly."""
|
||||
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -116,8 +116,11 @@ def _build_from_local_file(
|
|||
tenant_id: str,
|
||||
transfer_method: FileTransferMethod,
|
||||
) -> File:
|
||||
upload_file_id = mapping.get("upload_file_id")
|
||||
if not upload_file_id:
|
||||
raise ValueError("Invalid upload file id")
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == mapping.get("upload_file_id"),
|
||||
UploadFile.id == upload_file_id,
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ def parse_json_markdown(json_string: str) -> dict:
|
|||
extracted_content = json_string[start_index:end_index].strip()
|
||||
parsed = json.loads(extracted_content)
|
||||
else:
|
||||
raise Exception("Could not find JSON block in the output.")
|
||||
raise ValueError("could not find json block in the output.")
|
||||
|
||||
return parsed
|
||||
|
||||
|
|
@ -36,10 +36,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict:
|
|||
try:
|
||||
json_obj = parse_json_markdown(text)
|
||||
except json.JSONDecodeError as e:
|
||||
raise OutputParserError(f"Got invalid JSON object. Error: {e}")
|
||||
raise OutputParserError(f"got invalid json object. error: {e}")
|
||||
for key in expected_keys:
|
||||
if key not in json_obj:
|
||||
raise OutputParserError(
|
||||
f"Got invalid return object. Expected key `{key}` to be present, but got {json_obj}"
|
||||
f"got invalid return object. expected key `{key}` to be present, but got {json_obj}"
|
||||
)
|
||||
return json_obj
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import enum
|
|||
import json
|
||||
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy import func
|
||||
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
|
@ -30,11 +31,11 @@ class Account(UserMixin, db.Model):
|
|||
timezone = db.Column(db.String(255))
|
||||
last_login_at = db.Column(db.DateTime)
|
||||
last_login_ip = db.Column(db.String(255))
|
||||
last_active_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
last_active_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying"))
|
||||
initialized_at = db.Column(db.DateTime)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def is_password_set(self):
|
||||
|
|
@ -187,8 +188,8 @@ class Tenant(db.Model):
|
|||
plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying"))
|
||||
status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
|
||||
custom_config = db.Column(db.Text)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
def get_accounts(self) -> list[Account]:
|
||||
return (
|
||||
|
|
@ -228,8 +229,8 @@ class TenantAccountJoin(db.Model):
|
|||
current = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
role = db.Column(db.String(16), nullable=False, server_default="normal")
|
||||
invited_by = db.Column(StringUUID, nullable=True)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class AccountIntegrate(db.Model):
|
||||
|
|
@ -245,8 +246,8 @@ class AccountIntegrate(db.Model):
|
|||
provider = db.Column(db.String(16), nullable=False)
|
||||
open_id = db.Column(db.String(255), nullable=False)
|
||||
encrypted_token = db.Column(db.String(255), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class InvitationCode(db.Model):
|
||||
|
|
@ -265,4 +266,4 @@ class InvitationCode(db.Model):
|
|||
used_by_tenant_id = db.Column(StringUUID)
|
||||
used_by_account_id = db.Column(StringUUID)
|
||||
deprecated_at = db.Column(db.DateTime)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
import enum
|
||||
|
||||
from sqlalchemy import func
|
||||
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
|
@ -23,4 +25,4 @@ class APIBasedExtension(db.Model):
|
|||
name = db.Column(db.String(255), nullable=False)
|
||||
api_endpoint = db.Column(db.String(255), nullable=False)
|
||||
api_key = db.Column(db.Text, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
|
|||
|
|
@ -50,9 +50,9 @@ class Dataset(db.Model):
|
|||
indexing_technique = db.Column(db.String(255), nullable=True)
|
||||
index_struct = db.Column(db.Text, nullable=True)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
embedding_model = db.Column(db.String(255), nullable=True)
|
||||
embedding_model_provider = db.Column(db.String(255), nullable=True)
|
||||
collection_binding_id = db.Column(StringUUID, nullable=True)
|
||||
|
|
@ -212,7 +212,7 @@ class DatasetProcessRule(db.Model):
|
|||
mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
|
||||
rules = db.Column(db.Text, nullable=True)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
MODES = ["automatic", "custom"]
|
||||
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
|
||||
|
|
@ -264,7 +264,7 @@ class Document(db.Model):
|
|||
created_from = db.Column(db.String(255), nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_api_request_id = db.Column(StringUUID, nullable=True)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
# start processing
|
||||
processing_started_at = db.Column(db.DateTime, nullable=True)
|
||||
|
|
@ -303,7 +303,7 @@ class Document(db.Model):
|
|||
archived_reason = db.Column(db.String(255), nullable=True)
|
||||
archived_by = db.Column(StringUUID, nullable=True)
|
||||
archived_at = db.Column(db.DateTime, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
doc_type = db.Column(db.String(40), nullable=True)
|
||||
doc_metadata = db.Column(db.JSON, nullable=True)
|
||||
doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
|
||||
|
|
@ -527,9 +527,9 @@ class DocumentSegment(db.Model):
|
|||
disabled_by = db.Column(StringUUID, nullable=True)
|
||||
status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying"))
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
indexing_at = db.Column(db.DateTime, nullable=True)
|
||||
completed_at = db.Column(db.DateTime, nullable=True)
|
||||
error = db.Column(db.Text, nullable=True)
|
||||
|
|
@ -697,7 +697,7 @@ class Embedding(db.Model):
|
|||
)
|
||||
hash = db.Column(db.String(64), nullable=False)
|
||||
embedding = db.Column(db.LargeBinary, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying"))
|
||||
|
||||
def set_embedding(self, embedding_data: list[float]):
|
||||
|
|
@ -719,7 +719,7 @@ class DatasetCollectionBinding(db.Model):
|
|||
model_name = db.Column(db.String(255), nullable=False)
|
||||
type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
|
||||
collection_name = db.Column(db.String(64), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class TidbAuthBinding(db.Model):
|
||||
|
|
@ -739,7 +739,7 @@ class TidbAuthBinding(db.Model):
|
|||
status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING"))
|
||||
account = db.Column(db.String(255), nullable=False)
|
||||
password = db.Column(db.String(255), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class Whitelist(db.Model):
|
||||
|
|
@ -751,7 +751,7 @@ class Whitelist(db.Model):
|
|||
id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
|
||||
tenant_id = db.Column(StringUUID, nullable=True)
|
||||
category = db.Column(db.String(255), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class DatasetPermission(db.Model):
|
||||
|
|
@ -768,7 +768,7 @@ class DatasetPermission(db.Model):
|
|||
account_id = db.Column(StringUUID, nullable=False)
|
||||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class ExternalKnowledgeApis(db.Model):
|
||||
|
|
@ -785,9 +785,9 @@ class ExternalKnowledgeApis(db.Model):
|
|||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
settings = db.Column(db.Text, nullable=True)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
|
|
@ -840,6 +840,6 @@ class ExternalKnowledgeBindings(db.Model):
|
|||
dataset_id = db.Column(StringUUID, nullable=False)
|
||||
external_knowledge_id = db.Column(db.Text, nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class DifySetup(db.Model):
|
|||
__table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
|
||||
|
||||
version = db.Column(db.String(255), nullable=False)
|
||||
setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
setup_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class AppMode(StrEnum):
|
||||
|
|
@ -85,9 +85,9 @@ class App(db.Model):
|
|||
tracing = db.Column(db.Text, nullable=True)
|
||||
max_active_requests = db.Column(db.Integer, nullable=True)
|
||||
created_by = db.Column(StringUUID, nullable=True)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
|
||||
@property
|
||||
|
|
@ -226,9 +226,9 @@ class AppModelConfig(db.Model):
|
|||
model_id = db.Column(db.String(255), nullable=True)
|
||||
configs = db.Column(db.JSON, nullable=True)
|
||||
created_by = db.Column(StringUUID, nullable=True)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
opening_statement = db.Column(db.Text)
|
||||
suggested_questions = db.Column(db.Text)
|
||||
suggested_questions_after_answer = db.Column(db.Text)
|
||||
|
|
@ -482,8 +482,8 @@ class RecommendedApp(db.Model):
|
|||
is_listed = db.Column(db.Boolean, nullable=False, default=True)
|
||||
install_count = db.Column(db.Integer, nullable=False, default=0)
|
||||
language = db.Column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
|
|
@ -507,7 +507,7 @@ class InstalledApp(db.Model):
|
|||
position = db.Column(db.Integer, nullable=False, default=0)
|
||||
is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
last_used_at = db.Column(db.DateTime, nullable=True)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
|
|
@ -548,8 +548,8 @@ class Conversation(db.Model):
|
|||
read_at = db.Column(db.DateTime)
|
||||
read_account_id = db.Column(StringUUID)
|
||||
dialogue_count: Mapped[int] = mapped_column(default=0)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all")
|
||||
message_annotations = db.relationship(
|
||||
|
|
@ -700,8 +700,10 @@ class Conversation(db.Model):
|
|||
def status_count(self):
|
||||
messages = db.session.query(Message).filter(Message.conversation_id == self.id).all()
|
||||
status_counts = {
|
||||
WorkflowRunStatus.RUNNING: 0,
|
||||
WorkflowRunStatus.SUCCEEDED: 0,
|
||||
WorkflowRunStatus.FAILED: 0,
|
||||
WorkflowRunStatus.STOPPED: 0,
|
||||
WorkflowRunStatus.PARTIAL_SUCCESSED: 0,
|
||||
}
|
||||
|
||||
|
|
@ -789,8 +791,8 @@ class Message(db.Model):
|
|||
from_source = db.Column(db.String(255), nullable=False)
|
||||
from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID)
|
||||
from_account_id: Mapped[Optional[str]] = db.Column(StringUUID)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
workflow_run_id = db.Column(StringUUID)
|
||||
|
||||
|
|
@ -1115,8 +1117,8 @@ class MessageFeedback(db.Model):
|
|||
from_source = db.Column(db.String(255), nullable=False)
|
||||
from_end_user_id = db.Column(StringUUID)
|
||||
from_account_id = db.Column(StringUUID)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def from_account(self):
|
||||
|
|
@ -1162,9 +1164,7 @@ class MessageFile(db.Model):
|
|||
upload_file_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True)
|
||||
created_by_role: Mapped[str] = db.Column(db.String(255), nullable=False)
|
||||
created_by: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = db.Column(
|
||||
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class MessageAnnotation(db.Model):
|
||||
|
|
@ -1184,8 +1184,8 @@ class MessageAnnotation(db.Model):
|
|||
content = db.Column(db.Text, nullable=False)
|
||||
hit_count = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
|
||||
account_id = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def account(self):
|
||||
|
|
@ -1214,7 +1214,7 @@ class AppAnnotationHitHistory(db.Model):
|
|||
source = db.Column(db.Text, nullable=False)
|
||||
question = db.Column(db.Text, nullable=False)
|
||||
account_id = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
score = db.Column(Float, nullable=False, server_default=db.text("0"))
|
||||
message_id = db.Column(StringUUID, nullable=False)
|
||||
annotation_question = db.Column(db.Text, nullable=False)
|
||||
|
|
@ -1248,9 +1248,9 @@ class AppAnnotationSetting(db.Model):
|
|||
score_threshold = db.Column(Float, nullable=False, server_default=db.text("0"))
|
||||
collection_binding_id = db.Column(StringUUID, nullable=False)
|
||||
created_user_id = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_user_id = db.Column(StringUUID, nullable=False)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def created_account(self):
|
||||
|
|
@ -1296,9 +1296,9 @@ class OperationLog(db.Model):
|
|||
account_id = db.Column(StringUUID, nullable=False)
|
||||
action = db.Column(db.String(255), nullable=False)
|
||||
content = db.Column(db.JSON)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_ip = db.Column(db.String(255), nullable=False)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class EndUser(UserMixin, db.Model):
|
||||
|
|
@ -1317,8 +1317,8 @@ class EndUser(UserMixin, db.Model):
|
|||
name = db.Column(db.String(255))
|
||||
is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
session_id = db.Column(db.String(255), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class Site(db.Model):
|
||||
|
|
@ -1349,9 +1349,9 @@ class Site(db.Model):
|
|||
prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
|
||||
created_by = db.Column(StringUUID, nullable=True)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by = db.Column(StringUUID, nullable=True)
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
code = db.Column(db.String(255))
|
||||
|
||||
@property
|
||||
|
|
@ -1393,7 +1393,7 @@ class ApiToken(db.Model):
|
|||
type = db.Column(db.String(16), nullable=False)
|
||||
token = db.Column(db.String(255), nullable=False)
|
||||
last_used_at = db.Column(db.DateTime, nullable=True)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@staticmethod
|
||||
def generate_api_key(prefix, n):
|
||||
|
|
@ -1424,9 +1424,7 @@ class UploadFile(db.Model):
|
|||
db.String(255), nullable=False, server_default=db.text("'account'::character varying")
|
||||
)
|
||||
created_by: Mapped[str] = db.Column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = db.Column(
|
||||
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
used: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
used_by: Mapped[str | None] = db.Column(StringUUID, nullable=True)
|
||||
used_at: Mapped[datetime | None] = db.Column(db.DateTime, nullable=True)
|
||||
|
|
@ -1483,7 +1481,7 @@ class ApiRequest(db.Model):
|
|||
request = db.Column(db.Text, nullable=True)
|
||||
response = db.Column(db.Text, nullable=True)
|
||||
ip = db.Column(db.String(255), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class MessageChain(db.Model):
|
||||
|
|
@ -1655,7 +1653,7 @@ class Tag(db.Model):
|
|||
type = db.Column(db.String(16), nullable=False)
|
||||
name = db.Column(db.String(255), nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class TagBinding(db.Model):
|
||||
|
|
@ -1671,7 +1669,7 @@ class TagBinding(db.Model):
|
|||
tag_id = db.Column(StringUUID, nullable=True)
|
||||
target_id = db.Column(StringUUID, nullable=True)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class TraceAppConfig(db.Model):
|
||||
|
|
@ -1685,8 +1683,10 @@ class TraceAppConfig(db.Model):
|
|||
app_id = db.Column(StringUUID, nullable=False)
|
||||
tracing_provider = db.Column(db.String(255), nullable=True)
|
||||
tracing_config = db.Column(db.JSON, nullable=True)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.now())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.now(), onupdate=func.now())
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(
|
||||
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
is_active = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from enum import Enum
|
||||
|
||||
from sqlalchemy import func
|
||||
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
|
@ -60,8 +62,8 @@ class Provider(db.Model):
|
|||
quota_limit = db.Column(db.BigInteger, nullable=True)
|
||||
quota_used = db.Column(db.BigInteger, default=0)
|
||||
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
|
|
@ -108,8 +110,8 @@ class ProviderModel(db.Model):
|
|||
model_type = db.Column(db.String(40), nullable=False)
|
||||
encrypted_config = db.Column(db.Text, nullable=True)
|
||||
is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class TenantDefaultModel(db.Model):
|
||||
|
|
@ -124,8 +126,8 @@ class TenantDefaultModel(db.Model):
|
|||
provider_name = db.Column(db.String(255), nullable=False)
|
||||
model_name = db.Column(db.String(255), nullable=False)
|
||||
model_type = db.Column(db.String(40), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class TenantPreferredModelProvider(db.Model):
|
||||
|
|
@ -139,8 +141,8 @@ class TenantPreferredModelProvider(db.Model):
|
|||
tenant_id = db.Column(StringUUID, nullable=False)
|
||||
provider_name = db.Column(db.String(255), nullable=False)
|
||||
preferred_provider_type = db.Column(db.String(40), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class ProviderOrder(db.Model):
|
||||
|
|
@ -164,8 +166,8 @@ class ProviderOrder(db.Model):
|
|||
paid_at = db.Column(db.DateTime)
|
||||
pay_failed_at = db.Column(db.DateTime)
|
||||
refunded_at = db.Column(db.DateTime)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class ProviderModelSetting(db.Model):
|
||||
|
|
@ -186,8 +188,8 @@ class ProviderModelSetting(db.Model):
|
|||
model_type = db.Column(db.String(40), nullable=False)
|
||||
enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class LoadBalancingModelConfig(db.Model):
|
||||
|
|
@ -209,5 +211,5 @@ class LoadBalancingModelConfig(db.Model):
|
|||
name = db.Column(db.String(255), nullable=False)
|
||||
encrypted_config = db.Column(db.Text, nullable=True)
|
||||
enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import json
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
from .engine import db
|
||||
|
|
@ -19,8 +20,8 @@ class DataSourceOauthBinding(db.Model):
|
|||
access_token = db.Column(db.String(255), nullable=False)
|
||||
provider = db.Column(db.String(255), nullable=False)
|
||||
source_info = db.Column(JSONB, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))
|
||||
|
||||
|
||||
|
|
@ -37,8 +38,8 @@ class DataSourceApiKeyAuthBinding(db.Model):
|
|||
category = db.Column(db.String(255), nullable=False)
|
||||
provider = db.Column(db.String(255), nullable=False)
|
||||
credentials = db.Column(db.Text, nullable=True) # JSON
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))
|
||||
|
||||
def to_dict(self):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import json
|
|||
from typing import Optional
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import ForeignKey, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
|
@ -36,8 +36,8 @@ class BuiltinToolProvider(db.Model):
|
|||
provider = db.Column(db.String(40), nullable=False)
|
||||
# credential of the tool provider
|
||||
encrypted_credentials = db.Column(db.Text, nullable=True)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def credentials(self) -> dict:
|
||||
|
|
@ -74,8 +74,8 @@ class PublishedAppTool(db.Model):
|
|||
tool_name = db.Column(db.String(40), nullable=False)
|
||||
# author
|
||||
author = db.Column(db.String(40), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def description_i18n(self) -> I18nObject:
|
||||
|
|
@ -120,8 +120,8 @@ class ApiToolProvider(db.Model):
|
|||
# custom_disclaimer
|
||||
custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="")
|
||||
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def schema_type(self) -> ApiProviderSchemaType:
|
||||
|
|
@ -198,8 +198,8 @@ class WorkflowToolProvider(db.Model):
|
|||
# privacy policy
|
||||
privacy_policy = db.Column(db.String(255), nullable=True, server_default="")
|
||||
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def user(self) -> Account | None:
|
||||
|
|
@ -251,8 +251,8 @@ class ToolModelInvoke(db.Model):
|
|||
provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0"))
|
||||
total_price = db.Column(db.Numeric(10, 7))
|
||||
currency = db.Column(db.String(255), nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class ToolConversationVariables(db.Model):
|
||||
|
|
@ -278,8 +278,8 @@ class ToolConversationVariables(db.Model):
|
|||
# variables pool
|
||||
variables_str = db.Column(db.Text, nullable=False)
|
||||
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def variables(self) -> dict:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .engine import db
|
||||
from .model import Message
|
||||
from .types import StringUUID
|
||||
|
|
@ -15,7 +18,7 @@ class SavedMessage(db.Model):
|
|||
message_id = db.Column(StringUUID, nullable=False)
|
||||
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def message(self):
|
||||
|
|
@ -31,7 +34,7 @@ class PinnedConversation(db.Model):
|
|||
|
||||
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
|
||||
app_id = db.Column(StringUUID, nullable=False)
|
||||
conversation_id = db.Column(StringUUID, nullable=False)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID)
|
||||
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
|
|
|||
|
|
@ -103,12 +103,13 @@ class Workflow(db.Model):
|
|||
graph: Mapped[str] = mapped_column(sa.Text)
|
||||
_features: Mapped[str] = mapped_column("features", sa.TEXT)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by: Mapped[Optional[str]] = mapped_column(StringUUID)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime, nullable=False, default=datetime.now(tz=UTC), server_onupdate=func.current_timestamp()
|
||||
db.DateTime,
|
||||
nullable=False,
|
||||
default=datetime.now(UTC).replace(tzinfo=None),
|
||||
server_onupdate=func.current_timestamp(),
|
||||
)
|
||||
_environment_variables: Mapped[str] = mapped_column(
|
||||
"environment_variables", db.Text, nullable=False, server_default="{}"
|
||||
|
|
@ -406,7 +407,7 @@ class WorkflowRun(db.Model):
|
|||
total_steps = db.Column(db.Integer, server_default=db.text("0"))
|
||||
created_by_role = db.Column(db.String(255), nullable=False) # account, end_user
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
finished_at = db.Column(db.DateTime)
|
||||
exceptions_count = db.Column(db.Integer, server_default=db.text("0"))
|
||||
|
||||
|
|
@ -636,7 +637,7 @@ class WorkflowNodeExecution(db.Model):
|
|||
error = db.Column(db.Text)
|
||||
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
|
||||
execution_metadata = db.Column(db.Text)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
created_by_role = db.Column(db.String(255), nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
finished_at = db.Column(db.DateTime)
|
||||
|
|
@ -754,7 +755,7 @@ class WorkflowAppLog(db.Model):
|
|||
created_from = db.Column(db.String(255), nullable=False)
|
||||
created_by_role = db.Column(db.String(255), nullable=False)
|
||||
created_by = db.Column(StringUUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
@property
|
||||
def workflow_run(self):
|
||||
|
|
@ -780,7 +781,7 @@ class ConversationVariable(db.Model):
|
|||
conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True)
|
||||
app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True)
|
||||
data = db.Column(db.Text, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text("CURRENT_TIMESTAMP(0)"))
|
||||
created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=func.current_timestamp())
|
||||
updated_at = db.Column(
|
||||
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
|
|
|||
|
|
@ -340,7 +340,10 @@ class AppDslService:
|
|||
) -> App:
|
||||
"""Create a new app or update an existing one."""
|
||||
app_data = data.get("app", {})
|
||||
app_mode = AppMode(app_data["mode"])
|
||||
app_mode = app_data.get("mode")
|
||||
if not app_mode:
|
||||
raise ValueError("loss app mode")
|
||||
app_mode = AppMode(app_mode)
|
||||
|
||||
# Set icon type
|
||||
icon_type_value = icon_type or app_data.get("icon_type")
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import asc, desc, or_
|
||||
from sqlalchemy import asc, desc, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
|
|
@ -18,19 +19,21 @@ class ConversationService:
|
|||
@classmethod
|
||||
def pagination_by_last_id(
|
||||
cls,
|
||||
*,
|
||||
session: Session,
|
||||
app_model: App,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str],
|
||||
limit: int,
|
||||
invoke_from: InvokeFrom,
|
||||
include_ids: Optional[list] = None,
|
||||
exclude_ids: Optional[list] = None,
|
||||
include_ids: Optional[Sequence[str]] = None,
|
||||
exclude_ids: Optional[Sequence[str]] = None,
|
||||
sort_by: str = "-updated_at",
|
||||
) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
|
||||
base_query = db.session.query(Conversation).filter(
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.is_deleted == False,
|
||||
Conversation.app_id == app_model.id,
|
||||
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||
|
|
@ -38,37 +41,40 @@ class ConversationService:
|
|||
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value),
|
||||
)
|
||||
|
||||
if include_ids is not None:
|
||||
base_query = base_query.filter(Conversation.id.in_(include_ids))
|
||||
|
||||
stmt = stmt.where(Conversation.id.in_(include_ids))
|
||||
if exclude_ids is not None:
|
||||
base_query = base_query.filter(~Conversation.id.in_(exclude_ids))
|
||||
stmt = stmt.where(~Conversation.id.in_(exclude_ids))
|
||||
|
||||
# define sort fields and directions
|
||||
sort_field, sort_direction = cls._get_sort_params(sort_by)
|
||||
|
||||
if last_id:
|
||||
last_conversation = base_query.filter(Conversation.id == last_id).first()
|
||||
last_conversation = session.scalar(stmt.where(Conversation.id == last_id))
|
||||
if not last_conversation:
|
||||
raise LastConversationNotExistsError()
|
||||
|
||||
# build filters based on sorting
|
||||
filter_condition = cls._build_filter_condition(sort_field, sort_direction, last_conversation)
|
||||
base_query = base_query.filter(filter_condition)
|
||||
|
||||
base_query = base_query.order_by(sort_direction(getattr(Conversation, sort_field)))
|
||||
|
||||
conversations = base_query.limit(limit).all()
|
||||
filter_condition = cls._build_filter_condition(
|
||||
sort_field=sort_field,
|
||||
sort_direction=sort_direction,
|
||||
reference_conversation=last_conversation,
|
||||
)
|
||||
stmt = stmt.where(filter_condition)
|
||||
query_stmt = stmt.order_by(sort_direction(getattr(Conversation, sort_field))).limit(limit)
|
||||
conversations = session.scalars(query_stmt).all()
|
||||
|
||||
has_more = False
|
||||
if len(conversations) == limit:
|
||||
current_page_last_conversation = conversations[-1]
|
||||
rest_filter_condition = cls._build_filter_condition(
|
||||
sort_field, sort_direction, current_page_last_conversation, is_next_page=True
|
||||
sort_field=sort_field,
|
||||
sort_direction=sort_direction,
|
||||
reference_conversation=current_page_last_conversation,
|
||||
)
|
||||
rest_count = base_query.filter(rest_filter_condition).count()
|
||||
|
||||
count_stmt = stmt.where(rest_filter_condition)
|
||||
count_stmt = select(func.count()).select_from(count_stmt.subquery())
|
||||
rest_count = session.scalar(count_stmt) or 0
|
||||
if rest_count > 0:
|
||||
has_more = True
|
||||
|
||||
|
|
@ -81,11 +87,9 @@ class ConversationService:
|
|||
return sort_by, asc
|
||||
|
||||
@classmethod
|
||||
def _build_filter_condition(
|
||||
cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation, is_next_page: bool = False
|
||||
):
|
||||
def _build_filter_condition(cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation):
|
||||
field_value = getattr(reference_conversation, sort_field)
|
||||
if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page):
|
||||
if sort_direction == desc:
|
||||
return getattr(Conversation, sort_field) < field_value
|
||||
else:
|
||||
return getattr(Conversation, sort_field) > field_value
|
||||
|
|
|
|||
|
|
@ -151,7 +151,12 @@ class MessageService:
|
|||
|
||||
@classmethod
|
||||
def create_feedback(
|
||||
cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], rating: Optional[str]
|
||||
cls,
|
||||
app_model: App,
|
||||
message_id: str,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
rating: Optional[str],
|
||||
content: Optional[str],
|
||||
) -> MessageFeedback:
|
||||
if not user:
|
||||
raise ValueError("user cannot be None")
|
||||
|
|
@ -164,6 +169,7 @@ class MessageService:
|
|||
db.session.delete(feedback)
|
||||
elif rating and feedback:
|
||||
feedback.rating = rating
|
||||
feedback.content = content
|
||||
elif not rating and not feedback:
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
else:
|
||||
|
|
@ -172,6 +178,7 @@ class MessageService:
|
|||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=rating,
|
||||
content=content,
|
||||
from_source=("user" if isinstance(user, EndUser) else "admin"),
|
||||
from_end_user_id=(user.id if isinstance(user, EndUser) else None),
|
||||
from_account_id=(user.id if isinstance(user, Account) else None),
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@ import json
|
|||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
|
@ -32,7 +35,7 @@ class BuiltinToolManageService:
|
|||
tenant_id=tenant_id, provider_controller=provider_controller
|
||||
)
|
||||
# check if user has added the provider
|
||||
builtin_provider: BuiltinToolProvider = (
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
|
|
@ -71,19 +74,18 @@ class BuiltinToolManageService:
|
|||
return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()])
|
||||
|
||||
@staticmethod
|
||||
def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
|
||||
def update_builtin_tool_provider(
|
||||
session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict
|
||||
):
|
||||
"""
|
||||
update builtin tool provider
|
||||
"""
|
||||
# get if the provider exists
|
||||
provider: BuiltinToolProvider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
)
|
||||
.first()
|
||||
stmt = select(BuiltinToolProvider).where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
)
|
||||
provider = session.scalar(stmt)
|
||||
|
||||
try:
|
||||
# get provider
|
||||
|
|
@ -115,13 +117,10 @@ class BuiltinToolManageService:
|
|||
encrypted_credentials=json.dumps(credentials),
|
||||
)
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
session.add(provider)
|
||||
|
||||
else:
|
||||
provider.encrypted_credentials = json.dumps(credentials)
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
|
||||
# delete cache
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
|
@ -129,15 +128,15 @@ class BuiltinToolManageService:
|
|||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str):
|
||||
def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str):
|
||||
"""
|
||||
get builtin tool provider credentials
|
||||
"""
|
||||
provider: BuiltinToolProvider = (
|
||||
provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == provider,
|
||||
BuiltinToolProvider.provider == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
|
@ -156,7 +155,7 @@ class BuiltinToolManageService:
|
|||
"""
|
||||
delete tool provider
|
||||
"""
|
||||
provider: BuiltinToolProvider = (
|
||||
provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
|
|
@ -13,6 +16,8 @@ class WebConversationService:
|
|||
@classmethod
|
||||
def pagination_by_last_id(
|
||||
cls,
|
||||
*,
|
||||
session: Session,
|
||||
app_model: App,
|
||||
user: Optional[Union[Account, EndUser]],
|
||||
last_id: Optional[str],
|
||||
|
|
@ -23,24 +28,25 @@ class WebConversationService:
|
|||
) -> InfiniteScrollPagination:
|
||||
include_ids = None
|
||||
exclude_ids = None
|
||||
if pinned is not None:
|
||||
pinned_conversations = (
|
||||
db.session.query(PinnedConversation)
|
||||
.filter(
|
||||
if pinned is not None and user:
|
||||
stmt = (
|
||||
select(PinnedConversation.conversation_id)
|
||||
.where(
|
||||
PinnedConversation.app_id == app_model.id,
|
||||
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
|
||||
PinnedConversation.created_by == user.id,
|
||||
)
|
||||
.order_by(PinnedConversation.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations]
|
||||
pinned_conversation_ids = session.scalars(stmt).all()
|
||||
|
||||
if pinned:
|
||||
include_ids = pinned_conversation_ids
|
||||
else:
|
||||
exclude_ids = pinned_conversation_ids
|
||||
|
||||
return ConversationService.pagination_by_last_id(
|
||||
session=session,
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
last_id=last_id,
|
||||
|
|
|
|||
|
|
@ -923,6 +923,3 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false
|
|||
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
|
||||
MAX_SUBMIT_COUNT=100
|
||||
|
||||
# Proxy
|
||||
HTTP_PROXY=
|
||||
HTTPS_PROXY=
|
||||
|
|
|
|||
|
|
@ -386,8 +386,6 @@ x-shared-env: &shared-api-worker-env
|
|||
CSP_WHITELIST: ${CSP_WHITELIST:-}
|
||||
CREATE_TIDB_SERVICE_JOB_ENABLED: ${CREATE_TIDB_SERVICE_JOB_ENABLED:-false}
|
||||
MAX_SUBMIT_COUNT: ${MAX_SUBMIT_COUNT:-100}
|
||||
HTTP_PROXY: ${HTTP_PROXY:-}
|
||||
HTTPS_PROXY: ${HTTPS_PROXY:-}
|
||||
|
||||
services:
|
||||
# API service
|
||||
|
|
|
|||
|
|
@ -64,6 +64,12 @@ const WorkflowProcessItem = ({
|
|||
setShowMessageLogModal(true)
|
||||
}, [item, setCurrentLogItem, setCurrentLogModalActiveTab, setShowMessageLogModal])
|
||||
|
||||
const showRetryDetail = useCallback(() => {
|
||||
setCurrentLogItem(item)
|
||||
setCurrentLogModalActiveTab('TRACING')
|
||||
setShowMessageLogModal(true)
|
||||
}, [item, setCurrentLogItem, setCurrentLogModalActiveTab, setShowMessageLogModal])
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
|
|
@ -105,6 +111,7 @@ const WorkflowProcessItem = ({
|
|||
<TracingPanel
|
||||
list={data.tracing}
|
||||
onShowIterationDetail={showIterationDetail}
|
||||
onShowRetryDetail={showRetryDetail}
|
||||
hideNodeInfo={hideInfo}
|
||||
hideNodeProcessDetail={hideProcessDetail}
|
||||
/>
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ export type InputProps = {
|
|||
destructive?: boolean
|
||||
wrapperClassName?: string
|
||||
styleCss?: CSSProperties
|
||||
unit?: string
|
||||
} & React.InputHTMLAttributes<HTMLInputElement> & VariantProps<typeof inputVariants>
|
||||
|
||||
const Input = ({
|
||||
|
|
@ -43,6 +44,7 @@ const Input = ({
|
|||
value,
|
||||
placeholder,
|
||||
onChange,
|
||||
unit,
|
||||
...props
|
||||
}: InputProps) => {
|
||||
const { t } = useTranslation()
|
||||
|
|
@ -80,6 +82,13 @@ const Input = ({
|
|||
{destructive && (
|
||||
<RiErrorWarningLine className='absolute right-2 top-1/2 -translate-y-1/2 w-4 h-4 text-text-destructive-secondary' />
|
||||
)}
|
||||
{
|
||||
unit && (
|
||||
<div className='absolute right-2 top-1/2 -translate-y-1/2 system-sm-regular text-text-tertiary'>
|
||||
{unit}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ const SearchInput: FC<SearchInputProps> = ({
|
|||
const { t } = useTranslation()
|
||||
const [focus, setFocus] = useState<boolean>(false)
|
||||
const isComposing = useRef<boolean>(false)
|
||||
const [internalValue, setInternalValue] = useState<string>(value)
|
||||
|
||||
return (
|
||||
<div className={cn(
|
||||
|
|
@ -45,16 +46,18 @@ const SearchInput: FC<SearchInputProps> = ({
|
|||
white && '!bg-white hover:!bg-white group-hover:!bg-white placeholder:!text-gray-400',
|
||||
)}
|
||||
placeholder={placeholder || t('common.operation.search')!}
|
||||
value={value}
|
||||
value={internalValue}
|
||||
onChange={(e) => {
|
||||
setInternalValue(e.target.value)
|
||||
if (!isComposing.current)
|
||||
onChange(e.target.value)
|
||||
}}
|
||||
onCompositionStart={() => {
|
||||
isComposing.current = true
|
||||
}}
|
||||
onCompositionEnd={() => {
|
||||
onCompositionEnd={(e) => {
|
||||
isComposing.current = false
|
||||
onChange(e.data)
|
||||
}}
|
||||
onFocus={() => setFocus(true)}
|
||||
onBlur={() => setFocus(false)}
|
||||
|
|
@ -63,7 +66,10 @@ const SearchInput: FC<SearchInputProps> = ({
|
|||
{value && (
|
||||
<div
|
||||
className='shrink-0 flex items-center justify-center w-4 h-4 cursor-pointer group/clear'
|
||||
onClick={() => onChange('')}
|
||||
onClick={() => {
|
||||
onChange('')
|
||||
setInternalValue('')
|
||||
}}
|
||||
>
|
||||
<XCircle className='w-3.5 h-3.5 text-gray-400 group-hover/clear:text-gray-600' />
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -346,6 +346,9 @@ The text generation application offers non-session support and is ideal for tran
|
|||
<Property name='user' type='string' key='user'>
|
||||
User identifier, defined by the developer's rules, must be unique within the application.
|
||||
</Property>
|
||||
<Property name='content' type='string' key='content'>
|
||||
The specific content of message feedback.
|
||||
</Property>
|
||||
</Properties>
|
||||
|
||||
### Response
|
||||
|
|
@ -353,7 +356,7 @@ The text generation application offers non-session support and is ideal for tran
|
|||
</Col>
|
||||
<Col sticky>
|
||||
|
||||
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123"\n}'`}>
|
||||
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123",\n "content": "message feedback information"\n}'`}>
|
||||
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \
|
||||
|
|
@ -361,7 +364,8 @@ The text generation application offers non-session support and is ideal for tran
|
|||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"rating": "like",
|
||||
"user": "abc-123"
|
||||
"user": "abc-123",
|
||||
"content": "message feedback information"
|
||||
}'
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -345,6 +345,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
<Property name='user' type='string' key='user'>
|
||||
開発者のルールで定義されたユーザー識別子。アプリケーション内で一意である必要があります。
|
||||
</Property>
|
||||
<Property name='content' type='string' key='content'>
|
||||
メッセージのフィードバックです。
|
||||
</Property>
|
||||
</Properties>
|
||||
|
||||
### レスポンス
|
||||
|
|
@ -352,7 +355,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
</Col>
|
||||
<Col sticky>
|
||||
|
||||
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123"\n}'`}>
|
||||
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123",\n "content": "message feedback information"\n}'`}>
|
||||
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \
|
||||
|
|
@ -360,7 +363,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"rating": "like",
|
||||
"user": "abc-123"
|
||||
"user": "abc-123",
|
||||
"content": "message feedback information"
|
||||
}'
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -320,6 +320,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
|||
<Property name='user' type='string' key='user'>
|
||||
用户标识,由开发者定义规则,需保证用户标识在应用内唯一。
|
||||
</Property>
|
||||
<Property name='content' type='string' key='content'>
|
||||
消息反馈的具体信息。
|
||||
</Property>
|
||||
</Properties>
|
||||
|
||||
### Response
|
||||
|
|
@ -327,7 +330,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
|||
</Col>
|
||||
<Col sticky>
|
||||
|
||||
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123"\n}'`}>
|
||||
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123",\n "content": "message feedback information"\n}'`}>
|
||||
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \
|
||||
|
|
@ -335,7 +338,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
|||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"rating": "like",
|
||||
"user": "abc-123"
|
||||
"user": "abc-123",
|
||||
"content": "message feedback information"
|
||||
}'
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -444,6 +444,9 @@ Chat applications support session persistence, allowing previous chat history to
|
|||
<Property name='user' type='string' key='user'>
|
||||
User identifier, defined by the developer's rules, must be unique within the application.
|
||||
</Property>
|
||||
<Property name='content' type='string' key='content'>
|
||||
The specific content of message feedback.
|
||||
</Property>
|
||||
</Properties>
|
||||
|
||||
### Response
|
||||
|
|
@ -451,7 +454,7 @@ Chat applications support session persistence, allowing previous chat history to
|
|||
</Col>
|
||||
<Col sticky>
|
||||
|
||||
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123"\n}'`}>
|
||||
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123",\n "content": "message feedback information"\n}'`}>
|
||||
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \
|
||||
|
|
@ -459,7 +462,8 @@ Chat applications support session persistence, allowing previous chat history to
|
|||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"rating": "like",
|
||||
"user": "abc-123"
|
||||
"user": "abc-123",
|
||||
"content": "message feedback information"
|
||||
}'
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -444,6 +444,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
<Property name='user' type='string' key='user'>
|
||||
ユーザー識別子、開発者のルールによって定義され、アプリケーション内で一意でなければなりません。
|
||||
</Property>
|
||||
<Property name='content' type='string' key='content'>
|
||||
メッセージのフィードバックです。
|
||||
</Property>
|
||||
</Properties>
|
||||
|
||||
### 応答
|
||||
|
|
@ -451,7 +454,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
</Col>
|
||||
<Col sticky>
|
||||
|
||||
<CodeGroup title="リクエスト" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123"\n}'`}>
|
||||
<CodeGroup title="リクエスト" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123",\n "content": "message feedback information"\n}'`}>
|
||||
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \
|
||||
|
|
@ -459,7 +462,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"rating": "like",
|
||||
"user": "abc-123"
|
||||
"user": "abc-123",
|
||||
"content": "message feedback information"
|
||||
}'
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -450,6 +450,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
|||
<Property name='user' type='string' key='user'>
|
||||
用户标识,由开发者定义规则,需保证用户标识在应用内唯一。
|
||||
</Property>
|
||||
<Property name='content' type='string' key='content'>
|
||||
消息反馈的具体信息。
|
||||
</Property>
|
||||
</Properties>
|
||||
|
||||
### Response
|
||||
|
|
@ -457,7 +460,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
|||
</Col>
|
||||
<Col sticky>
|
||||
|
||||
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123"\n}'`}>
|
||||
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123",\n "content": "message feedback information"\n}'`}>
|
||||
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \
|
||||
|
|
@ -465,7 +468,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
|||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"rating": "like",
|
||||
"user": "abc-123"
|
||||
"user": "abc-123",
|
||||
"content": "message feedback information"
|
||||
}'
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -408,6 +408,9 @@ Chat applications support session persistence, allowing previous chat history to
|
|||
<Property name='user' type='string' key='user'>
|
||||
User identifier, defined by the developer's rules, must be unique within the application.
|
||||
</Property>
|
||||
<Property name='content' type='string' key='content'>
|
||||
The specific content of message feedback.
|
||||
</Property>
|
||||
</Properties>
|
||||
|
||||
### Response
|
||||
|
|
@ -415,7 +418,7 @@ Chat applications support session persistence, allowing previous chat history to
|
|||
</Col>
|
||||
<Col sticky>
|
||||
|
||||
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123"\n}'`}>
|
||||
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123",\n "content": "message feedback information"\n}'`}>
|
||||
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \
|
||||
|
|
@ -423,7 +426,8 @@ Chat applications support session persistence, allowing previous chat history to
|
|||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"rating": "like",
|
||||
"user": "abc-123"
|
||||
"user": "abc-123",
|
||||
"content": "message feedback information"
|
||||
}'
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -408,6 +408,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
<Property name='user' type='string' key='user'>
|
||||
ユーザー識別子、開発者のルールで定義され、アプリケーション内で一意でなければなりません。
|
||||
</Property>
|
||||
<Property name='content' type='string' key='content'>
|
||||
メッセージのフィードバックです。
|
||||
</Property>
|
||||
</Properties>
|
||||
|
||||
### 応答
|
||||
|
|
@ -415,7 +418,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
</Col>
|
||||
<Col sticky>
|
||||
|
||||
<CodeGroup title="リクエスト" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123"\n}'`}>
|
||||
<CodeGroup title="リクエスト" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123",\n "content": "message feedback information"\n}'`}>
|
||||
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \
|
||||
|
|
@ -423,7 +426,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
|
|||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"rating": "like",
|
||||
"user": "abc-123"
|
||||
"user": "abc-123",
|
||||
"content": "message feedback information"
|
||||
}'
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -423,6 +423,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
|||
<Property name='user' type='string' key='user'>
|
||||
用户标识,由开发者定义规则,需保证用户标识在应用内唯一。
|
||||
</Property>
|
||||
<Property name='content' type='string' key='content'>
|
||||
消息反馈的具体信息。
|
||||
</Property>
|
||||
</Properties>
|
||||
|
||||
### Response
|
||||
|
|
@ -430,7 +433,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
|||
</Col>
|
||||
<Col sticky>
|
||||
|
||||
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123"\n}'`}>
|
||||
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123",\n "content": "message feedback information"\n}'`}>
|
||||
|
||||
```bash {{ title: 'cURL' }}
|
||||
curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \
|
||||
|
|
@ -438,7 +441,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
|
|||
--header 'Content-Type: application/json' \
|
||||
--data-raw '{
|
||||
"rating": "like",
|
||||
"user": "abc-123"
|
||||
"user": "abc-123",
|
||||
"content": "message feedback information"
|
||||
}'
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -506,3 +506,5 @@ export const WORKFLOW_DATA_UPDATE = 'WORKFLOW_DATA_UPDATE'
|
|||
export const CUSTOM_NODE = 'custom'
|
||||
export const CUSTOM_EDGE = 'custom'
|
||||
export const DSL_EXPORT_CHECK = 'DSL_EXPORT_CHECK'
|
||||
export const DEFAULT_RETRY_MAX = 3
|
||||
export const DEFAULT_RETRY_INTERVAL = 100
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ import {
|
|||
getFilesInLogs,
|
||||
} from '@/app/components/base/file-uploader/utils'
|
||||
import { ErrorHandleTypeEnum } from '@/app/components/workflow/nodes/_base/components/error-handle/types'
|
||||
import type { NodeTracing } from '@/types/workflow'
|
||||
|
||||
export const useWorkflowRun = () => {
|
||||
const store = useStoreApi()
|
||||
|
|
@ -114,6 +115,7 @@ export const useWorkflowRun = () => {
|
|||
onIterationStart,
|
||||
onIterationNext,
|
||||
onIterationFinish,
|
||||
onNodeRetry,
|
||||
onError,
|
||||
...restCallback
|
||||
} = callback || {}
|
||||
|
|
@ -440,10 +442,13 @@ export const useWorkflowRun = () => {
|
|||
})
|
||||
if (currentIndex > -1 && draft.tracing) {
|
||||
draft.tracing[currentIndex] = {
|
||||
...data,
|
||||
...(draft.tracing[currentIndex].extras
|
||||
? { extras: draft.tracing[currentIndex].extras }
|
||||
: {}),
|
||||
...data,
|
||||
...(draft.tracing[currentIndex].retryDetail
|
||||
? { retryDetail: draft.tracing[currentIndex].retryDetail }
|
||||
: {}),
|
||||
} as any
|
||||
}
|
||||
}))
|
||||
|
|
@ -616,6 +621,41 @@ export const useWorkflowRun = () => {
|
|||
if (onIterationFinish)
|
||||
onIterationFinish(params)
|
||||
},
|
||||
onNodeRetry: (params) => {
|
||||
const { data } = params
|
||||
const {
|
||||
workflowRunningData,
|
||||
setWorkflowRunningData,
|
||||
} = workflowStore.getState()
|
||||
const {
|
||||
getNodes,
|
||||
setNodes,
|
||||
} = store.getState()
|
||||
|
||||
const nodes = getNodes()
|
||||
setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
|
||||
const tracing = draft.tracing!
|
||||
const currentRetryNodeIndex = tracing.findIndex(trace => trace.node_id === data.node_id)
|
||||
|
||||
if (currentRetryNodeIndex > -1) {
|
||||
const currentRetryNode = tracing[currentRetryNodeIndex]
|
||||
if (currentRetryNode.retryDetail)
|
||||
draft.tracing![currentRetryNodeIndex].retryDetail!.push(data as NodeTracing)
|
||||
|
||||
else
|
||||
draft.tracing![currentRetryNodeIndex].retryDetail = [data as NodeTracing]
|
||||
}
|
||||
}))
|
||||
const newNodes = produce(nodes, (draft) => {
|
||||
const currentNode = draft.find(node => node.id === data.node_id)!
|
||||
|
||||
currentNode.data._retryIndex = data.retry_index
|
||||
})
|
||||
setNodes(newNodes)
|
||||
|
||||
if (onNodeRetry)
|
||||
onNodeRetry(params)
|
||||
},
|
||||
onParallelBranchStarted: (params) => {
|
||||
// console.log(params, 'parallel start')
|
||||
},
|
||||
|
|
|
|||
|
|
@ -17,17 +17,25 @@ import ResultPanel from '@/app/components/workflow/run/result-panel'
|
|||
import Toast from '@/app/components/base/toast'
|
||||
import { TransferMethod } from '@/types/app'
|
||||
import { getProcessedFiles } from '@/app/components/base/file-uploader/utils'
|
||||
import type { NodeTracing } from '@/types/workflow'
|
||||
import RetryResultPanel from '@/app/components/workflow/run/retry-result-panel'
|
||||
import type { BlockEnum } from '@/app/components/workflow/types'
|
||||
import type { Emoji } from '@/app/components/tools/types'
|
||||
|
||||
const i18nPrefix = 'workflow.singleRun'
|
||||
|
||||
type BeforeRunFormProps = {
|
||||
nodeName: string
|
||||
nodeType?: BlockEnum
|
||||
toolIcon?: string | Emoji
|
||||
onHide: () => void
|
||||
onRun: (submitData: Record<string, any>) => void
|
||||
onStop: () => void
|
||||
runningStatus: NodeRunningStatus
|
||||
result?: JSX.Element
|
||||
forms: FormProps[]
|
||||
retryDetails?: NodeTracing[]
|
||||
onRetryDetailBack?: any
|
||||
}
|
||||
|
||||
function formatValue(value: string | any, type: InputVarType) {
|
||||
|
|
@ -50,12 +58,16 @@ function formatValue(value: string | any, type: InputVarType) {
|
|||
}
|
||||
const BeforeRunForm: FC<BeforeRunFormProps> = ({
|
||||
nodeName,
|
||||
nodeType,
|
||||
toolIcon,
|
||||
onHide,
|
||||
onRun,
|
||||
onStop,
|
||||
runningStatus,
|
||||
result,
|
||||
forms,
|
||||
retryDetails,
|
||||
onRetryDetailBack = () => { },
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
|
|
@ -122,48 +134,69 @@ const BeforeRunForm: FC<BeforeRunFormProps> = ({
|
|||
<div className='text-base font-semibold text-gray-900 truncate'>
|
||||
{t(`${i18nPrefix}.testRun`)} {nodeName}
|
||||
</div>
|
||||
<div className='ml-2 shrink-0 p-1 cursor-pointer' onClick={onHide}>
|
||||
<div className='ml-2 shrink-0 p-1 cursor-pointer' onClick={() => {
|
||||
onHide()
|
||||
}}>
|
||||
<RiCloseLine className='w-4 h-4 text-gray-500 ' />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className='h-0 grow overflow-y-auto pb-4'>
|
||||
<div className='mt-3 px-4 space-y-4'>
|
||||
{forms.map((form, index) => (
|
||||
<div key={index}>
|
||||
<Form
|
||||
key={index}
|
||||
className={cn(index < forms.length - 1 && 'mb-4')}
|
||||
{...form}
|
||||
/>
|
||||
{index < forms.length - 1 && <Split />}
|
||||
{
|
||||
retryDetails?.length && (
|
||||
<div className='h-0 grow overflow-y-auto pb-4'>
|
||||
<RetryResultPanel
|
||||
list={retryDetails.map((item, index) => ({
|
||||
...item,
|
||||
title: `${t('workflow.nodes.common.retry.retry')} ${index + 1}`,
|
||||
node_type: nodeType!,
|
||||
extras: {
|
||||
icon: toolIcon!,
|
||||
},
|
||||
}))}
|
||||
onBack={onRetryDetailBack}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
!retryDetails?.length && (
|
||||
<div className='h-0 grow overflow-y-auto pb-4'>
|
||||
<div className='mt-3 px-4 space-y-4'>
|
||||
{forms.map((form, index) => (
|
||||
<div key={index}>
|
||||
<Form
|
||||
key={index}
|
||||
className={cn(index < forms.length - 1 && 'mb-4')}
|
||||
{...form}
|
||||
/>
|
||||
{index < forms.length - 1 && <Split />}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
<div className='mt-4 flex justify-between space-x-2 px-4' >
|
||||
{isRunning && (
|
||||
<div
|
||||
className='p-2 rounded-lg border border-gray-200 bg-white shadow-xs cursor-pointer'
|
||||
onClick={onStop}
|
||||
>
|
||||
<StopCircle className='w-4 h-4 text-gray-500' />
|
||||
<div className='mt-4 flex justify-between space-x-2 px-4' >
|
||||
{isRunning && (
|
||||
<div
|
||||
className='p-2 rounded-lg border border-gray-200 bg-white shadow-xs cursor-pointer'
|
||||
onClick={onStop}
|
||||
>
|
||||
<StopCircle className='w-4 h-4 text-gray-500' />
|
||||
</div>
|
||||
)}
|
||||
<Button disabled={!isFileLoaded || isRunning} variant='primary' className='w-0 grow space-x-2' onClick={handleRun}>
|
||||
{isRunning && <RiLoader2Line className='animate-spin w-4 h-4 text-white' />}
|
||||
<div>{t(`${i18nPrefix}.${isRunning ? 'running' : 'startRun'}`)}</div>
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
<Button disabled={!isFileLoaded || isRunning} variant='primary' className='w-0 grow space-x-2' onClick={handleRun}>
|
||||
{isRunning && <RiLoader2Line className='animate-spin w-4 h-4 text-white' />}
|
||||
<div>{t(`${i18nPrefix}.${isRunning ? 'running' : 'startRun'}`)}</div>
|
||||
</Button>
|
||||
</div>
|
||||
{isRunning && (
|
||||
<ResultPanel status='running' showSteps={false} />
|
||||
)}
|
||||
{isFinished && (
|
||||
<>
|
||||
{result}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
{isRunning && (
|
||||
<ResultPanel status='running' showSteps={false} />
|
||||
)}
|
||||
{isFinished && (
|
||||
<>
|
||||
{result}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ import type {
|
|||
CommonNodeType,
|
||||
Node,
|
||||
} from '@/app/components/workflow/types'
|
||||
import Split from '@/app/components/workflow/nodes/_base/components/split'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
|
||||
type ErrorHandleProps = Pick<Node, 'id' | 'data'>
|
||||
|
|
@ -45,7 +44,6 @@ const ErrorHandle = ({
|
|||
|
||||
return (
|
||||
<>
|
||||
<Split />
|
||||
<div className='py-4'>
|
||||
<Collapse
|
||||
disabled={!error_strategy}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,41 @@
|
|||
import {
|
||||
useCallback,
|
||||
useState,
|
||||
} from 'react'
|
||||
import type { WorkflowRetryConfig } from './types'
|
||||
import {
|
||||
useNodeDataUpdate,
|
||||
} from '@/app/components/workflow/hooks'
|
||||
import type { NodeTracing } from '@/types/workflow'
|
||||
|
||||
export const useRetryConfig = (
|
||||
id: string,
|
||||
) => {
|
||||
const { handleNodeDataUpdateWithSyncDraft } = useNodeDataUpdate()
|
||||
|
||||
const handleRetryConfigChange = useCallback((value?: WorkflowRetryConfig) => {
|
||||
handleNodeDataUpdateWithSyncDraft({
|
||||
id,
|
||||
data: {
|
||||
retry_config: value,
|
||||
},
|
||||
})
|
||||
}, [id, handleNodeDataUpdateWithSyncDraft])
|
||||
|
||||
return {
|
||||
handleRetryConfigChange,
|
||||
}
|
||||
}
|
||||
|
||||
export const useRetryDetailShowInSingleRun = () => {
|
||||
const [retryDetails, setRetryDetails] = useState<NodeTracing[] | undefined>()
|
||||
|
||||
const handleRetryDetailsChange = useCallback((details: NodeTracing[] | undefined) => {
|
||||
setRetryDetails(details)
|
||||
}, [])
|
||||
|
||||
return {
|
||||
retryDetails,
|
||||
handleRetryDetailsChange,
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
RiAlertFill,
|
||||
RiCheckboxCircleFill,
|
||||
RiLoader2Line,
|
||||
} from '@remixicon/react'
|
||||
import type { Node } from '@/app/components/workflow/types'
|
||||
import { NodeRunningStatus } from '@/app/components/workflow/types'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
type RetryOnNodeProps = Pick<Node, 'id' | 'data'>
|
||||
const RetryOnNode = ({
|
||||
data,
|
||||
}: RetryOnNodeProps) => {
|
||||
const { t } = useTranslation()
|
||||
const { retry_config } = data
|
||||
const showSelectedBorder = data.selected || data._isBundled || data._isEntering
|
||||
const {
|
||||
isRunning,
|
||||
isSuccessful,
|
||||
isException,
|
||||
isFailed,
|
||||
} = useMemo(() => {
|
||||
return {
|
||||
isRunning: data._runningStatus === NodeRunningStatus.Running && !showSelectedBorder,
|
||||
isSuccessful: data._runningStatus === NodeRunningStatus.Succeeded && !showSelectedBorder,
|
||||
isFailed: data._runningStatus === NodeRunningStatus.Failed && !showSelectedBorder,
|
||||
isException: data._runningStatus === NodeRunningStatus.Exception && !showSelectedBorder,
|
||||
}
|
||||
}, [data._runningStatus, showSelectedBorder])
|
||||
const showDefault = !isRunning && !isSuccessful && !isException && !isFailed
|
||||
|
||||
if (!retry_config)
|
||||
return null
|
||||
|
||||
return (
|
||||
<div className='px-3'>
|
||||
<div className={cn(
|
||||
'flex items-center justify-between px-[5px] py-1 bg-workflow-block-parma-bg border-[0.5px] border-transparent rounded-md system-xs-medium-uppercase text-text-tertiary',
|
||||
isRunning && 'bg-state-accent-hover border-state-accent-active text-text-accent',
|
||||
isSuccessful && 'bg-state-success-hover border-state-success-active text-text-success',
|
||||
(isException || isFailed) && 'bg-state-warning-hover border-state-warning-active text-text-warning',
|
||||
)}>
|
||||
<div className='flex items-center'>
|
||||
{
|
||||
showDefault && (
|
||||
t('workflow.nodes.common.retry.retryTimes', { times: retry_config.max_retries })
|
||||
)
|
||||
}
|
||||
{
|
||||
isRunning && (
|
||||
<>
|
||||
<RiLoader2Line className='animate-spin mr-1 w-3.5 h-3.5' />
|
||||
{t('workflow.nodes.common.retry.retrying')}
|
||||
</>
|
||||
)
|
||||
}
|
||||
{
|
||||
isSuccessful && (
|
||||
<>
|
||||
<RiCheckboxCircleFill className='mr-1 w-3.5 h-3.5' />
|
||||
{t('workflow.nodes.common.retry.retrySuccessful')}
|
||||
</>
|
||||
)
|
||||
}
|
||||
{
|
||||
(isFailed || isException) && (
|
||||
<>
|
||||
<RiAlertFill className='mr-1 w-3.5 h-3.5' />
|
||||
{t('workflow.nodes.common.retry.retryFailed')}
|
||||
</>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
{
|
||||
!showDefault && (
|
||||
<div>
|
||||
{data._retryIndex}/{data.retry_config?.max_retries}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default RetryOnNode
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
import { useTranslation } from 'react-i18next'
|
||||
import { useRetryConfig } from './hooks'
|
||||
import s from './style.module.css'
|
||||
import Switch from '@/app/components/base/switch'
|
||||
import Slider from '@/app/components/base/slider'
|
||||
import Input from '@/app/components/base/input'
|
||||
import type {
|
||||
Node,
|
||||
} from '@/app/components/workflow/types'
|
||||
import Split from '@/app/components/workflow/nodes/_base/components/split'
|
||||
|
||||
type RetryOnPanelProps = Pick<Node, 'id' | 'data'>
|
||||
const RetryOnPanel = ({
|
||||
id,
|
||||
data,
|
||||
}: RetryOnPanelProps) => {
|
||||
const { t } = useTranslation()
|
||||
const { handleRetryConfigChange } = useRetryConfig(id)
|
||||
const { retry_config } = data
|
||||
|
||||
const handleRetryEnabledChange = (value: boolean) => {
|
||||
handleRetryConfigChange({
|
||||
retry_enabled: value,
|
||||
max_retries: retry_config?.max_retries || 3,
|
||||
retry_interval: retry_config?.retry_interval || 1000,
|
||||
})
|
||||
}
|
||||
|
||||
const handleMaxRetriesChange = (value: number) => {
|
||||
if (value > 10)
|
||||
value = 10
|
||||
else if (value < 1)
|
||||
value = 1
|
||||
handleRetryConfigChange({
|
||||
retry_enabled: true,
|
||||
max_retries: value,
|
||||
retry_interval: retry_config?.retry_interval || 1000,
|
||||
})
|
||||
}
|
||||
|
||||
const handleRetryIntervalChange = (value: number) => {
|
||||
if (value > 5000)
|
||||
value = 5000
|
||||
else if (value < 100)
|
||||
value = 100
|
||||
handleRetryConfigChange({
|
||||
retry_enabled: true,
|
||||
max_retries: retry_config?.max_retries || 3,
|
||||
retry_interval: value,
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className='pt-2'>
|
||||
<div className='flex items-center justify-between px-4 py-2 h-10'>
|
||||
<div className='flex items-center'>
|
||||
<div className='mr-0.5 system-sm-semibold-uppercase text-text-secondary'>{t('workflow.nodes.common.retry.retryOnFailure')}</div>
|
||||
</div>
|
||||
<Switch
|
||||
defaultValue={retry_config?.retry_enabled}
|
||||
onChange={v => handleRetryEnabledChange(v)}
|
||||
/>
|
||||
</div>
|
||||
{
|
||||
retry_config?.retry_enabled && (
|
||||
<div className='px-4 pb-2'>
|
||||
<div className='flex items-center mb-1 w-full'>
|
||||
<div className='grow mr-2 system-xs-medium-uppercase'>{t('workflow.nodes.common.retry.maxRetries')}</div>
|
||||
<Slider
|
||||
className='mr-3 w-[108px]'
|
||||
value={retry_config?.max_retries || 3}
|
||||
onChange={handleMaxRetriesChange}
|
||||
min={1}
|
||||
max={10}
|
||||
/>
|
||||
<Input
|
||||
type='number'
|
||||
wrapperClassName='w-[80px]'
|
||||
value={retry_config?.max_retries || 3}
|
||||
onChange={e => handleMaxRetriesChange(e.target.value as any)}
|
||||
min={1}
|
||||
max={10}
|
||||
unit={t('workflow.nodes.common.retry.times') || ''}
|
||||
className={s.input}
|
||||
/>
|
||||
</div>
|
||||
<div className='flex items-center'>
|
||||
<div className='grow mr-2 system-xs-medium-uppercase'>{t('workflow.nodes.common.retry.retryInterval')}</div>
|
||||
<Slider
|
||||
className='mr-3 w-[108px]'
|
||||
value={retry_config?.retry_interval || 1000}
|
||||
onChange={handleRetryIntervalChange}
|
||||
min={100}
|
||||
max={5000}
|
||||
/>
|
||||
<Input
|
||||
type='number'
|
||||
wrapperClassName='w-[80px]'
|
||||
value={retry_config?.retry_interval || 1000}
|
||||
onChange={e => handleRetryIntervalChange(e.target.value as any)}
|
||||
min={100}
|
||||
max={5000}
|
||||
unit={t('workflow.nodes.common.retry.ms') || ''}
|
||||
className={s.input}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
<Split className='mx-4 mt-2' />
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default RetryOnPanel
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
.input::-webkit-inner-spin-button,
|
||||
.input::-webkit-outer-spin-button {
|
||||
-webkit-appearance: none;
|
||||
margin: 0;
|
||||
}
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
export type WorkflowRetryConfig = {
|
||||
max_retries: number
|
||||
retry_interval: number
|
||||
retry_enabled: boolean
|
||||
}
|
||||
|
|
@ -25,7 +25,10 @@ import {
|
|||
useNodesReadOnly,
|
||||
useToolIcon,
|
||||
} from '../../hooks'
|
||||
import { hasErrorHandleNode } from '../../utils'
|
||||
import {
|
||||
hasErrorHandleNode,
|
||||
hasRetryNode,
|
||||
} from '../../utils'
|
||||
import { useNodeIterationInteractions } from '../iteration/use-interactions'
|
||||
import type { IterationNodeType } from '../iteration/types'
|
||||
import {
|
||||
|
|
@ -35,6 +38,7 @@ import {
|
|||
import NodeResizer from './components/node-resizer'
|
||||
import NodeControl from './components/node-control'
|
||||
import ErrorHandleOnNode from './components/error-handle/error-handle-on-node'
|
||||
import RetryOnNode from './components/retry/retry-on-node'
|
||||
import AddVariablePopupWithPosition from './components/add-variable-popup-with-position'
|
||||
import cn from '@/utils/classnames'
|
||||
import BlockIcon from '@/app/components/workflow/block-icon'
|
||||
|
|
@ -237,6 +241,14 @@ const BaseNode: FC<BaseNodeProps> = ({
|
|||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
hasRetryNode(data.type) && (
|
||||
<RetryOnNode
|
||||
id={id}
|
||||
data={data}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
hasErrorHandleNode(data.type) && (
|
||||
<ErrorHandleOnNode
|
||||
|
|
|
|||
|
|
@ -21,9 +21,11 @@ import {
|
|||
TitleInput,
|
||||
} from './components/title-description-input'
|
||||
import ErrorHandleOnPanel from './components/error-handle/error-handle-on-panel'
|
||||
import RetryOnPanel from './components/retry/retry-on-panel'
|
||||
import { useResizePanel } from './hooks/use-resize-panel'
|
||||
import cn from '@/utils/classnames'
|
||||
import BlockIcon from '@/app/components/workflow/block-icon'
|
||||
import Split from '@/app/components/workflow/nodes/_base/components/split'
|
||||
import {
|
||||
WorkflowHistoryEvent,
|
||||
useAvailableBlocks,
|
||||
|
|
@ -38,6 +40,7 @@ import {
|
|||
import {
|
||||
canRunBySingle,
|
||||
hasErrorHandleNode,
|
||||
hasRetryNode,
|
||||
} from '@/app/components/workflow/utils'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import type { Node } from '@/app/components/workflow/types'
|
||||
|
|
@ -168,6 +171,15 @@ const BasePanel: FC<BasePanelProps> = ({
|
|||
<div>
|
||||
{cloneElement(children, { id, data })}
|
||||
</div>
|
||||
<Split />
|
||||
{
|
||||
hasRetryNode(data.type) && (
|
||||
<RetryOnPanel
|
||||
id={id}
|
||||
data={data}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{
|
||||
hasErrorHandleNode(data.type) && (
|
||||
<ErrorHandleOnPanel
|
||||
|
|
|
|||
|
|
@ -2,7 +2,10 @@ import { BlockEnum } from '../../types'
|
|||
import type { NodeDefault } from '../../types'
|
||||
import { AuthorizationType, BodyType, Method } from './types'
|
||||
import type { BodyPayload, HttpNodeType } from './types'
|
||||
import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants'
|
||||
import {
|
||||
ALL_CHAT_AVAILABLE_BLOCKS,
|
||||
ALL_COMPLETION_AVAILABLE_BLOCKS,
|
||||
} from '@/app/components/workflow/constants'
|
||||
|
||||
const nodeDefault: NodeDefault<HttpNodeType> = {
|
||||
defaultValue: {
|
||||
|
|
@ -24,6 +27,11 @@ const nodeDefault: NodeDefault<HttpNodeType> = {
|
|||
max_read_timeout: 0,
|
||||
max_write_timeout: 0,
|
||||
},
|
||||
retry_config: {
|
||||
retry_enabled: true,
|
||||
max_retries: 3,
|
||||
retry_interval: 100,
|
||||
},
|
||||
},
|
||||
getAvailablePrevNodes(isChatMode: boolean) {
|
||||
const nodes = isChatMode
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import { memo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import useConfig from './use-config'
|
||||
import ApiInput from './components/api-input'
|
||||
|
|
@ -18,6 +18,7 @@ import { FileArrow01 } from '@/app/components/base/icons/src/vender/line/files'
|
|||
import type { NodePanelProps } from '@/app/components/workflow/types'
|
||||
import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form'
|
||||
import ResultPanel from '@/app/components/workflow/run/result-panel'
|
||||
import { useRetryDetailShowInSingleRun } from '@/app/components/workflow/nodes/_base/components/retry/hooks'
|
||||
|
||||
const i18nPrefix = 'workflow.nodes.http'
|
||||
|
||||
|
|
@ -60,6 +61,10 @@ const Panel: FC<NodePanelProps<HttpNodeType>> = ({
|
|||
hideCurlPanel,
|
||||
handleCurlImport,
|
||||
} = useConfig(id, data)
|
||||
const {
|
||||
retryDetails,
|
||||
handleRetryDetailsChange,
|
||||
} = useRetryDetailShowInSingleRun()
|
||||
// To prevent prompt editor in body not update data.
|
||||
if (!isDataReady)
|
||||
return null
|
||||
|
|
@ -181,6 +186,7 @@ const Panel: FC<NodePanelProps<HttpNodeType>> = ({
|
|||
{isShowSingleRun && (
|
||||
<BeforeRunForm
|
||||
nodeName={inputs.title}
|
||||
nodeType={inputs.type}
|
||||
onHide={hideSingleRun}
|
||||
forms={[
|
||||
{
|
||||
|
|
@ -192,7 +198,9 @@ const Panel: FC<NodePanelProps<HttpNodeType>> = ({
|
|||
runningStatus={runningStatus}
|
||||
onRun={handleRun}
|
||||
onStop={handleStop}
|
||||
result={<ResultPanel {...runResult} showSteps={false} />}
|
||||
retryDetails={retryDetails}
|
||||
onRetryDetailBack={handleRetryDetailsChange}
|
||||
result={<ResultPanel {...runResult} showSteps={false} onShowRetryDetail={handleRetryDetailsChange} />}
|
||||
/>
|
||||
)}
|
||||
{(isShowCurlPanel && !readOnly) && (
|
||||
|
|
@ -207,4 +215,4 @@ const Panel: FC<NodePanelProps<HttpNodeType>> = ({
|
|||
)
|
||||
}
|
||||
|
||||
export default React.memo(Panel)
|
||||
export default memo(Panel)
|
||||
|
|
|
|||
|
|
@ -129,9 +129,6 @@ export const getMultipleRetrievalConfig = (
|
|||
reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true,
|
||||
}
|
||||
|
||||
if (!rerankModelIsValid)
|
||||
result.reranking_model = undefined
|
||||
|
||||
const setDefaultWeights = () => {
|
||||
result.weights = {
|
||||
vector_setting: {
|
||||
|
|
@ -198,7 +195,6 @@ export const getMultipleRetrievalConfig = (
|
|||
setDefaultWeights()
|
||||
}
|
||||
}
|
||||
|
||||
if (reranking_mode === RerankingModeEnum.RerankingModel && !rerankModelIsValid && shouldSetWeightDefaultValue) {
|
||||
result.reranking_mode = RerankingModeEnum.WeightedScore
|
||||
setDefaultWeights()
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ import type { Props as FormProps } from '@/app/components/workflow/nodes/_base/c
|
|||
import ResultPanel from '@/app/components/workflow/run/result-panel'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import Editor from '@/app/components/workflow/nodes/_base/components/prompt/editor'
|
||||
import { useRetryDetailShowInSingleRun } from '@/app/components/workflow/nodes/_base/components/retry/hooks'
|
||||
|
||||
const i18nPrefix = 'workflow.nodes.llm'
|
||||
|
||||
|
|
@ -69,6 +70,10 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
|
|||
runResult,
|
||||
filterJinjia2InputVar,
|
||||
} = useConfig(id, data)
|
||||
const {
|
||||
retryDetails,
|
||||
handleRetryDetailsChange,
|
||||
} = useRetryDetailShowInSingleRun()
|
||||
|
||||
const model = inputs.model
|
||||
|
||||
|
|
@ -282,12 +287,15 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
|
|||
{isShowSingleRun && (
|
||||
<BeforeRunForm
|
||||
nodeName={inputs.title}
|
||||
nodeType={inputs.type}
|
||||
onHide={hideSingleRun}
|
||||
forms={singleRunForms}
|
||||
runningStatus={runningStatus}
|
||||
onRun={handleRun}
|
||||
onStop={handleStop}
|
||||
result={<ResultPanel {...runResult} showSteps={false} />}
|
||||
retryDetails={retryDetails}
|
||||
onRetryDetailBack={handleRetryDetailsChange}
|
||||
result={<ResultPanel {...runResult} showSteps={false} onShowRetryDetail={handleRetryDetailsChange} />}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -162,7 +162,7 @@ const InputVarList: FC<Props> = ({
|
|||
readonly={readOnly}
|
||||
isShowNodeName
|
||||
nodeId={nodeId}
|
||||
value={varInput?.type === VarKindType.constant ? (varInput?.value || '') : (varInput?.value || [])}
|
||||
value={varInput?.type === VarKindType.constant ? (varInput?.value ?? '') : (varInput?.value ?? [])}
|
||||
onChange={handleNotMixedTypeChange(variable)}
|
||||
onOpen={handleOpen(index)}
|
||||
defaultVarKindType={varInput?.type || (isNumber ? VarKindType.constant : VarKindType.variable)}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,8 @@ import Loading from '@/app/components/base/loading'
|
|||
import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form'
|
||||
import OutputVars, { VarItem } from '@/app/components/workflow/nodes/_base/components/output-vars'
|
||||
import ResultPanel from '@/app/components/workflow/run/result-panel'
|
||||
import { useRetryDetailShowInSingleRun } from '@/app/components/workflow/nodes/_base/components/retry/hooks'
|
||||
import { useToolIcon } from '@/app/components/workflow/hooks'
|
||||
|
||||
const i18nPrefix = 'workflow.nodes.tool'
|
||||
|
||||
|
|
@ -48,6 +50,11 @@ const Panel: FC<NodePanelProps<ToolNodeType>> = ({
|
|||
handleStop,
|
||||
runResult,
|
||||
} = useConfig(id, data)
|
||||
const toolIcon = useToolIcon(data)
|
||||
const {
|
||||
retryDetails,
|
||||
handleRetryDetailsChange,
|
||||
} = useRetryDetailShowInSingleRun()
|
||||
|
||||
if (isLoading) {
|
||||
return <div className='flex h-[200px] items-center justify-center'>
|
||||
|
|
@ -143,12 +150,16 @@ const Panel: FC<NodePanelProps<ToolNodeType>> = ({
|
|||
{isShowSingleRun && (
|
||||
<BeforeRunForm
|
||||
nodeName={inputs.title}
|
||||
nodeType={inputs.type}
|
||||
toolIcon={toolIcon}
|
||||
onHide={hideSingleRun}
|
||||
forms={singleRunForms}
|
||||
runningStatus={runningStatus}
|
||||
onRun={handleRun}
|
||||
onStop={handleStop}
|
||||
result={<ResultPanel {...runResult} showSteps={false} />}
|
||||
retryDetails={retryDetails}
|
||||
onRetryDetailBack={handleRetryDetailsChange}
|
||||
result={<ResultPanel {...runResult} showSteps={false} onShowRetryDetail={handleRetryDetailsChange} />}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ import {
|
|||
getProcessedFilesFromResponse,
|
||||
} from '@/app/components/base/file-uploader/utils'
|
||||
import type { FileEntity } from '@/app/components/base/file-uploader/types'
|
||||
import type { NodeTracing } from '@/types/workflow'
|
||||
|
||||
type GetAbortController = (abortController: AbortController) => void
|
||||
type SendCallback = {
|
||||
|
|
@ -381,6 +382,28 @@ export const useChat = (
|
|||
}
|
||||
}))
|
||||
},
|
||||
onNodeRetry: ({ data }) => {
|
||||
if (data.iteration_id)
|
||||
return
|
||||
|
||||
const currentIndex = responseItem.workflowProcess!.tracing!.findIndex((item) => {
|
||||
if (!item.execution_metadata?.parallel_id)
|
||||
return item.node_id === data.node_id
|
||||
return item.node_id === data.node_id && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id)
|
||||
})
|
||||
if (responseItem.workflowProcess!.tracing[currentIndex].retryDetail)
|
||||
responseItem.workflowProcess!.tracing[currentIndex].retryDetail?.push(data as NodeTracing)
|
||||
else
|
||||
responseItem.workflowProcess!.tracing[currentIndex].retryDetail = [data as NodeTracing]
|
||||
|
||||
handleUpdateChatList(produce(chatListRef.current, (draft) => {
|
||||
const currentIndex = draft.findIndex(item => item.id === responseItem.id)
|
||||
draft[currentIndex] = {
|
||||
...draft[currentIndex],
|
||||
...responseItem,
|
||||
}
|
||||
}))
|
||||
},
|
||||
onNodeFinished: ({ data }) => {
|
||||
if (data.iteration_id)
|
||||
return
|
||||
|
|
@ -394,6 +417,9 @@ export const useChat = (
|
|||
...(responseItem.workflowProcess!.tracing[currentIndex]?.extras
|
||||
? { extras: responseItem.workflowProcess!.tracing[currentIndex].extras }
|
||||
: {}),
|
||||
...(responseItem.workflowProcess!.tracing[currentIndex]?.retryDetail
|
||||
? { retryDetail: responseItem.workflowProcess!.tracing[currentIndex].retryDetail }
|
||||
: {}),
|
||||
...data,
|
||||
} as any
|
||||
handleUpdateChatList(produce(chatListRef.current, (draft) => {
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ import {
|
|||
import { SimpleBtn } from '../../app/text-generate/item'
|
||||
import Toast from '../../base/toast'
|
||||
import IterationResultPanel from '../run/iteration-result-panel'
|
||||
import RetryResultPanel from '../run/retry-result-panel'
|
||||
import InputsPanel from './inputs-panel'
|
||||
import cn from '@/utils/classnames'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
|
|
@ -53,11 +54,16 @@ const WorkflowPreview = () => {
|
|||
}, [workflowRunningData])
|
||||
|
||||
const [iterationRunResult, setIterationRunResult] = useState<NodeTracing[][]>([])
|
||||
const [retryRunResult, setRetryRunResult] = useState<NodeTracing[]>([])
|
||||
const [iterDurationMap, setIterDurationMap] = useState<IterationDurationMap>({})
|
||||
const [isShowIterationDetail, {
|
||||
setTrue: doShowIterationDetail,
|
||||
setFalse: doHideIterationDetail,
|
||||
}] = useBoolean(false)
|
||||
const [isShowRetryDetail, {
|
||||
setTrue: doShowRetryDetail,
|
||||
setFalse: doHideRetryDetail,
|
||||
}] = useBoolean(false)
|
||||
|
||||
const handleShowIterationDetail = useCallback((detail: NodeTracing[][], iterationDurationMap: IterationDurationMap) => {
|
||||
setIterDurationMap(iterationDurationMap)
|
||||
|
|
@ -65,6 +71,11 @@ const WorkflowPreview = () => {
|
|||
doShowIterationDetail()
|
||||
}, [doShowIterationDetail])
|
||||
|
||||
const handleRetryDetail = useCallback((detail: NodeTracing[]) => {
|
||||
setRetryRunResult(detail)
|
||||
doShowRetryDetail()
|
||||
}, [doShowRetryDetail])
|
||||
|
||||
if (isShowIterationDetail) {
|
||||
return (
|
||||
<div className={`
|
||||
|
|
@ -201,11 +212,12 @@ const WorkflowPreview = () => {
|
|||
<Loading />
|
||||
</div>
|
||||
)}
|
||||
{currentTab === 'TRACING' && (
|
||||
{currentTab === 'TRACING' && !isShowRetryDetail && (
|
||||
<TracingPanel
|
||||
className='bg-background-section-burn'
|
||||
list={workflowRunningData?.tracing || []}
|
||||
onShowIterationDetail={handleShowIterationDetail}
|
||||
onShowRetryDetail={handleRetryDetail}
|
||||
/>
|
||||
)}
|
||||
{currentTab === 'TRACING' && !workflowRunningData?.tracing?.length && (
|
||||
|
|
@ -213,7 +225,14 @@ const WorkflowPreview = () => {
|
|||
<Loading />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{
|
||||
currentTab === 'TRACING' && isShowRetryDetail && (
|
||||
<RetryResultPanel
|
||||
list={retryRunResult}
|
||||
onBack={doHideRetryDetail}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import OutputPanel from './output-panel'
|
|||
import ResultPanel from './result-panel'
|
||||
import TracingPanel from './tracing-panel'
|
||||
import IterationResultPanel from './iteration-result-panel'
|
||||
import RetryResultPanel from './retry-result-panel'
|
||||
import cn from '@/utils/classnames'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
|
|
@ -107,6 +108,18 @@ const RunPanel: FC<RunProps> = ({ hideResult, activeTab = 'RESULT', runID, getRe
|
|||
const processNonIterationNode = (item: NodeTracing) => {
|
||||
const { execution_metadata } = item
|
||||
if (!execution_metadata?.iteration_id) {
|
||||
if (item.status === 'retry') {
|
||||
const retryNode = result.find(node => node.node_id === item.node_id)
|
||||
|
||||
if (retryNode) {
|
||||
if (retryNode?.retryDetail)
|
||||
retryNode.retryDetail.push(item)
|
||||
else
|
||||
retryNode.retryDetail = [item]
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
result.push(item)
|
||||
return
|
||||
}
|
||||
|
|
@ -181,10 +194,15 @@ const RunPanel: FC<RunProps> = ({ hideResult, activeTab = 'RESULT', runID, getRe
|
|||
|
||||
const [iterationRunResult, setIterationRunResult] = useState<NodeTracing[][]>([])
|
||||
const [iterDurationMap, setIterDurationMap] = useState<IterationDurationMap>({})
|
||||
const [retryRunResult, setRetryRunResult] = useState<NodeTracing[]>([])
|
||||
const [isShowIterationDetail, {
|
||||
setTrue: doShowIterationDetail,
|
||||
setFalse: doHideIterationDetail,
|
||||
}] = useBoolean(false)
|
||||
const [isShowRetryDetail, {
|
||||
setTrue: doShowRetryDetail,
|
||||
setFalse: doHideRetryDetail,
|
||||
}] = useBoolean(false)
|
||||
|
||||
const handleShowIterationDetail = useCallback((detail: NodeTracing[][], iterDurationMap: IterationDurationMap) => {
|
||||
setIterationRunResult(detail)
|
||||
|
|
@ -192,6 +210,11 @@ const RunPanel: FC<RunProps> = ({ hideResult, activeTab = 'RESULT', runID, getRe
|
|||
setIterDurationMap(iterDurationMap)
|
||||
}, [doShowIterationDetail, setIterationRunResult, setIterDurationMap])
|
||||
|
||||
const handleShowRetryDetail = useCallback((detail: NodeTracing[]) => {
|
||||
setRetryRunResult(detail)
|
||||
doShowRetryDetail()
|
||||
}, [doShowRetryDetail, setRetryRunResult])
|
||||
|
||||
if (isShowIterationDetail) {
|
||||
return (
|
||||
<div className='grow relative flex flex-col'>
|
||||
|
|
@ -261,13 +284,22 @@ const RunPanel: FC<RunProps> = ({ hideResult, activeTab = 'RESULT', runID, getRe
|
|||
exceptionCounts={runDetail.exceptions_count}
|
||||
/>
|
||||
)}
|
||||
{!loading && currentTab === 'TRACING' && (
|
||||
{!loading && currentTab === 'TRACING' && !isShowRetryDetail && (
|
||||
<TracingPanel
|
||||
className='bg-background-section-burn'
|
||||
list={list}
|
||||
onShowIterationDetail={handleShowIterationDetail}
|
||||
onShowRetryDetail={handleShowRetryDetail}
|
||||
/>
|
||||
)}
|
||||
{
|
||||
!loading && currentTab === 'TRACING' && isShowRetryDetail && (
|
||||
<RetryResultPanel
|
||||
list={retryRunResult}
|
||||
onBack={doHideRetryDetail}
|
||||
/>
|
||||
)
|
||||
}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import {
|
|||
RiCheckboxCircleFill,
|
||||
RiErrorWarningLine,
|
||||
RiLoader2Line,
|
||||
RiRestartFill,
|
||||
} from '@remixicon/react'
|
||||
import BlockIcon from '../block-icon'
|
||||
import { BlockEnum } from '../types'
|
||||
|
|
@ -20,6 +21,7 @@ import Button from '@/app/components/base/button'
|
|||
import { CodeLanguage } from '@/app/components/workflow/nodes/code/types'
|
||||
import type { IterationDurationMap, NodeTracing } from '@/types/workflow'
|
||||
import ErrorHandleTip from '@/app/components/workflow/nodes/_base/components/error-handle/error-handle-tip'
|
||||
import { hasRetryNode } from '@/app/components/workflow/utils'
|
||||
|
||||
type Props = {
|
||||
className?: string
|
||||
|
|
@ -28,8 +30,10 @@ type Props = {
|
|||
hideInfo?: boolean
|
||||
hideProcessDetail?: boolean
|
||||
onShowIterationDetail?: (detail: NodeTracing[][], iterDurationMap: IterationDurationMap) => void
|
||||
onShowRetryDetail?: (detail: NodeTracing[]) => void
|
||||
notShowIterationNav?: boolean
|
||||
justShowIterationNavArrow?: boolean
|
||||
justShowRetryNavArrow?: boolean
|
||||
}
|
||||
|
||||
const NodePanel: FC<Props> = ({
|
||||
|
|
@ -39,6 +43,7 @@ const NodePanel: FC<Props> = ({
|
|||
hideInfo = false,
|
||||
hideProcessDetail,
|
||||
onShowIterationDetail,
|
||||
onShowRetryDetail,
|
||||
notShowIterationNav,
|
||||
justShowIterationNavArrow,
|
||||
}) => {
|
||||
|
|
@ -88,11 +93,17 @@ const NodePanel: FC<Props> = ({
|
|||
}, [nodeInfo.expand, setCollapseState])
|
||||
|
||||
const isIterationNode = nodeInfo.node_type === BlockEnum.Iteration
|
||||
const isRetryNode = hasRetryNode(nodeInfo.node_type) && nodeInfo.retryDetail
|
||||
const handleOnShowIterationDetail = (e: React.MouseEvent<HTMLButtonElement>) => {
|
||||
e.stopPropagation()
|
||||
e.nativeEvent.stopImmediatePropagation()
|
||||
onShowIterationDetail?.(nodeInfo.details || [], nodeInfo?.iterDurationMap || nodeInfo.execution_metadata?.iteration_duration_map || {})
|
||||
}
|
||||
const handleOnShowRetryDetail = (e: React.MouseEvent<HTMLButtonElement>) => {
|
||||
e.stopPropagation()
|
||||
e.nativeEvent.stopImmediatePropagation()
|
||||
onShowRetryDetail?.(nodeInfo.retryDetail || [])
|
||||
}
|
||||
return (
|
||||
<div className={cn('px-2 py-1', className)}>
|
||||
<div className='group transition-all bg-background-default border border-components-panel-border rounded-[10px] shadow-xs hover:shadow-md'>
|
||||
|
|
@ -169,6 +180,19 @@ const NodePanel: FC<Props> = ({
|
|||
<Split className='mt-2' />
|
||||
</div>
|
||||
)}
|
||||
{isRetryNode && (
|
||||
<Button
|
||||
className='flex items-center justify-between mb-1 w-full'
|
||||
variant='tertiary'
|
||||
onClick={handleOnShowRetryDetail}
|
||||
>
|
||||
<div className='flex items-center'>
|
||||
<RiRestartFill className='mr-0.5 w-4 h-4 text-components-button-tertiary-text flex-shrink-0' />
|
||||
{t('workflow.nodes.common.retry.retries', { num: nodeInfo.retryDetail?.length })}
|
||||
</div>
|
||||
<RiArrowRightSLine className='w-4 h-4 text-components-button-tertiary-text flex-shrink-0' />
|
||||
</Button>
|
||||
)}
|
||||
<div className={cn('mb-1', hideInfo && '!px-2 !py-0.5')}>
|
||||
{(nodeInfo.status === 'stopped') && (
|
||||
<StatusContainer status='stopped'>
|
||||
|
|
|
|||
|
|
@ -1,11 +1,17 @@
|
|||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
RiArrowRightSLine,
|
||||
RiRestartFill,
|
||||
} from '@remixicon/react'
|
||||
import StatusPanel from './status'
|
||||
import MetaData from './meta'
|
||||
import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor'
|
||||
import { CodeLanguage } from '@/app/components/workflow/nodes/code/types'
|
||||
import ErrorHandleTip from '@/app/components/workflow/nodes/_base/components/error-handle/error-handle-tip'
|
||||
import type { NodeTracing } from '@/types/workflow'
|
||||
import Button from '@/app/components/base/button'
|
||||
|
||||
type ResultPanelProps = {
|
||||
inputs?: string
|
||||
|
|
@ -22,6 +28,8 @@ type ResultPanelProps = {
|
|||
showSteps?: boolean
|
||||
exceptionCounts?: number
|
||||
execution_metadata?: any
|
||||
retry_events?: NodeTracing[]
|
||||
onShowRetryDetail?: (retries: NodeTracing[]) => void
|
||||
}
|
||||
|
||||
const ResultPanel: FC<ResultPanelProps> = ({
|
||||
|
|
@ -38,8 +46,11 @@ const ResultPanel: FC<ResultPanelProps> = ({
|
|||
showSteps,
|
||||
exceptionCounts,
|
||||
execution_metadata,
|
||||
retry_events,
|
||||
onShowRetryDetail,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<div className='bg-components-panel-bg py-2'>
|
||||
<div className='px-4 py-2'>
|
||||
|
|
@ -51,6 +62,23 @@ const ResultPanel: FC<ResultPanelProps> = ({
|
|||
exceptionCounts={exceptionCounts}
|
||||
/>
|
||||
</div>
|
||||
{
|
||||
retry_events?.length && onShowRetryDetail && (
|
||||
<div className='px-4'>
|
||||
<Button
|
||||
className='flex items-center justify-between w-full'
|
||||
variant='tertiary'
|
||||
onClick={() => onShowRetryDetail(retry_events)}
|
||||
>
|
||||
<div className='flex items-center'>
|
||||
<RiRestartFill className='mr-0.5 w-4 h-4 text-components-button-tertiary-text flex-shrink-0' />
|
||||
{t('workflow.nodes.common.retry.retries', { num: retry_events?.length })}
|
||||
</div>
|
||||
<RiArrowRightSLine className='w-4 h-4 text-components-button-tertiary-text flex-shrink-0' />
|
||||
</Button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
<div className='px-4 py-2 flex flex-col gap-2'>
|
||||
<CodeEditor
|
||||
readOnly
|
||||
|
|
|
|||
|
|
@ -0,0 +1,46 @@
|
|||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import { memo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import {
|
||||
RiArrowLeftLine,
|
||||
} from '@remixicon/react'
|
||||
import TracingPanel from './tracing-panel'
|
||||
import type { NodeTracing } from '@/types/workflow'
|
||||
|
||||
type Props = {
|
||||
list: NodeTracing[]
|
||||
onBack: () => void
|
||||
}
|
||||
|
||||
const RetryResultPanel: FC<Props> = ({
|
||||
list,
|
||||
onBack,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div
|
||||
className='flex items-center px-4 h-8 text-text-accent-secondary bg-components-panel-bg system-sm-medium cursor-pointer'
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
e.nativeEvent.stopImmediatePropagation()
|
||||
onBack()
|
||||
}}
|
||||
>
|
||||
<RiArrowLeftLine className='mr-1 w-4 h-4' />
|
||||
{t('workflow.singleRun.back')}
|
||||
</div>
|
||||
<TracingPanel
|
||||
list={list.map((item, index) => ({
|
||||
...item,
|
||||
title: `${t('workflow.nodes.common.retry.retry')} ${index + 1}`,
|
||||
}))}
|
||||
className='bg-background-section-burn'
|
||||
/>
|
||||
</div >
|
||||
)
|
||||
}
|
||||
export default memo(RetryResultPanel)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue