Merge branch 'main' into fix/remove-the-retry-index-field

This commit is contained in:
Novice Lee 2024-12-23 10:47:43 +08:00
commit 4d35df9210
125 changed files with 2847 additions and 370 deletions

View File

@ -555,7 +555,8 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
if language not in languages:
language = "en-US"
name = name.strip()
# Validates name encoding for non-Latin characters.
name = name.strip().encode("utf-8").decode("utf-8") if name else None
# generate random password
new_password = secrets.token_urlsafe(16)

View File

@ -4,3 +4,8 @@ from werkzeug.exceptions import HTTPException
class FilenameNotExistsError(HTTPException):
code = 400
description = "The specified filename does not exist."
class RemoteFileUploadError(HTTPException):
code = 400
description = "Error uploading remote file."

View File

@ -1,12 +1,14 @@
from flask_login import current_user
from flask_restful import marshal_with, reqparse
from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.console import api
from controllers.console.explore.error import NotChatAppError
from controllers.console.explore.wraps import InstalledAppResource
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from models.model import AppMode
@ -34,14 +36,16 @@ class ConversationListApi(InstalledAppResource):
pinned = True if args["pinned"] == "true" else False
try:
return WebConversationService.pagination_by_last_id(
app_model=app_model,
user=current_user,
last_id=args["last_id"],
limit=args["limit"],
invoke_from=InvokeFrom.EXPLORE,
pinned=pinned,
)
with Session(db.engine) as session:
return WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=current_user,
last_id=args["last_id"],
limit=args["limit"],
invoke_from=InvokeFrom.EXPLORE,
pinned=pinned,
)
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")

View File

@ -70,7 +70,7 @@ class MessageFeedbackApi(InstalledAppResource):
args = parser.parse_args()
try:
MessageService.create_feedback(app_model, message_id, current_user, args["rating"])
MessageService.create_feedback(app_model, message_id, current_user, args["rating"], args["content"])
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")

View File

@ -7,6 +7,7 @@ from flask_restful import Resource, marshal_with, reqparse
import services
from controllers.common import helpers
from controllers.common.errors import RemoteFileUploadError
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
@ -43,10 +44,14 @@ class RemoteFileUploadApi(Resource):
url = args["url"]
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
resp.raise_for_status()
try:
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as e:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
file_info = helpers.guess_file_info_from_response(resp)

View File

@ -3,12 +3,14 @@ import io
from flask import send_file
from flask_login import current_user
from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from libs.helper import alphanumeric, uuid_value
from libs.login import login_required
from services.tools.api_tools_manage_service import ApiToolManageService
@ -91,12 +93,16 @@ class ToolBuiltinProviderUpdateApi(Resource):
args = parser.parse_args()
return BuiltinToolManageService.update_builtin_tool_provider(
user_id,
tenant_id,
provider,
args["credentials"],
)
with Session(db.engine) as session:
result = BuiltinToolManageService.update_builtin_tool_provider(
session=session,
user_id=user_id,
tenant_id=tenant_id,
provider_name=provider,
credentials=args["credentials"],
)
session.commit()
return result
class ToolBuiltinProviderGetCredentialsApi(Resource):
@ -104,13 +110,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
@login_required
@account_initialization_required
def get(self, provider):
user_id = current_user.id
tenant_id = current_user.current_tenant_id
return BuiltinToolManageService.get_builtin_tool_provider_credentials(
user_id,
tenant_id,
provider,
tenant_id=tenant_id,
provider_name=provider,
)

View File

@ -1,5 +1,6 @@
from flask_restful import Resource, marshal_with, reqparse
from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
import services
@ -7,6 +8,7 @@ from controllers.service_api import api
from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import (
conversation_delete_fields,
conversation_infinite_scroll_pagination_fields,
@ -39,14 +41,16 @@ class ConversationApi(Resource):
args = parser.parse_args()
try:
return ConversationService.pagination_by_last_id(
app_model=app_model,
user=end_user,
last_id=args["last_id"],
limit=args["limit"],
invoke_from=InvokeFrom.SERVICE_API,
sort_by=args["sort_by"],
)
with Session(db.engine) as session:
return ConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=end_user,
last_id=args["last_id"],
limit=args["limit"],
invoke_from=InvokeFrom.SERVICE_API,
sort_by=args["sort_by"],
)
except services.errors.conversation.LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")

View File

@ -104,10 +104,11 @@ class MessageFeedbackApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
parser.add_argument("content", type=str, location="json")
args = parser.parse_args()
try:
MessageService.create_feedback(app_model, message_id, end_user, args["rating"])
MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"])
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")

View File

@ -1,11 +1,13 @@
from flask_restful import marshal_with, reqparse
from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.web import api
from controllers.web.error import NotChatAppError
from controllers.web.wraps import WebApiResource
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
from libs.helper import uuid_value
from models.model import AppMode
@ -40,15 +42,17 @@ class ConversationListApi(WebApiResource):
pinned = True if args["pinned"] == "true" else False
try:
return WebConversationService.pagination_by_last_id(
app_model=app_model,
user=end_user,
last_id=args["last_id"],
limit=args["limit"],
invoke_from=InvokeFrom.WEB_APP,
pinned=pinned,
sort_by=args["sort_by"],
)
with Session(db.engine) as session:
return WebConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=end_user,
last_id=args["last_id"],
limit=args["limit"],
invoke_from=InvokeFrom.WEB_APP,
pinned=pinned,
sort_by=args["sort_by"],
)
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")

View File

@ -108,7 +108,7 @@ class MessageFeedbackApi(WebApiResource):
args = parser.parse_args()
try:
MessageService.create_feedback(app_model, message_id, end_user, args["rating"])
MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"])
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")

View File

@ -5,6 +5,7 @@ from flask_restful import marshal_with, reqparse
import services
from controllers.common import helpers
from controllers.common.errors import RemoteFileUploadError
from controllers.web.wraps import WebApiResource
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
@ -38,10 +39,14 @@ class RemoteFileUploadApi(WebApiResource):
url = args["url"]
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3)
resp.raise_for_status()
try:
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as e:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
file_info = helpers.guess_file_info_from_response(resp)

View File

@ -4,14 +4,17 @@ import logging
import queue
import re
import threading
from collections.abc import Iterable
from core.app.entities.queue_entities import (
MessageQueueMessage,
QueueAgentMessageEvent,
QueueLLMChunkEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
WorkflowQueueMessage,
)
from core.model_manager import ModelManager
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
@ -21,7 +24,7 @@ class AudioTrunk:
self.status = status
def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str):
if not text_content or text_content.isspace():
return
return model_instance.invoke_tts(
@ -29,13 +32,19 @@ def _invoice_tts(text_content: str, model_instance, tenant_id: str, voice: str):
)
def _process_future(future_queue, audio_queue):
def _process_future(
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None],
audio_queue: queue.Queue[AudioTrunk],
):
while True:
try:
future = future_queue.get()
if future is None:
break
for audio in future.result():
invoke_result = future.result()
if not invoke_result:
continue
for audio in invoke_result:
audio_base64 = base64.b64encode(bytes(audio))
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
except Exception as e:
@ -49,8 +58,8 @@ class AppGeneratorTTSPublisher:
self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id
self.msg_text = ""
self._audio_queue = queue.Queue()
self._msg_queue = queue.Queue()
self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue()
self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
self.match = re.compile(r"[。.!?]")
self.model_manager = ModelManager()
self.model_instance = self.model_manager.get_default_model_instance(
@ -66,14 +75,11 @@ class AppGeneratorTTSPublisher:
self._runtime_thread = threading.Thread(target=self._runtime).start()
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
def publish(self, message):
try:
self._msg_queue.put(message)
except Exception as e:
self.logger.warning(e)
def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
self._msg_queue.put(message)
def _runtime(self):
future_queue = queue.Queue()
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None] = queue.Queue()
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
while True:
try:
@ -110,7 +116,7 @@ class AppGeneratorTTSPublisher:
break
future_queue.put(None)
def check_and_get_audio(self) -> AudioTrunk | None:
def check_and_get_audio(self):
try:
if self._last_audio_event and self._last_audio_event.status == "finish":
if self.executor:

View File

@ -197,11 +197,11 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
stream_response=stream_response,
)
def _listen_audio_msg(self, publisher, task_id: str):
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish":
audio_msg = publisher.check_and_get_audio()
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
@ -222,7 +222,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
@ -512,7 +512,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(message=queue_message)
tts_publisher.publish(queue_message)
self._task_state.answer += delta_text
yield self._message_to_stream_response(

View File

@ -1,7 +1,6 @@
import queue
import time
from abc import abstractmethod
from collections.abc import Generator
from enum import Enum
from typing import Any
@ -11,9 +10,11 @@ from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
MessageQueueMessage,
QueueErrorEvent,
QueuePingEvent,
QueueStopEvent,
WorkflowQueueMessage,
)
from extensions.ext_redis import redis_client
@ -37,11 +38,11 @@ class AppQueueManager:
AppQueueManager._generate_task_belong_cache_key(self._task_id), 1800, f"{user_prefix}-{self._user_id}"
)
q = queue.Queue()
q: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
self._q = q
def listen(self) -> Generator:
def listen(self):
"""
Listen to queue
:return:

View File

@ -171,11 +171,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
yield WorkflowAppStreamResponse(workflow_run_id=workflow_run_id, stream_response=stream_response)
def _listen_audio_msg(self, publisher, task_id: str):
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if not publisher:
return None
audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish":
audio_msg = publisher.check_and_get_audio()
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None
@ -196,7 +196,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
for response in self._process_stream_response(tts_publisher=tts_publisher, trace_manager=trace_manager):
while True:
audio_response = self._listen_audio_msg(tts_publisher, task_id=task_id)
audio_response = self._listen_audio_msg(publisher=tts_publisher, task_id=task_id)
if audio_response:
yield audio_response
else:
@ -421,7 +421,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
# only publish tts message at text chunk streaming
if tts_publisher:
tts_publisher.publish(message=queue_message)
tts_publisher.publish(queue_message)
self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response(

View File

@ -201,11 +201,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
stream_response=stream_response,
)
def _listen_audio_msg(self, publisher, task_id: str):
def _listen_audio_msg(self, publisher: AppGeneratorTTSPublisher | None, task_id: str):
if publisher is None:
return None
audio_msg: AudioTrunk = publisher.check_and_get_audio()
if audio_msg and audio_msg.status != "finish":
audio_msg = publisher.check_and_get_audio()
if audio_msg and isinstance(audio_msg, AudioTrunk) and audio_msg.status != "finish":
# audio_str = audio_msg.audio.decode('utf-8', errors='ignore')
return MessageAudioStreamResponse(audio=audio_msg.audio, task_id=task_id)
return None

View File

@ -508,6 +508,12 @@ class WorkflowCycleManage:
:param workflow_run: workflow run
:return:
"""
# Attach WorkflowRun to an active session so "created_by_role" can be accessed.
workflow_run = db.session.merge(workflow_run)
# Refresh to ensure any expired attributes are fully loaded
db.session.refresh(workflow_run)
created_by = None
if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value:
created_by_account = workflow_run.created_by_account

View File

@ -1,7 +1,7 @@
from typing import Optional
class LLMError(Exception):
class LLMError(ValueError):
"""Base class for all LLM exceptions."""
description: Optional[str] = None
@ -16,7 +16,7 @@ class LLMBadRequestError(LLMError):
description = "Bad Request"
class ProviderTokenNotInitError(Exception):
class ProviderTokenNotInitError(ValueError):
"""
Custom exception raised when the provider token is not initialized.
"""
@ -27,7 +27,7 @@ class ProviderTokenNotInitError(Exception):
self.description = args[0] if args else self.description
class QuotaExceededError(Exception):
class QuotaExceededError(ValueError):
"""
Custom exception raised when the quota for a provider has been exceeded.
"""
@ -35,7 +35,7 @@ class QuotaExceededError(Exception):
description = "Quota Exceeded"
class AppInvokeQuotaExceededError(Exception):
class AppInvokeQuotaExceededError(ValueError):
"""
Custom exception raised when the quota for an app has been exceeded.
"""
@ -43,7 +43,7 @@ class AppInvokeQuotaExceededError(Exception):
description = "App Invoke Quota Exceeded"
class ModelCurrentlyNotSupportError(Exception):
class ModelCurrentlyNotSupportError(ValueError):
"""
Custom exception raised when the model not support
"""
@ -51,7 +51,7 @@ class ModelCurrentlyNotSupportError(Exception):
description = "Model Currently Not Support"
class InvokeRateLimitError(Exception):
class InvokeRateLimitError(ValueError):
"""Raised when the Invoke returns rate limit error."""
description = "Rate Limit Error"

View File

@ -118,7 +118,7 @@ class CodeExecutor:
return response.data.stdout or ""
@classmethod
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]) -> dict:
def execute_workflow_code_template(cls, language: CodeLanguage, code: str, inputs: Mapping[str, Any]):
"""
Execute code
:param language: code language

View File

@ -25,7 +25,7 @@ class TemplateTransformer(ABC):
return runner_script, preload_script
@classmethod
def extract_result_str_from_response(cls, response: str) -> str:
def extract_result_str_from_response(cls, response: str):
result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL)
if not result:
raise ValueError("Failed to parse result")
@ -33,13 +33,21 @@ class TemplateTransformer(ABC):
return result
@classmethod
def transform_response(cls, response: str) -> dict:
def transform_response(cls, response: str) -> Mapping[str, Any]:
"""
Transform response to dict
:param response: response
:return:
"""
return json.loads(cls.extract_result_str_from_response(response))
try:
result = json.loads(cls.extract_result_str_from_response(response))
except json.JSONDecodeError:
raise ValueError("failed to parse response")
if not isinstance(result, dict):
raise ValueError("result must be a dict")
if not all(isinstance(k, str) for k in result):
raise ValueError("result keys must be strings")
return result
@classmethod
@abstractmethod

View File

@ -24,7 +24,7 @@ BACKOFF_FACTOR = 0.5
STATUS_FORCELIST = [429, 500, 502, 503, 504]
class MaxRetriesExceededError(Exception):
class MaxRetriesExceededError(ValueError):
"""Raised when the maximum number of retries is exceeded."""
pass

View File

@ -1,2 +1,2 @@
class OutputParserError(Exception):
class OutputParserError(ValueError):
pass

View File

@ -1,7 +1,7 @@
from typing import Optional
class InvokeError(Exception):
class InvokeError(ValueError):
"""Base class for all LLM exceptions."""
description: Optional[str] = None

View File

@ -1,4 +1,4 @@
class CredentialsValidateFailedError(Exception):
class CredentialsValidateFailedError(ValueError):
"""
Credentials validate failed error
"""

View File

@ -531,7 +531,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
"source": {
"type": "base64",
"media_type": message_content.mime_type,
"data": message_content.data,
"data": message_content.base64_data,
},
}
sub_messages.append(sub_message_dict)

View File

@ -21,6 +21,7 @@ from core.model_runtime.entities.message_entities import (
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
@ -143,7 +144,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
"""
try:
ping_message = SystemPromptMessage(content="ping")
ping_message = UserPromptMessage(content="ping")
self._generate(model, credentials, [ping_message], {"max_output_tokens": 5})
except Exception as ex:
@ -187,17 +188,23 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
config_kwargs["stop_sequences"] = stop
genai.configure(api_key=credentials["google_api_key"])
google_model = genai.GenerativeModel(model_name=model)
history = []
system_instruction = None
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
elif content["role"] == "system":
system_instruction = content["parts"][0]
else:
history.append(content)
if not history:
raise InvokeError("The user prompt message is required. You only add a system prompt message.")
google_model = genai.GenerativeModel(model_name=model, system_instruction=system_instruction)
response = google_model.generate_content(
contents=history,
generation_config=genai.types.GenerationConfig(**config_kwargs),
@ -404,7 +411,10 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
)
return glm_content
elif isinstance(message, SystemPromptMessage):
return {"role": "user", "parts": [to_part(message.content)]}
if isinstance(message.content, list):
text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content)
message.content = "".join(c.data for c in text_contents)
return {"role": "system", "parts": [to_part(message.content)]}
elif isinstance(message, ToolPromptMessage):
return {
"role": "function",

View File

@ -355,7 +355,13 @@ class TraceTask:
def conversation_trace(self, **kwargs):
return kwargs
def workflow_trace(self, workflow_run: WorkflowRun, conversation_id, user_id):
def workflow_trace(self, workflow_run: WorkflowRun | None, conversation_id, user_id):
if not workflow_run:
raise ValueError("Workflow run not found")
db.session.merge(workflow_run)
db.sessoin.refresh(workflow_run)
workflow_id = workflow_run.workflow_id
tenant_id = workflow_run.tenant_id
workflow_run_id = workflow_run.id

View File

@ -83,11 +83,15 @@ class DataPostProcessor:
if reranking_model:
try:
model_manager = ModelManager()
reranking_provider_name = reranking_model.get("reranking_provider_name")
reranking_model_name = reranking_model.get("reranking_model_name")
if not reranking_provider_name or not reranking_model_name:
return None
rerank_model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
provider=reranking_model["reranking_provider_name"],
provider=reranking_provider_name,
model_type=ModelType.RERANK,
model=reranking_model["reranking_model_name"],
model=reranking_model_name,
)
return rerank_model_instance
except InvokeAuthorizationError:

View File

@ -103,7 +103,7 @@ class RetrievalService:
if exceptions:
exception_message = ";\n".join(exceptions)
raise Exception(exception_message)
raise ValueError(exception_message)
if retrieval_method == RetrievalMethod.HYBRID_SEARCH.value:
data_post_processor = DataPostProcessor(

View File

@ -62,7 +62,7 @@ class AppToolProviderEntity(ToolProviderController):
user_input_form_list = app_model_config.user_input_form_list
for input_form in user_input_form_list:
# get type
form_type = input_form.keys()[0]
form_type = list(input_form.keys())[0]
default = input_form[form_type]["default"]
required = input_form[form_type]["required"]
label = input_form[form_type]["label"]

View File

@ -0,0 +1,115 @@
import json
import operator
from typing import Any, Optional, Union
import boto3
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class BedrockRetrieveTool(BuiltinTool):
bedrock_client: Any = None
knowledge_base_id: str = None
topk: int = None
def _bedrock_retrieve(
self, query_input: str, knowledge_base_id: str, num_results: int, metadata_filter: Optional[dict] = None
):
try:
retrieval_query = {"text": query_input}
retrieval_configuration = {"vectorSearchConfiguration": {"numberOfResults": num_results}}
# 如果有元数据过滤条件,则添加到检索配置中
if metadata_filter:
retrieval_configuration["vectorSearchConfiguration"]["filter"] = metadata_filter
response = self.bedrock_client.retrieve(
knowledgeBaseId=knowledge_base_id,
retrievalQuery=retrieval_query,
retrievalConfiguration=retrieval_configuration,
)
results = []
for result in response.get("retrievalResults", []):
results.append(
{
"content": result.get("content", {}).get("text", ""),
"score": result.get("score", 0.0),
"metadata": result.get("metadata", {}),
}
)
return results
except Exception as e:
raise Exception(f"Error retrieving from knowledge base: {str(e)}")
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
line = 0
try:
if not self.bedrock_client:
aws_region = tool_parameters.get("aws_region")
if aws_region:
self.bedrock_client = boto3.client("bedrock-agent-runtime", region_name=aws_region)
else:
self.bedrock_client = boto3.client("bedrock-agent-runtime")
line = 1
if not self.knowledge_base_id:
self.knowledge_base_id = tool_parameters.get("knowledge_base_id")
if not self.knowledge_base_id:
return self.create_text_message("Please provide knowledge_base_id")
line = 2
if not self.topk:
self.topk = tool_parameters.get("topk", 5)
line = 3
query = tool_parameters.get("query", "")
if not query:
return self.create_text_message("Please input query")
# 获取元数据过滤条件(如果存在)
metadata_filter_str = tool_parameters.get("metadata_filter")
metadata_filter = json.loads(metadata_filter_str) if metadata_filter_str else None
line = 4
retrieved_docs = self._bedrock_retrieve(
query_input=query,
knowledge_base_id=self.knowledge_base_id,
num_results=self.topk,
metadata_filter=metadata_filter, # 将元数据过滤条件传递给检索方法
)
line = 5
# Sort results by score in descending order
sorted_docs = sorted(retrieved_docs, key=operator.itemgetter("score"), reverse=True)
line = 6
return [self.create_json_message(res) for res in sorted_docs]
except Exception as e:
return self.create_text_message(f"Exception {str(e)}, line : {line}")
def validate_parameters(self, parameters: dict[str, Any]) -> None:
"""
Validate the parameters
"""
if not parameters.get("knowledge_base_id"):
raise ValueError("knowledge_base_id is required")
if not parameters.get("query"):
raise ValueError("query is required")
# 可选:可以验证元数据过滤条件是否为有效的 JSON 字符串(如果提供)
metadata_filter_str = parameters.get("metadata_filter")
if metadata_filter_str and not isinstance(json.loads(metadata_filter_str), dict):
raise ValueError("metadata_filter must be a valid JSON object")

View File

@ -0,0 +1,87 @@
identity:
name: bedrock_retrieve
author: AWS
label:
en_US: Bedrock Retrieve
zh_Hans: Bedrock检索
pt_BR: Bedrock Retrieve
icon: icon.svg
description:
human:
en_US: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool
zh_Hans: Amazon Bedrock知识库检索工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署说明
pt_BR: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base.
llm: A tool for retrieving relevant information from Amazon Bedrock Knowledge Base. You can find deploy instructions on Github Repo - https://github.com/aws-samples/dify-aws-tool
parameters:
- name: knowledge_base_id
type: string
required: true
label:
en_US: Bedrock Knowledge Base ID
zh_Hans: Bedrock知识库ID
pt_BR: Bedrock Knowledge Base ID
human_description:
en_US: ID of the Bedrock Knowledge Base to retrieve from
zh_Hans: 用于检索的Bedrock知识库ID
pt_BR: ID of the Bedrock Knowledge Base to retrieve from
llm_description: ID of the Bedrock Knowledge Base to retrieve from
form: form
- name: query
type: string
required: true
label:
en_US: Query string
zh_Hans: 查询语句
pt_BR: Query string
human_description:
en_US: The search query to retrieve relevant information
zh_Hans: 用于检索相关信息的查询语句
pt_BR: The search query to retrieve relevant information
llm_description: The search query to retrieve relevant information
form: llm
- name: topk
type: number
required: false
form: form
label:
en_US: Limit for results count
zh_Hans: 返回结果数量限制
pt_BR: Limit for results count
human_description:
en_US: Maximum number of results to return
zh_Hans: 最大返回结果数量
pt_BR: Maximum number of results to return
min: 1
max: 10
default: 5
- name: aws_region
type: string
required: false
label:
en_US: AWS Region
zh_Hans: AWS 区域
pt_BR: AWS Region
human_description:
en_US: AWS region where the Bedrock Knowledge Base is located
zh_Hans: Bedrock知识库所在的AWS区域
pt_BR: AWS region where the Bedrock Knowledge Base is located
llm_description: AWS region where the Bedrock Knowledge Base is located
form: form
- name: metadata_filter
type: string
required: false
label:
en_US: Metadata Filter
zh_Hans: 元数据过滤器
pt_BR: Metadata Filter
human_description:
en_US: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})'
zh_Hans: '元数据的JSON格式过滤条件例如{{"greaterThan": {"key: "aaa", "value": 10}}'
pt_BR: 'JSON formatted filter conditions for metadata (e.g., {"greaterThan": {"key: "aaa", "value": 10}})'
form: form

View File

@ -0,0 +1,357 @@
import base64
import json
import logging
import re
from datetime import datetime
from typing import Any, Union
from urllib.parse import urlparse
import boto3
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool.builtin_tool import BuiltinTool
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class NovaCanvasTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
Invoke AWS Bedrock Nova Canvas model for image generation
"""
# Get common parameters
prompt = tool_parameters.get("prompt", "")
image_output_s3uri = tool_parameters.get("image_output_s3uri", "").strip()
if not prompt:
return self.create_text_message("Please provide a text prompt for image generation.")
if not image_output_s3uri or urlparse(image_output_s3uri).scheme != "s3":
return self.create_text_message("Please provide an valid S3 URI for image output.")
task_type = tool_parameters.get("task_type", "TEXT_IMAGE")
aws_region = tool_parameters.get("aws_region", "us-east-1")
# Get common image generation config parameters
width = tool_parameters.get("width", 1024)
height = tool_parameters.get("height", 1024)
cfg_scale = tool_parameters.get("cfg_scale", 8.0)
negative_prompt = tool_parameters.get("negative_prompt", "")
seed = tool_parameters.get("seed", 0)
quality = tool_parameters.get("quality", "standard")
# Handle S3 image if provided
image_input_s3uri = tool_parameters.get("image_input_s3uri", "")
if task_type != "TEXT_IMAGE":
if not image_input_s3uri or urlparse(image_input_s3uri).scheme != "s3":
return self.create_text_message("Please provide a valid S3 URI for image to image generation.")
# Parse S3 URI
parsed_uri = urlparse(image_input_s3uri)
bucket = parsed_uri.netloc
key = parsed_uri.path.lstrip("/")
# Initialize S3 client and download image
s3_client = boto3.client("s3")
response = s3_client.get_object(Bucket=bucket, Key=key)
image_data = response["Body"].read()
# Base64 encode the image
input_image = base64.b64encode(image_data).decode("utf-8")
try:
# Initialize Bedrock client
bedrock = boto3.client(service_name="bedrock-runtime", region_name=aws_region)
# Base image generation config
image_generation_config = {
"width": width,
"height": height,
"cfgScale": cfg_scale,
"seed": seed,
"numberOfImages": 1,
"quality": quality,
}
# Prepare request body based on task type
body = {"imageGenerationConfig": image_generation_config}
if task_type == "TEXT_IMAGE":
body["taskType"] = "TEXT_IMAGE"
body["textToImageParams"] = {"text": prompt}
if negative_prompt:
body["textToImageParams"]["negativeText"] = negative_prompt
elif task_type == "COLOR_GUIDED_GENERATION":
colors = tool_parameters.get("colors", "#ff8080-#ffb280-#ffe680-#ffe680")
if not self._validate_color_string(colors):
return self.create_text_message("Please provide valid colors in hexadecimal format.")
body["taskType"] = "COLOR_GUIDED_GENERATION"
body["colorGuidedGenerationParams"] = {
"colors": colors.split("-"),
"referenceImage": input_image,
"text": prompt,
}
if negative_prompt:
body["colorGuidedGenerationParams"]["negativeText"] = negative_prompt
elif task_type == "IMAGE_VARIATION":
similarity_strength = tool_parameters.get("similarity_strength", 0.5)
body["taskType"] = "IMAGE_VARIATION"
body["imageVariationParams"] = {
"images": [input_image],
"similarityStrength": similarity_strength,
"text": prompt,
}
if negative_prompt:
body["imageVariationParams"]["negativeText"] = negative_prompt
elif task_type == "INPAINTING":
mask_prompt = tool_parameters.get("mask_prompt")
if not mask_prompt:
return self.create_text_message("Please provide a mask prompt for image inpainting.")
body["taskType"] = "INPAINTING"
body["inPaintingParams"] = {"image": input_image, "maskPrompt": mask_prompt, "text": prompt}
if negative_prompt:
body["inPaintingParams"]["negativeText"] = negative_prompt
elif task_type == "OUTPAINTING":
mask_prompt = tool_parameters.get("mask_prompt")
if not mask_prompt:
return self.create_text_message("Please provide a mask prompt for image outpainting.")
outpainting_mode = tool_parameters.get("outpainting_mode", "DEFAULT")
body["taskType"] = "OUTPAINTING"
body["outPaintingParams"] = {
"image": input_image,
"maskPrompt": mask_prompt,
"outPaintingMode": outpainting_mode,
"text": prompt,
}
if negative_prompt:
body["outPaintingParams"]["negativeText"] = negative_prompt
elif task_type == "BACKGROUND_REMOVAL":
body["taskType"] = "BACKGROUND_REMOVAL"
body["backgroundRemovalParams"] = {"image": input_image}
else:
return self.create_text_message(f"Unsupported task type: {task_type}")
# Call Nova Canvas model
response = bedrock.invoke_model(
body=json.dumps(body),
modelId="amazon.nova-canvas-v1:0",
accept="application/json",
contentType="application/json",
)
# Process response
response_body = json.loads(response.get("body").read())
if response_body.get("error"):
raise Exception(f"Error in model response: {response_body.get('error')}")
base64_image = response_body.get("images")[0]
# Upload to S3 if image_output_s3uri is provided
try:
# Parse S3 URI for output
parsed_uri = urlparse(image_output_s3uri)
output_bucket = parsed_uri.netloc
output_base_path = parsed_uri.path.lstrip("/")
# Generate filename with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_key = f"{output_base_path}/canvas-output-{timestamp}.png"
# Initialize S3 client if not already done
s3_client = boto3.client("s3", region_name=aws_region)
# Decode base64 image and upload to S3
image_data = base64.b64decode(base64_image)
s3_client.put_object(Bucket=output_bucket, Key=output_key, Body=image_data, ContentType="image/png")
logger.info(f"Image uploaded to s3://{output_bucket}/{output_key}")
except Exception as e:
logger.exception("Failed to upload image to S3")
# Return image
return [
self.create_text_message(f"Image is available at: s3://{output_bucket}/{output_key}"),
self.create_blob_message(
blob=base64.b64decode(base64_image),
meta={"mime_type": "image/png"},
save_as=self.VariableKey.IMAGE.value,
),
]
except Exception as e:
return self.create_text_message(f"Failed to generate image: {str(e)}")
def _validate_color_string(self, color_string) -> bool:
color_pattern = r"^#[0-9a-fA-F]{6}(?:-#[0-9a-fA-F]{6})*$"
if re.match(color_pattern, color_string):
return True
return False
def get_runtime_parameters(self) -> list[ToolParameter]:
parameters = [
ToolParameter(
name="prompt",
label=I18nObject(en_US="Prompt", zh_Hans="提示词"),
type=ToolParameter.ToolParameterType.STRING,
required=True,
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(
en_US="Text description of the image you want to generate or modify",
zh_Hans="您想要生成或修改的图像的文本描述",
),
llm_description="Describe the image you want to generate or how you want to modify the input image",
),
ToolParameter(
name="image_input_s3uri",
label=I18nObject(en_US="Input image s3 uri", zh_Hans="输入图片的s3 uri"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(en_US="Image to be modified", zh_Hans="想要修改的图片"),
),
ToolParameter(
name="image_output_s3uri",
label=I18nObject(en_US="Output Image S3 URI", zh_Hans="输出图片的S3 URI目录"),
type=ToolParameter.ToolParameterType.STRING,
required=True,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="S3 URI where the generated image should be uploaded", zh_Hans="生成的图像应该上传到的S3 URI"
),
),
ToolParameter(
name="width",
label=I18nObject(en_US="Width", zh_Hans="宽度"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=1024,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="Width of the generated image", zh_Hans="生成图像的宽度"),
),
ToolParameter(
name="height",
label=I18nObject(en_US="Height", zh_Hans="高度"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=1024,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="Height of the generated image", zh_Hans="生成图像的高度"),
),
ToolParameter(
name="cfg_scale",
label=I18nObject(en_US="CFG Scale", zh_Hans="CFG比例"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=8.0,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="How strongly the image should conform to the prompt", zh_Hans="图像应该多大程度上符合提示词"
),
),
ToolParameter(
name="negative_prompt",
label=I18nObject(en_US="Negative Prompt", zh_Hans="负面提示词"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
default="",
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(
en_US="Things you don't want in the generated image", zh_Hans="您不想在生成的图像中出现的内容"
),
),
ToolParameter(
name="seed",
label=I18nObject(en_US="Seed", zh_Hans="种子值"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=0,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="Random seed for image generation", zh_Hans="图像生成的随机种子"),
),
ToolParameter(
name="aws_region",
label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
default="us-east-1",
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"),
),
ToolParameter(
name="task_type",
label=I18nObject(en_US="Task Type", zh_Hans="任务类型"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
default="TEXT_IMAGE",
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(en_US="Type of image generation task", zh_Hans="图像生成任务的类型"),
),
ToolParameter(
name="quality",
label=I18nObject(en_US="Quality", zh_Hans="质量"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
default="standard",
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="Quality of the generated image (standard or premium)", zh_Hans="生成图像的质量(标准或高级)"
),
),
ToolParameter(
name="colors",
label=I18nObject(en_US="Colors", zh_Hans="颜色"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="List of colors for color-guided generation, example: #ff8080-#ffb280-#ffe680-#ffe680",
zh_Hans="颜色引导生成的颜色列表, 例子: #ff8080-#ffb280-#ffe680-#ffe680",
),
),
ToolParameter(
name="similarity_strength",
label=I18nObject(en_US="Similarity Strength", zh_Hans="相似度强度"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=0.5,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="How similar the generated image should be to the input image (0.0 to 1.0)",
zh_Hans="生成的图像应该与输入图像的相似程度0.0到1.0",
),
),
ToolParameter(
name="mask_prompt",
label=I18nObject(en_US="Mask Prompt", zh_Hans="蒙版提示词"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(
en_US="Text description to generate mask for inpainting/outpainting",
zh_Hans="用于生成内补绘制/外补绘制蒙版的文本描述",
),
),
ToolParameter(
name="outpainting_mode",
label=I18nObject(en_US="Outpainting Mode", zh_Hans="外补绘制模式"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
default="DEFAULT",
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="Mode for outpainting (DEFAULT or other supported modes)",
zh_Hans="外补绘制的模式DEFAULT或其他支持的模式",
),
),
]
return parameters

View File

@ -0,0 +1,175 @@
identity:
name: nova_canvas
author: AWS
label:
en_US: AWS Bedrock Nova Canvas
zh_Hans: AWS Bedrock Nova Canvas
icon: icon.svg
description:
human:
en_US: A tool for generating and modifying images using AWS Bedrock's Nova Canvas model. Supports text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html
zh_Hans: 使用 AWS Bedrock 的 Nova Canvas 模型生成和修改图像的工具。支持文生图、颜色引导生成、图像变体、内补绘制、外补绘制和背景移除功能, 输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/image-gen-req-resp-structure.html。
llm: Generate or modify images using AWS Bedrock's Nova Canvas model with multiple task types including text-to-image, color-guided generation, image variation, inpainting, outpainting, and background removal.
parameters:
- name: task_type
type: string
required: false
default: TEXT_IMAGE
label:
en_US: Task Type
zh_Hans: 任务类型
human_description:
en_US: Type of image generation task (TEXT_IMAGE, COLOR_GUIDED_GENERATION, IMAGE_VARIATION, INPAINTING, OUTPAINTING, BACKGROUND_REMOVAL)
zh_Hans: 图像生成任务的类型(文生图、颜色引导生成、图像变体、内补绘制、外补绘制、背景移除)
form: llm
- name: prompt
type: string
required: true
label:
en_US: Prompt
zh_Hans: 提示词
human_description:
en_US: Text description of the image you want to generate or modify
zh_Hans: 您想要生成或修改的图像的文本描述
llm_description: Describe the image you want to generate or how you want to modify the input image
form: llm
- name: image_input_s3uri
type: string
required: false
label:
en_US: Input image s3 uri
zh_Hans: 输入图片的s3 uri
human_description:
en_US: The input image to modify (required for all modes except TEXT_IMAGE)
zh_Hans: 要修改的输入图像(除文生图外的所有模式都需要)
llm_description: The input image you want to modify. Required for all modes except TEXT_IMAGE.
form: llm
- name: image_output_s3uri
type: string
required: true
label:
en_US: Output S3 URI
zh_Hans: 输出S3 URI
human_description:
en_US: The S3 URI where the generated image will be saved. If provided, the image will be uploaded with name format canvas-output-{timestamp}.png
zh_Hans: 生成的图像将保存到的S3 URI。如果提供图像将以canvas-output-{timestamp}.png的格式上传
llm_description: Optional S3 URI where the generated image will be uploaded. The image will be saved with a timestamp-based filename.
form: form
- name: negative_prompt
type: string
required: false
label:
en_US: Negative Prompt
zh_Hans: 负面提示词
human_description:
en_US: Things you don't want in the generated image
zh_Hans: 您不想在生成的图像中出现的内容
form: llm
- name: width
type: number
required: false
label:
en_US: Width
zh_Hans: 宽度
human_description:
en_US: Width of the generated image
zh_Hans: 生成图像的宽度
form: form
default: 1024
- name: height
type: number
required: false
label:
en_US: Height
zh_Hans: 高度
human_description:
en_US: Height of the generated image
zh_Hans: 生成图像的高度
form: form
default: 1024
- name: cfg_scale
type: number
required: false
label:
en_US: CFG Scale
zh_Hans: CFG比例
human_description:
en_US: How strongly the image should conform to the prompt
zh_Hans: 图像应该多大程度上符合提示词
form: form
default: 8.0
- name: seed
type: number
required: false
label:
en_US: Seed
zh_Hans: 种子值
human_description:
en_US: Random seed for image generation
zh_Hans: 图像生成的随机种子
form: form
default: 0
- name: aws_region
type: string
required: false
default: us-east-1
label:
en_US: AWS Region
zh_Hans: AWS 区域
human_description:
en_US: AWS region for Bedrock service
zh_Hans: Bedrock 服务的 AWS 区域
form: form
- name: quality
type: string
required: false
default: standard
label:
en_US: Quality
zh_Hans: 质量
human_description:
en_US: Quality of the generated image (standard or premium)
zh_Hans: 生成图像的质量(标准或高级)
form: form
- name: colors
type: string
required: false
label:
en_US: Colors
zh_Hans: 颜色
human_description:
en_US: List of colors for color-guided generation
zh_Hans: 颜色引导生成的颜色列表
form: form
- name: similarity_strength
type: number
required: false
default: 0.5
label:
en_US: Similarity Strength
zh_Hans: 相似度强度
human_description:
en_US: How similar the generated image should be to the input image (0.0 to 1.0)
zh_Hans: 生成的图像应该与输入图像的相似程度0.0到1.0
form: form
- name: mask_prompt
type: string
required: false
label:
en_US: Mask Prompt
zh_Hans: 蒙版提示词
human_description:
en_US: Text description to generate mask for inpainting/outpainting
zh_Hans: 用于生成内补绘制/外补绘制蒙版的文本描述
form: llm
- name: outpainting_mode
type: string
required: false
default: DEFAULT
label:
en_US: Outpainting Mode
zh_Hans: 外补绘制模式
human_description:
en_US: Mode for outpainting (DEFAULT or other supported modes)
zh_Hans: 外补绘制的模式DEFAULT或其他支持的模式
form: form

View File

@ -0,0 +1,371 @@
import base64
import logging
import time
from io import BytesIO
from typing import Any, Optional, Union
from urllib.parse import urlparse
import boto3
from botocore.exceptions import ClientError
from PIL import Image
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool.builtin_tool import BuiltinTool
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
NOVA_REEL_DEFAULT_REGION = "us-east-1"
NOVA_REEL_DEFAULT_DIMENSION = "1280x720"
NOVA_REEL_DEFAULT_FPS = 24
NOVA_REEL_DEFAULT_DURATION = 6
NOVA_REEL_MODEL_ID = "amazon.nova-reel-v1:0"
NOVA_REEL_STATUS_CHECK_INTERVAL = 5
# Image requirements
NOVA_REEL_REQUIRED_IMAGE_WIDTH = 1280
NOVA_REEL_REQUIRED_IMAGE_HEIGHT = 720
NOVA_REEL_REQUIRED_IMAGE_MODE = "RGB"
class NovaReelTool(BuiltinTool):
def _invoke(
self, user_id: str, tool_parameters: dict[str, Any]
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
Invoke AWS Bedrock Nova Reel model for video generation.
Args:
user_id: The ID of the user making the request
tool_parameters: Dictionary containing the tool parameters
Returns:
ToolInvokeMessage containing either the video content or status information
"""
try:
# Validate and extract parameters
params = self._validate_and_extract_parameters(tool_parameters)
if isinstance(params, ToolInvokeMessage):
return params
# Initialize AWS clients
bedrock, s3_client = self._initialize_aws_clients(params["aws_region"])
# Prepare model input
model_input = self._prepare_model_input(params, s3_client)
if isinstance(model_input, ToolInvokeMessage):
return model_input
# Start video generation
invocation = self._start_video_generation(bedrock, model_input, params["video_output_s3uri"])
invocation_arn = invocation["invocationArn"]
# Handle async/sync mode
return self._handle_generation_mode(bedrock, s3_client, invocation_arn, params["async_mode"])
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code", "Unknown")
error_message = e.response.get("Error", {}).get("Message", str(e))
logger.exception(f"AWS API error: {error_code} - {error_message}")
return self.create_text_message(f"AWS service error: {error_code} - {error_message}")
except Exception as e:
logger.error(f"Unexpected error in video generation: {str(e)}", exc_info=True)
return self.create_text_message(f"Failed to generate video: {str(e)}")
def _validate_and_extract_parameters(
self, tool_parameters: dict[str, Any]
) -> Union[dict[str, Any], ToolInvokeMessage]:
"""Validate and extract parameters from the input dictionary."""
prompt = tool_parameters.get("prompt", "")
video_output_s3uri = tool_parameters.get("video_output_s3uri", "").strip()
# Validate required parameters
if not prompt:
return self.create_text_message("Please provide a text prompt for video generation.")
if not video_output_s3uri:
return self.create_text_message("Please provide an S3 URI for video output.")
# Validate S3 URI format
if not video_output_s3uri.startswith("s3://"):
return self.create_text_message("Invalid S3 URI format. Must start with 's3://'")
# Ensure S3 URI ends with '/'
video_output_s3uri = video_output_s3uri if video_output_s3uri.endswith("/") else video_output_s3uri + "/"
return {
"prompt": prompt,
"video_output_s3uri": video_output_s3uri,
"image_input_s3uri": tool_parameters.get("image_input_s3uri", "").strip(),
"aws_region": tool_parameters.get("aws_region", NOVA_REEL_DEFAULT_REGION),
"dimension": tool_parameters.get("dimension", NOVA_REEL_DEFAULT_DIMENSION),
"seed": int(tool_parameters.get("seed", 0)),
"fps": int(tool_parameters.get("fps", NOVA_REEL_DEFAULT_FPS)),
"duration": int(tool_parameters.get("duration", NOVA_REEL_DEFAULT_DURATION)),
"async_mode": bool(tool_parameters.get("async", True)),
}
def _initialize_aws_clients(self, region: str) -> tuple[Any, Any]:
"""Initialize AWS Bedrock and S3 clients."""
bedrock = boto3.client(service_name="bedrock-runtime", region_name=region)
s3_client = boto3.client("s3", region_name=region)
return bedrock, s3_client
def _prepare_model_input(self, params: dict[str, Any], s3_client: Any) -> Union[dict[str, Any], ToolInvokeMessage]:
"""Prepare the input for the Nova Reel model."""
model_input = {
"taskType": "TEXT_VIDEO",
"textToVideoParams": {"text": params["prompt"]},
"videoGenerationConfig": {
"durationSeconds": params["duration"],
"fps": params["fps"],
"dimension": params["dimension"],
"seed": params["seed"],
},
}
# Add image if provided
if params["image_input_s3uri"]:
try:
image_data = self._get_image_from_s3(s3_client, params["image_input_s3uri"])
if not image_data:
return self.create_text_message("Failed to retrieve image from S3")
# Process and validate image
processed_image = self._process_and_validate_image(image_data)
if isinstance(processed_image, ToolInvokeMessage):
return processed_image
# Convert processed image to base64
img_buffer = BytesIO()
processed_image.save(img_buffer, format="PNG")
img_buffer.seek(0)
input_image_base64 = base64.b64encode(img_buffer.getvalue()).decode("utf-8")
model_input["textToVideoParams"]["images"] = [
{"format": "png", "source": {"bytes": input_image_base64}}
]
except Exception as e:
logger.error(f"Error processing input image: {str(e)}", exc_info=True)
return self.create_text_message(f"Failed to process input image: {str(e)}")
return model_input
def _process_and_validate_image(self, image_data: bytes) -> Union[Image.Image, ToolInvokeMessage]:
"""
Process and validate the input image according to Nova Reel requirements.
Requirements:
- Must be 1280x720 pixels
- Must be RGB format (8 bits per channel)
- If PNG, alpha channel must not have transparent/translucent pixels
"""
try:
# Open image
img = Image.open(BytesIO(image_data))
# Convert RGBA to RGB if needed, ensuring no transparency
if img.mode == "RGBA":
# Check for transparency
if img.getchannel("A").getextrema()[0] < 255:
return self.create_text_message(
"PNG image contains transparent or translucent pixels, which is not supported. "
"Please provide an image without transparency."
)
# Convert to RGB
img = img.convert("RGB")
elif img.mode != "RGB":
# Convert any other mode to RGB
img = img.convert("RGB")
# Validate/adjust dimensions
if img.size != (NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT):
logger.warning(
f"Image dimensions {img.size} do not match required dimensions "
f"({NOVA_REEL_REQUIRED_IMAGE_WIDTH}x{NOVA_REEL_REQUIRED_IMAGE_HEIGHT}). Resizing..."
)
img = img.resize(
(NOVA_REEL_REQUIRED_IMAGE_WIDTH, NOVA_REEL_REQUIRED_IMAGE_HEIGHT), Image.Resampling.LANCZOS
)
# Validate bit depth
if img.mode != NOVA_REEL_REQUIRED_IMAGE_MODE:
return self.create_text_message(
f"Image must be in {NOVA_REEL_REQUIRED_IMAGE_MODE} mode with 8 bits per channel"
)
return img
except Exception as e:
logger.error(f"Error processing image: {str(e)}", exc_info=True)
return self.create_text_message(
"Failed to process image. Please ensure the image is a valid JPEG or PNG file."
)
def _get_image_from_s3(self, s3_client: Any, s3_uri: str) -> Optional[bytes]:
"""Download and return image data from S3."""
parsed_uri = urlparse(s3_uri)
bucket = parsed_uri.netloc
key = parsed_uri.path.lstrip("/")
response = s3_client.get_object(Bucket=bucket, Key=key)
return response["Body"].read()
def _start_video_generation(self, bedrock: Any, model_input: dict[str, Any], output_s3uri: str) -> dict[str, Any]:
"""Start the async video generation process."""
return bedrock.start_async_invoke(
modelId=NOVA_REEL_MODEL_ID,
modelInput=model_input,
outputDataConfig={"s3OutputDataConfig": {"s3Uri": output_s3uri}},
)
def _handle_generation_mode(
self, bedrock: Any, s3_client: Any, invocation_arn: str, async_mode: bool
) -> ToolInvokeMessage:
"""Handle async or sync video generation mode."""
invocation_response = bedrock.get_async_invoke(invocationArn=invocation_arn)
video_path = invocation_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"]
video_uri = f"{video_path}/output.mp4"
if async_mode:
return self.create_text_message(
f"Video generation started.\nInvocation ARN: {invocation_arn}\n"
f"Video will be available at: {video_uri}"
)
return self._wait_for_completion(bedrock, s3_client, invocation_arn)
def _wait_for_completion(self, bedrock: Any, s3_client: Any, invocation_arn: str) -> ToolInvokeMessage:
"""Wait for video generation completion and handle the result."""
while True:
status_response = bedrock.get_async_invoke(invocationArn=invocation_arn)
status = status_response["status"]
video_path = status_response["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"]
if status == "Completed":
return self._handle_completed_video(s3_client, video_path)
elif status == "Failed":
failure_message = status_response.get("failureMessage", "Unknown error")
return self.create_text_message(f"Video generation failed.\nError: {failure_message}")
elif status == "InProgress":
time.sleep(NOVA_REEL_STATUS_CHECK_INTERVAL)
else:
return self.create_text_message(f"Unexpected status: {status}")
def _handle_completed_video(self, s3_client: Any, video_path: str) -> ToolInvokeMessage:
"""Handle completed video generation and return the result."""
parsed_uri = urlparse(video_path)
bucket = parsed_uri.netloc
key = parsed_uri.path.lstrip("/") + "/output.mp4"
try:
response = s3_client.get_object(Bucket=bucket, Key=key)
video_content = response["Body"].read()
return [
self.create_text_message(f"Video is available at: {video_path}/output.mp4"),
self.create_blob_message(blob=video_content, meta={"mime_type": "video/mp4"}, save_as="output.mp4"),
]
except Exception as e:
logger.error(f"Error downloading video: {str(e)}", exc_info=True)
return self.create_text_message(
f"Video generation completed but failed to download video: {str(e)}\n"
f"Video is available at: s3://{bucket}/{key}"
)
def get_runtime_parameters(self) -> list[ToolParameter]:
"""Define the tool's runtime parameters."""
parameters = [
ToolParameter(
name="prompt",
label=I18nObject(en_US="Prompt", zh_Hans="提示词"),
type=ToolParameter.ToolParameterType.STRING,
required=True,
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(
en_US="Text description of the video you want to generate", zh_Hans="您想要生成的视频的文本描述"
),
llm_description="Describe the video you want to generate",
),
ToolParameter(
name="video_output_s3uri",
label=I18nObject(en_US="Output S3 URI", zh_Hans="输出S3 URI"),
type=ToolParameter.ToolParameterType.STRING,
required=True,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="S3 URI where the generated video will be stored", zh_Hans="生成的视频将存储的S3 URI"
),
),
ToolParameter(
name="dimension",
label=I18nObject(en_US="Dimension", zh_Hans="尺寸"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
default=NOVA_REEL_DEFAULT_DIMENSION,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="Video dimensions (width x height)", zh_Hans="视频尺寸(宽 x 高)"),
),
ToolParameter(
name="duration",
label=I18nObject(en_US="Duration", zh_Hans="时长"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=NOVA_REEL_DEFAULT_DURATION,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="Video duration in seconds", zh_Hans="视频时长(秒)"),
),
ToolParameter(
name="seed",
label=I18nObject(en_US="Seed", zh_Hans="种子值"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=0,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="Random seed for video generation", zh_Hans="视频生成的随机种子"),
),
ToolParameter(
name="fps",
label=I18nObject(en_US="FPS", zh_Hans="帧率"),
type=ToolParameter.ToolParameterType.NUMBER,
required=False,
default=NOVA_REEL_DEFAULT_FPS,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(
en_US="Frames per second for the generated video", zh_Hans="生成视频的每秒帧数"
),
),
ToolParameter(
name="async",
label=I18nObject(en_US="Async Mode", zh_Hans="异步模式"),
type=ToolParameter.ToolParameterType.BOOLEAN,
required=False,
default=True,
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(
en_US="Whether to run in async mode (return immediately) or sync mode (wait for completion)",
zh_Hans="是否以异步模式运行(立即返回)或同步模式(等待完成)",
),
),
ToolParameter(
name="aws_region",
label=I18nObject(en_US="AWS Region", zh_Hans="AWS 区域"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
default=NOVA_REEL_DEFAULT_REGION,
form=ToolParameter.ToolParameterForm.FORM,
human_description=I18nObject(en_US="AWS region for Bedrock service", zh_Hans="Bedrock 服务的 AWS 区域"),
),
ToolParameter(
name="image_input_s3uri",
label=I18nObject(en_US="Input Image S3 URI", zh_Hans="输入图像S3 URI"),
type=ToolParameter.ToolParameterType.STRING,
required=False,
form=ToolParameter.ToolParameterForm.LLM,
human_description=I18nObject(
en_US="S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame",
zh_Hans="用作第一帧的输入图像1280x720 JPEG/PNG的S3 URI",
),
),
]
return parameters

View File

@ -0,0 +1,124 @@
identity:
name: nova_reel
author: AWS
label:
en_US: AWS Bedrock Nova Reel
zh_Hans: AWS Bedrock Nova Reel
icon: icon.svg
description:
human:
en_US: A tool for generating videos using AWS Bedrock's Nova Reel model. Supports text-to-video generation and image-to-video generation with customizable parameters like duration, FPS, and dimensions. Input parameters reference https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html
zh_Hans: 使用 AWS Bedrock 的 Nova Reel 模型生成视频的工具。支持文本生成视频和图像生成视频功能,可自定义持续时间、帧率和尺寸等参数。输入参数参考 https://docs.aws.amazon.com/nova/latest/userguide/video-generation.html
llm: Generate videos using AWS Bedrock's Nova Reel model with support for both text-to-video and image-to-video generation, allowing customization of video properties like duration, frame rate, and resolution.
parameters:
- name: prompt
type: string
required: true
label:
en_US: Prompt
zh_Hans: 提示词
human_description:
en_US: Text description of the video you want to generate
zh_Hans: 您想要生成的视频的文本描述
llm_description: Describe the video you want to generate
form: llm
- name: video_output_s3uri
type: string
required: true
label:
en_US: Output S3 URI
zh_Hans: 输出S3 URI
human_description:
en_US: S3 URI where the generated video will be stored
zh_Hans: 生成的视频将存储的S3 URI
form: form
- name: dimension
type: string
required: false
default: 1280x720
label:
en_US: Dimension
zh_Hans: 尺寸
human_description:
en_US: Video dimensions (width x height)
zh_Hans: 视频尺寸(宽 x 高)
form: form
- name: duration
type: number
required: false
default: 6
label:
en_US: Duration
zh_Hans: 时长
human_description:
en_US: Video duration in seconds
zh_Hans: 视频时长(秒)
form: form
- name: seed
type: number
required: false
default: 0
label:
en_US: Seed
zh_Hans: 种子值
human_description:
en_US: Random seed for video generation
zh_Hans: 视频生成的随机种子
form: form
- name: fps
type: number
required: false
default: 24
label:
en_US: FPS
zh_Hans: 帧率
human_description:
en_US: Frames per second for the generated video
zh_Hans: 生成视频的每秒帧数
form: form
- name: async
type: boolean
required: false
default: true
label:
en_US: Async Mode
zh_Hans: 异步模式
human_description:
en_US: Whether to run in async mode (return immediately) or sync mode (wait for completion)
zh_Hans: 是否以异步模式运行(立即返回)或同步模式(等待完成)
form: llm
- name: aws_region
type: string
required: false
default: us-east-1
label:
en_US: AWS Region
zh_Hans: AWS 区域
human_description:
en_US: AWS region for Bedrock service
zh_Hans: Bedrock 服务的 AWS 区域
form: form
- name: image_input_s3uri
type: string
required: false
label:
en_US: Input Image S3 URI
zh_Hans: 输入图像S3 URI
human_description:
en_US: S3 URI of the input image (1280x720 JPEG/PNG) to use as first frame
zh_Hans: 用作第一帧的输入图像1280x720 JPEG/PNG的S3 URI
form: llm
development:
dependencies:
- boto3
- pillow

View File

@ -0,0 +1,80 @@
from typing import Any, Union
from urllib.parse import urlparse
import boto3
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool
class S3Operator(BuiltinTool):
s3_client: Any = None
def _invoke(
self,
user_id: str,
tool_parameters: dict[str, Any],
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
"""
invoke tools
"""
try:
# Initialize S3 client if not already done
if not self.s3_client:
aws_region = tool_parameters.get("aws_region")
if aws_region:
self.s3_client = boto3.client("s3", region_name=aws_region)
else:
self.s3_client = boto3.client("s3")
# Parse S3 URI
s3_uri = tool_parameters.get("s3_uri")
if not s3_uri:
return self.create_text_message("s3_uri parameter is required")
parsed_uri = urlparse(s3_uri)
if parsed_uri.scheme != "s3":
return self.create_text_message("Invalid S3 URI format. Must start with 's3://'")
bucket = parsed_uri.netloc
# Remove leading slash from key
key = parsed_uri.path.lstrip("/")
operation_type = tool_parameters.get("operation_type", "read")
generate_presign_url = tool_parameters.get("generate_presign_url", False)
presign_expiry = int(tool_parameters.get("presign_expiry", 3600)) # default 1 hour
if operation_type == "write":
text_content = tool_parameters.get("text_content")
if not text_content:
return self.create_text_message("text_content parameter is required for write operation")
# Write content to S3
self.s3_client.put_object(Bucket=bucket, Key=key, Body=text_content.encode("utf-8"))
result = f"s3://{bucket}/{key}"
# Generate presigned URL for the written object if requested
if generate_presign_url:
result = self.s3_client.generate_presigned_url(
"get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry
)
else: # read operation
# Get object from S3
response = self.s3_client.get_object(Bucket=bucket, Key=key)
result = response["Body"].read().decode("utf-8")
# Generate presigned URL if requested
if generate_presign_url:
result = self.s3_client.generate_presigned_url(
"get_object", Params={"Bucket": bucket, "Key": key}, ExpiresIn=presign_expiry
)
return self.create_text_message(text=result)
except self.s3_client.exceptions.NoSuchBucket:
return self.create_text_message(f"Bucket '{bucket}' does not exist")
except self.s3_client.exceptions.NoSuchKey:
return self.create_text_message(f"Object '{key}' does not exist in bucket '{bucket}'")
except Exception as e:
return self.create_text_message(f"Exception: {str(e)}")

View File

@ -0,0 +1,98 @@
identity:
name: s3_operator
author: AWS
label:
en_US: AWS S3 Operator
zh_Hans: AWS S3 读写器
pt_BR: AWS S3 Operator
icon: icon.svg
description:
human:
en_US: AWS S3 Writer and Reader
zh_Hans: 读写S3 bucket中的文件
pt_BR: AWS S3 Writer and Reader
llm: AWS S3 Writer and Reader
parameters:
- name: text_content
type: string
required: false
label:
en_US: The text to write
zh_Hans: 待写入的文本
pt_BR: The text to write
human_description:
en_US: The text to write
zh_Hans: 待写入的文本
pt_BR: The text to write
llm_description: The text to write
form: llm
- name: s3_uri
type: string
required: true
label:
en_US: s3 uri
zh_Hans: s3 uri
pt_BR: s3 uri
human_description:
en_US: s3 uri
zh_Hans: s3 uri
pt_BR: s3 uri
llm_description: s3 uri
form: llm
- name: aws_region
type: string
required: true
label:
en_US: region of bucket
zh_Hans: bucket 所在的region
pt_BR: region of bucket
human_description:
en_US: region of bucket
zh_Hans: bucket 所在的region
pt_BR: region of bucket
llm_description: region of bucket
form: form
- name: operation_type
type: select
required: true
label:
en_US: operation type
zh_Hans: 操作类型
pt_BR: operation type
human_description:
en_US: operation type
zh_Hans: 操作类型
pt_BR: operation type
default: read
options:
- value: read
label:
en_US: read
zh_Hans:
- value: write
label:
en_US: write
zh_Hans:
form: form
- name: generate_presign_url
type: boolean
required: false
label:
en_US: Generate presigned URL
zh_Hans: 生成预签名URL
human_description:
en_US: Whether to generate a presigned URL for the S3 object
zh_Hans: 是否生成S3对象的预签名URL
default: false
form: form
- name: presign_expiry
type: number
required: false
label:
en_US: Presigned URL expiration time
zh_Hans: 预签名URL有效期
human_description:
en_US: Expiration time in seconds for the presigned URL
zh_Hans: 预签名URL的有效期
default: 3600
form: form

View File

@ -210,7 +210,7 @@ class ApiTool(Tool):
)
return response
else:
raise ValueError(f"Invalid http method {self.method}")
raise ValueError(f"Invalid http method {method}")
def _convert_body_property_any_of(
self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10

View File

@ -8,9 +8,10 @@ from mimetypes import guess_extension, guess_type
from typing import Optional, Union
from uuid import uuid4
from httpx import get
import httpx
from configs import dify_config
from core.helper import ssrf_proxy
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import MessageFile
@ -94,12 +95,11 @@ class ToolFileManager:
) -> ToolFile:
# try to download image
try:
response = get(file_url)
response = ssrf_proxy.get(file_url)
response.raise_for_status()
blob = response.content
except Exception as e:
logger.exception(f"Failed to download file from {file_url}")
raise
except httpx.TimeoutException as e:
raise ValueError(f"timeout when downloading file from {file_url}")
mimetype = guess_type(file_url)[0] or "octet/stream"
extension = guess_extension(mimetype) or ".bin"

View File

@ -1,4 +1,4 @@
class BaseNodeError(Exception):
class BaseNodeError(ValueError):
"""Base class for node errors."""
pass

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
from typing import Any, Optional, Union
from typing import Any, Optional
from configs import dify_config
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage
@ -59,7 +59,7 @@ class CodeNode(BaseNode[CodeNodeData]):
)
# Transform result
result = self._transform_result(result, self.node_data.outputs)
result = self._transform_result(result=result, output_schema=self.node_data.outputs)
except (CodeExecutionError, CodeNodeError) as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
@ -67,18 +67,17 @@ class CodeNode(BaseNode[CodeNodeData]):
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result)
def _check_string(self, value: str, variable: str) -> str:
def _check_string(self, value: str | None, variable: str) -> str | None:
"""
Check string
:param value: value
:param variable: variable
:return:
"""
if value is None:
return None
if not isinstance(value, str):
if value is None:
return None
else:
raise OutputValidationError(f"Output variable `{variable}` must be a string")
raise OutputValidationError(f"Output variable `{variable}` must be a string")
if len(value) > dify_config.CODE_MAX_STRING_LENGTH:
raise OutputValidationError(
@ -88,18 +87,17 @@ class CodeNode(BaseNode[CodeNodeData]):
return value.replace("\x00", "")
def _check_number(self, value: Union[int, float], variable: str) -> Union[int, float]:
def _check_number(self, value: int | float | None, variable: str) -> int | float | None:
"""
Check number
:param value: value
:param variable: variable
:return:
"""
if value is None:
return None
if not isinstance(value, int | float):
if value is None:
return None
else:
raise OutputValidationError(f"Output variable `{variable}` must be a number")
raise OutputValidationError(f"Output variable `{variable}` must be a number")
if value > dify_config.CODE_MAX_NUMBER or value < dify_config.CODE_MIN_NUMBER:
raise OutputValidationError(
@ -118,14 +116,12 @@ class CodeNode(BaseNode[CodeNodeData]):
return value
def _transform_result(
self, result: dict, output_schema: Optional[dict[str, CodeNodeData.Output]], prefix: str = "", depth: int = 1
) -> dict:
"""
Transform result
:param result: result
:param output_schema: output schema
:return:
"""
self,
result: Mapping[str, Any],
output_schema: Optional[dict[str, CodeNodeData.Output]],
prefix: str = "",
depth: int = 1,
):
if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.")

View File

@ -1,6 +1,7 @@
import csv
import io
import json
import logging
import os
import tempfile
@ -22,6 +23,8 @@ from models.workflow import WorkflowNodeExecutionStatus
from .entities import DocumentExtractorNodeData
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
logger = logging.getLogger(__name__)
class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
"""
@ -177,10 +180,43 @@ def _extract_text_from_pdf(file_content: bytes) -> str:
def _extract_text_from_doc(file_content: bytes) -> str:
"""
Extract text from a DOC/DOCX file.
For now support only paragraph and table add more if needed
"""
try:
doc_file = io.BytesIO(file_content)
doc = docx.Document(doc_file)
return "\n".join([paragraph.text for paragraph in doc.paragraphs])
text = []
# Process paragraphs
for paragraph in doc.paragraphs:
if paragraph.text.strip():
text.append(paragraph.text)
# Process tables
for table in doc.tables:
# Table header
try:
# table maybe cause errors so ignore it.
if len(table.rows) > 0 and table.rows[0].cells is not None:
# Check if any cell in the table has text
has_content = False
for row in table.rows:
if any(cell.text.strip() for cell in row.cells):
has_content = True
break
if has_content:
markdown_table = "| " + " | ".join(cell.text for cell in table.rows[0].cells) + " |\n"
markdown_table += "| " + " | ".join(["---"] * len(table.rows[0].cells)) + " |\n"
for row in table.rows[1:]:
markdown_table += "| " + " | ".join(cell.text for cell in row.cells) + " |\n"
text.append(markdown_table)
except Exception as e:
logger.warning(f"Failed to extract table from DOC/DOCX: {e}")
continue
return "\n".join(text)
except Exception as e:
raise TextExtractionError(f"Failed to extract text from DOC/DOCX: {str(e)}") from e

View File

@ -179,6 +179,15 @@ class ParameterExtractorNode(LLMNode):
error=str(e),
metadata={},
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs,
process_data=process_data,
outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)},
error=str(e),
metadata={},
)
error = None

View File

@ -154,8 +154,7 @@ class QuestionClassifierNode(LLMNode):
},
llm_usage=usage,
)
except ValueError as e:
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,

View File

@ -1,4 +1,4 @@
class VariableOperatorNodeError(Exception):
class VariableOperatorNodeError(ValueError):
"""Base error type, don't use directly."""
pass

View File

@ -116,8 +116,11 @@ def _build_from_local_file(
tenant_id: str,
transfer_method: FileTransferMethod,
) -> File:
upload_file_id = mapping.get("upload_file_id")
if not upload_file_id:
raise ValueError("Invalid upload file id")
stmt = select(UploadFile).where(
UploadFile.id == mapping.get("upload_file_id"),
UploadFile.id == upload_file_id,
UploadFile.tenant_id == tenant_id,
)

View File

@ -27,7 +27,7 @@ def parse_json_markdown(json_string: str) -> dict:
extracted_content = json_string[start_index:end_index].strip()
parsed = json.loads(extracted_content)
else:
raise Exception("Could not find JSON block in the output.")
raise ValueError("could not find json block in the output.")
return parsed
@ -36,10 +36,10 @@ def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict:
try:
json_obj = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise OutputParserError(f"Got invalid JSON object. Error: {e}")
raise OutputParserError(f"got invalid json object. error: {e}")
for key in expected_keys:
if key not in json_obj:
raise OutputParserError(
f"Got invalid return object. Expected key `{key}` to be present, but got {json_obj}"
f"got invalid return object. expected key `{key}` to be present, but got {json_obj}"
)
return json_obj

View File

@ -2,6 +2,7 @@ import enum
import json
from flask_login import UserMixin
from sqlalchemy import func
from .engine import db
from .types import StringUUID
@ -30,11 +31,11 @@ class Account(UserMixin, db.Model):
timezone = db.Column(db.String(255))
last_login_at = db.Column(db.DateTime)
last_login_ip = db.Column(db.String(255))
last_active_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
last_active_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying"))
initialized_at = db.Column(db.DateTime)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def is_password_set(self):
@ -187,8 +188,8 @@ class Tenant(db.Model):
plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying"))
status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
custom_config = db.Column(db.Text)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
def get_accounts(self) -> list[Account]:
return (
@ -228,8 +229,8 @@ class TenantAccountJoin(db.Model):
current = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
role = db.Column(db.String(16), nullable=False, server_default="normal")
invited_by = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class AccountIntegrate(db.Model):
@ -245,8 +246,8 @@ class AccountIntegrate(db.Model):
provider = db.Column(db.String(16), nullable=False)
open_id = db.Column(db.String(255), nullable=False)
encrypted_token = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class InvitationCode(db.Model):
@ -265,4 +266,4 @@ class InvitationCode(db.Model):
used_by_tenant_id = db.Column(StringUUID)
used_by_account_id = db.Column(StringUUID)
deprecated_at = db.Column(db.DateTime)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -1,5 +1,7 @@
import enum
from sqlalchemy import func
from .engine import db
from .types import StringUUID
@ -23,4 +25,4 @@ class APIBasedExtension(db.Model):
name = db.Column(db.String(255), nullable=False)
api_endpoint = db.Column(db.String(255), nullable=False)
api_key = db.Column(db.Text, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -50,9 +50,9 @@ class Dataset(db.Model):
indexing_technique = db.Column(db.String(255), nullable=True)
index_struct = db.Column(db.Text, nullable=True)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
embedding_model = db.Column(db.String(255), nullable=True)
embedding_model_provider = db.Column(db.String(255), nullable=True)
collection_binding_id = db.Column(StringUUID, nullable=True)
@ -212,7 +212,7 @@ class DatasetProcessRule(db.Model):
mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
rules = db.Column(db.Text, nullable=True)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
MODES = ["automatic", "custom"]
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
@ -264,7 +264,7 @@ class Document(db.Model):
created_from = db.Column(db.String(255), nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_api_request_id = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
# start processing
processing_started_at = db.Column(db.DateTime, nullable=True)
@ -303,7 +303,7 @@ class Document(db.Model):
archived_reason = db.Column(db.String(255), nullable=True)
archived_by = db.Column(StringUUID, nullable=True)
archived_at = db.Column(db.DateTime, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
doc_type = db.Column(db.String(40), nullable=True)
doc_metadata = db.Column(db.JSON, nullable=True)
doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
@ -527,9 +527,9 @@ class DocumentSegment(db.Model):
disabled_by = db.Column(StringUUID, nullable=True)
status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying"))
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
indexing_at = db.Column(db.DateTime, nullable=True)
completed_at = db.Column(db.DateTime, nullable=True)
error = db.Column(db.Text, nullable=True)
@ -697,7 +697,7 @@ class Embedding(db.Model):
)
hash = db.Column(db.String(64), nullable=False)
embedding = db.Column(db.LargeBinary, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying"))
def set_embedding(self, embedding_data: list[float]):
@ -719,7 +719,7 @@ class DatasetCollectionBinding(db.Model):
model_name = db.Column(db.String(255), nullable=False)
type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
collection_name = db.Column(db.String(64), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TidbAuthBinding(db.Model):
@ -739,7 +739,7 @@ class TidbAuthBinding(db.Model):
status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING"))
account = db.Column(db.String(255), nullable=False)
password = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class Whitelist(db.Model):
@ -751,7 +751,7 @@ class Whitelist(db.Model):
id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=True)
category = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class DatasetPermission(db.Model):
@ -768,7 +768,7 @@ class DatasetPermission(db.Model):
account_id = db.Column(StringUUID, nullable=False)
tenant_id = db.Column(StringUUID, nullable=False)
has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class ExternalKnowledgeApis(db.Model):
@ -785,9 +785,9 @@ class ExternalKnowledgeApis(db.Model):
tenant_id = db.Column(StringUUID, nullable=False)
settings = db.Column(db.Text, nullable=True)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
def to_dict(self):
return {
@ -840,6 +840,6 @@ class ExternalKnowledgeBindings(db.Model):
dataset_id = db.Column(StringUUID, nullable=False)
external_knowledge_id = db.Column(db.Text, nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -30,7 +30,7 @@ class DifySetup(db.Model):
__table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
version = db.Column(db.String(255), nullable=False)
setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
setup_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class AppMode(StrEnum):
@ -85,9 +85,9 @@ class App(db.Model):
tracing = db.Column(db.Text, nullable=True)
max_active_requests = db.Column(db.Integer, nullable=True)
created_by = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
@property
@ -226,9 +226,9 @@ class AppModelConfig(db.Model):
model_id = db.Column(db.String(255), nullable=True)
configs = db.Column(db.JSON, nullable=True)
created_by = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
opening_statement = db.Column(db.Text)
suggested_questions = db.Column(db.Text)
suggested_questions_after_answer = db.Column(db.Text)
@ -482,8 +482,8 @@ class RecommendedApp(db.Model):
is_listed = db.Column(db.Boolean, nullable=False, default=True)
install_count = db.Column(db.Integer, nullable=False, default=0)
language = db.Column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying"))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def app(self):
@ -507,7 +507,7 @@ class InstalledApp(db.Model):
position = db.Column(db.Integer, nullable=False, default=0)
is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
last_used_at = db.Column(db.DateTime, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def app(self):
@ -548,8 +548,8 @@ class Conversation(db.Model):
read_at = db.Column(db.DateTime)
read_account_id = db.Column(StringUUID)
dialogue_count: Mapped[int] = mapped_column(default=0)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all")
message_annotations = db.relationship(
@ -700,8 +700,10 @@ class Conversation(db.Model):
def status_count(self):
messages = db.session.query(Message).filter(Message.conversation_id == self.id).all()
status_counts = {
WorkflowRunStatus.RUNNING: 0,
WorkflowRunStatus.SUCCEEDED: 0,
WorkflowRunStatus.FAILED: 0,
WorkflowRunStatus.STOPPED: 0,
WorkflowRunStatus.PARTIAL_SUCCESSED: 0,
}
@ -789,8 +791,8 @@ class Message(db.Model):
from_source = db.Column(db.String(255), nullable=False)
from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID)
from_account_id: Mapped[Optional[str]] = db.Column(StringUUID)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
workflow_run_id = db.Column(StringUUID)
@ -1115,8 +1117,8 @@ class MessageFeedback(db.Model):
from_source = db.Column(db.String(255), nullable=False)
from_end_user_id = db.Column(StringUUID)
from_account_id = db.Column(StringUUID)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def from_account(self):
@ -1162,9 +1164,7 @@ class MessageFile(db.Model):
upload_file_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True)
created_by_role: Mapped[str] = db.Column(db.String(255), nullable=False)
created_by: Mapped[str] = db.Column(StringUUID, nullable=False)
created_at: Mapped[datetime] = db.Column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class MessageAnnotation(db.Model):
@ -1184,8 +1184,8 @@ class MessageAnnotation(db.Model):
content = db.Column(db.Text, nullable=False)
hit_count = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
account_id = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def account(self):
@ -1214,7 +1214,7 @@ class AppAnnotationHitHistory(db.Model):
source = db.Column(db.Text, nullable=False)
question = db.Column(db.Text, nullable=False)
account_id = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
score = db.Column(Float, nullable=False, server_default=db.text("0"))
message_id = db.Column(StringUUID, nullable=False)
annotation_question = db.Column(db.Text, nullable=False)
@ -1248,9 +1248,9 @@ class AppAnnotationSetting(db.Model):
score_threshold = db.Column(Float, nullable=False, server_default=db.text("0"))
collection_binding_id = db.Column(StringUUID, nullable=False)
created_user_id = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_user_id = db.Column(StringUUID, nullable=False)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def created_account(self):
@ -1296,9 +1296,9 @@ class OperationLog(db.Model):
account_id = db.Column(StringUUID, nullable=False)
action = db.Column(db.String(255), nullable=False)
content = db.Column(db.JSON)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_ip = db.Column(db.String(255), nullable=False)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class EndUser(UserMixin, db.Model):
@ -1317,8 +1317,8 @@ class EndUser(UserMixin, db.Model):
name = db.Column(db.String(255))
is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
session_id = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class Site(db.Model):
@ -1349,9 +1349,9 @@ class Site(db.Model):
prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
created_by = db.Column(StringUUID, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
code = db.Column(db.String(255))
@property
@ -1393,7 +1393,7 @@ class ApiToken(db.Model):
type = db.Column(db.String(16), nullable=False)
token = db.Column(db.String(255), nullable=False)
last_used_at = db.Column(db.DateTime, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@staticmethod
def generate_api_key(prefix, n):
@ -1424,9 +1424,7 @@ class UploadFile(db.Model):
db.String(255), nullable=False, server_default=db.text("'account'::character varying")
)
created_by: Mapped[str] = db.Column(StringUUID, nullable=False)
created_at: Mapped[datetime] = db.Column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
used: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
used_by: Mapped[str | None] = db.Column(StringUUID, nullable=True)
used_at: Mapped[datetime | None] = db.Column(db.DateTime, nullable=True)
@ -1483,7 +1481,7 @@ class ApiRequest(db.Model):
request = db.Column(db.Text, nullable=True)
response = db.Column(db.Text, nullable=True)
ip = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class MessageChain(db.Model):
@ -1655,7 +1653,7 @@ class Tag(db.Model):
type = db.Column(db.String(16), nullable=False)
name = db.Column(db.String(255), nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TagBinding(db.Model):
@ -1671,7 +1669,7 @@ class TagBinding(db.Model):
tag_id = db.Column(StringUUID, nullable=True)
target_id = db.Column(StringUUID, nullable=True)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TraceAppConfig(db.Model):
@ -1685,8 +1683,10 @@ class TraceAppConfig(db.Model):
app_id = db.Column(StringUUID, nullable=False)
tracing_provider = db.Column(db.String(255), nullable=True)
tracing_config = db.Column(db.JSON, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.now())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.now(), onupdate=func.now())
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
is_active = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
@property

View File

@ -1,5 +1,7 @@
from enum import Enum
from sqlalchemy import func
from .engine import db
from .types import StringUUID
@ -60,8 +62,8 @@ class Provider(db.Model):
quota_limit = db.Column(db.BigInteger, nullable=True)
quota_used = db.Column(db.BigInteger, default=0)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
def __repr__(self):
return (
@ -108,8 +110,8 @@ class ProviderModel(db.Model):
model_type = db.Column(db.String(40), nullable=False)
encrypted_config = db.Column(db.Text, nullable=True)
is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TenantDefaultModel(db.Model):
@ -124,8 +126,8 @@ class TenantDefaultModel(db.Model):
provider_name = db.Column(db.String(255), nullable=False)
model_name = db.Column(db.String(255), nullable=False)
model_type = db.Column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TenantPreferredModelProvider(db.Model):
@ -139,8 +141,8 @@ class TenantPreferredModelProvider(db.Model):
tenant_id = db.Column(StringUUID, nullable=False)
provider_name = db.Column(db.String(255), nullable=False)
preferred_provider_type = db.Column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class ProviderOrder(db.Model):
@ -164,8 +166,8 @@ class ProviderOrder(db.Model):
paid_at = db.Column(db.DateTime)
pay_failed_at = db.Column(db.DateTime)
refunded_at = db.Column(db.DateTime)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class ProviderModelSetting(db.Model):
@ -186,8 +188,8 @@ class ProviderModelSetting(db.Model):
model_type = db.Column(db.String(40), nullable=False)
enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class LoadBalancingModelConfig(db.Model):
@ -209,5 +211,5 @@ class LoadBalancingModelConfig(db.Model):
name = db.Column(db.String(255), nullable=False)
encrypted_config = db.Column(db.Text, nullable=True)
enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -1,5 +1,6 @@
import json
from sqlalchemy import func
from sqlalchemy.dialects.postgresql import JSONB
from .engine import db
@ -19,8 +20,8 @@ class DataSourceOauthBinding(db.Model):
access_token = db.Column(db.String(255), nullable=False)
provider = db.Column(db.String(255), nullable=False)
source_info = db.Column(JSONB, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))
@ -37,8 +38,8 @@ class DataSourceApiKeyAuthBinding(db.Model):
category = db.Column(db.String(255), nullable=False)
provider = db.Column(db.String(255), nullable=False)
credentials = db.Column(db.Text, nullable=True) # JSON
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))
def to_dict(self):

View File

@ -2,7 +2,7 @@ import json
from typing import Optional
import sqlalchemy as sa
from sqlalchemy import ForeignKey
from sqlalchemy import ForeignKey, func
from sqlalchemy.orm import Mapped, mapped_column
from core.tools.entities.common_entities import I18nObject
@ -36,8 +36,8 @@ class BuiltinToolProvider(db.Model):
provider = db.Column(db.String(40), nullable=False)
# credential of the tool provider
encrypted_credentials = db.Column(db.Text, nullable=True)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def credentials(self) -> dict:
@ -74,8 +74,8 @@ class PublishedAppTool(db.Model):
tool_name = db.Column(db.String(40), nullable=False)
# author
author = db.Column(db.String(40), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def description_i18n(self) -> I18nObject:
@ -120,8 +120,8 @@ class ApiToolProvider(db.Model):
# custom_disclaimer
custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="")
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def schema_type(self) -> ApiProviderSchemaType:
@ -198,8 +198,8 @@ class WorkflowToolProvider(db.Model):
# privacy policy
privacy_policy = db.Column(db.String(255), nullable=True, server_default="")
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def user(self) -> Account | None:
@ -251,8 +251,8 @@ class ToolModelInvoke(db.Model):
provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0"))
total_price = db.Column(db.Numeric(10, 7))
currency = db.Column(db.String(255), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class ToolConversationVariables(db.Model):
@ -278,8 +278,8 @@ class ToolConversationVariables(db.Model):
# variables pool
variables_str = db.Column(db.Text, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def variables(self) -> dict:

View File

@ -1,3 +1,6 @@
from sqlalchemy import func
from sqlalchemy.orm import Mapped, mapped_column
from .engine import db
from .model import Message
from .types import StringUUID
@ -15,7 +18,7 @@ class SavedMessage(db.Model):
message_id = db.Column(StringUUID, nullable=False)
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def message(self):
@ -31,7 +34,7 @@ class PinnedConversation(db.Model):
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
app_id = db.Column(StringUUID, nullable=False)
conversation_id = db.Column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID)
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

View File

@ -103,12 +103,13 @@ class Workflow(db.Model):
graph: Mapped[str] = mapped_column(sa.Text)
_features: Mapped[str] = mapped_column("features", sa.TEXT)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by: Mapped[Optional[str]] = mapped_column(StringUUID)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, default=datetime.now(tz=UTC), server_onupdate=func.current_timestamp()
db.DateTime,
nullable=False,
default=datetime.now(UTC).replace(tzinfo=None),
server_onupdate=func.current_timestamp(),
)
_environment_variables: Mapped[str] = mapped_column(
"environment_variables", db.Text, nullable=False, server_default="{}"
@ -406,7 +407,7 @@ class WorkflowRun(db.Model):
total_steps = db.Column(db.Integer, server_default=db.text("0"))
created_by_role = db.Column(db.String(255), nullable=False) # account, end_user
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
finished_at = db.Column(db.DateTime)
exceptions_count = db.Column(db.Integer, server_default=db.text("0"))
@ -636,7 +637,7 @@ class WorkflowNodeExecution(db.Model):
error = db.Column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
execution_metadata = db.Column(db.Text)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(StringUUID, nullable=False)
finished_at = db.Column(db.DateTime)
@ -754,7 +755,7 @@ class WorkflowAppLog(db.Model):
created_from = db.Column(db.String(255), nullable=False)
created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def workflow_run(self):
@ -780,7 +781,7 @@ class ConversationVariable(db.Model):
conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True)
app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True)
data = db.Column(db.Text, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text("CURRENT_TIMESTAMP(0)"))
created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=func.current_timestamp())
updated_at = db.Column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)

View File

@ -340,7 +340,10 @@ class AppDslService:
) -> App:
"""Create a new app or update an existing one."""
app_data = data.get("app", {})
app_mode = AppMode(app_data["mode"])
app_mode = app_data.get("mode")
if not app_mode:
raise ValueError("loss app mode")
app_mode = AppMode(app_mode)
# Set icon type
icon_type_value = icon_type or app_data.get("icon_type")

View File

@ -1,8 +1,9 @@
from collections.abc import Callable
from collections.abc import Callable, Sequence
from datetime import UTC, datetime
from typing import Optional, Union
from sqlalchemy import asc, desc, or_
from sqlalchemy import asc, desc, func, or_, select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.llm_generator import LLMGenerator
@ -18,19 +19,21 @@ class ConversationService:
@classmethod
def pagination_by_last_id(
cls,
*,
session: Session,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
limit: int,
invoke_from: InvokeFrom,
include_ids: Optional[list] = None,
exclude_ids: Optional[list] = None,
include_ids: Optional[Sequence[str]] = None,
exclude_ids: Optional[Sequence[str]] = None,
sort_by: str = "-updated_at",
) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
base_query = db.session.query(Conversation).filter(
stmt = select(Conversation).where(
Conversation.is_deleted == False,
Conversation.app_id == app_model.id,
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
@ -38,37 +41,40 @@ class ConversationService:
Conversation.from_account_id == (user.id if isinstance(user, Account) else None),
or_(Conversation.invoke_from.is_(None), Conversation.invoke_from == invoke_from.value),
)
if include_ids is not None:
base_query = base_query.filter(Conversation.id.in_(include_ids))
stmt = stmt.where(Conversation.id.in_(include_ids))
if exclude_ids is not None:
base_query = base_query.filter(~Conversation.id.in_(exclude_ids))
stmt = stmt.where(~Conversation.id.in_(exclude_ids))
# define sort fields and directions
sort_field, sort_direction = cls._get_sort_params(sort_by)
if last_id:
last_conversation = base_query.filter(Conversation.id == last_id).first()
last_conversation = session.scalar(stmt.where(Conversation.id == last_id))
if not last_conversation:
raise LastConversationNotExistsError()
# build filters based on sorting
filter_condition = cls._build_filter_condition(sort_field, sort_direction, last_conversation)
base_query = base_query.filter(filter_condition)
base_query = base_query.order_by(sort_direction(getattr(Conversation, sort_field)))
conversations = base_query.limit(limit).all()
filter_condition = cls._build_filter_condition(
sort_field=sort_field,
sort_direction=sort_direction,
reference_conversation=last_conversation,
)
stmt = stmt.where(filter_condition)
query_stmt = stmt.order_by(sort_direction(getattr(Conversation, sort_field))).limit(limit)
conversations = session.scalars(query_stmt).all()
has_more = False
if len(conversations) == limit:
current_page_last_conversation = conversations[-1]
rest_filter_condition = cls._build_filter_condition(
sort_field, sort_direction, current_page_last_conversation, is_next_page=True
sort_field=sort_field,
sort_direction=sort_direction,
reference_conversation=current_page_last_conversation,
)
rest_count = base_query.filter(rest_filter_condition).count()
count_stmt = stmt.where(rest_filter_condition)
count_stmt = select(func.count()).select_from(count_stmt.subquery())
rest_count = session.scalar(count_stmt) or 0
if rest_count > 0:
has_more = True
@ -81,11 +87,9 @@ class ConversationService:
return sort_by, asc
@classmethod
def _build_filter_condition(
cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation, is_next_page: bool = False
):
def _build_filter_condition(cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation):
field_value = getattr(reference_conversation, sort_field)
if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page):
if sort_direction == desc:
return getattr(Conversation, sort_field) < field_value
else:
return getattr(Conversation, sort_field) > field_value

View File

@ -151,7 +151,12 @@ class MessageService:
@classmethod
def create_feedback(
cls, app_model: App, message_id: str, user: Optional[Union[Account, EndUser]], rating: Optional[str]
cls,
app_model: App,
message_id: str,
user: Optional[Union[Account, EndUser]],
rating: Optional[str],
content: Optional[str],
) -> MessageFeedback:
if not user:
raise ValueError("user cannot be None")
@ -164,6 +169,7 @@ class MessageService:
db.session.delete(feedback)
elif rating and feedback:
feedback.rating = rating
feedback.content = content
elif not rating and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
else:
@ -172,6 +178,7 @@ class MessageService:
conversation_id=message.conversation_id,
message_id=message.id,
rating=rating,
content=content,
from_source=("user" if isinstance(user, EndUser) else "admin"),
from_end_user_id=(user.id if isinstance(user, EndUser) else None),
from_account_id=(user.id if isinstance(user, Account) else None),

View File

@ -2,6 +2,9 @@ import json
import logging
from pathlib import Path
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
@ -32,7 +35,7 @@ class BuiltinToolManageService:
tenant_id=tenant_id, provider_controller=provider_controller
)
# check if user has added the provider
builtin_provider: BuiltinToolProvider = (
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
@ -71,19 +74,18 @@ class BuiltinToolManageService:
return jsonable_encoder([v for _, v in (provider.credentials_schema or {}).items()])
@staticmethod
def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
def update_builtin_tool_provider(
session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict
):
"""
update builtin tool provider
"""
# get if the provider exists
provider: BuiltinToolProvider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
)
.first()
stmt = select(BuiltinToolProvider).where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
)
provider = session.scalar(stmt)
try:
# get provider
@ -115,13 +117,10 @@ class BuiltinToolManageService:
encrypted_credentials=json.dumps(credentials),
)
db.session.add(provider)
db.session.commit()
session.add(provider)
else:
provider.encrypted_credentials = json.dumps(credentials)
db.session.add(provider)
db.session.commit()
# delete cache
tool_configuration.delete_tool_credentials_cache()
@ -129,15 +128,15 @@ class BuiltinToolManageService:
return {"result": "success"}
@staticmethod
def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str):
def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str):
"""
get builtin tool provider credentials
"""
provider: BuiltinToolProvider = (
provider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
BuiltinToolProvider.provider == provider_name,
)
.first()
)
@ -156,7 +155,7 @@ class BuiltinToolManageService:
"""
delete tool provider
"""
provider: BuiltinToolProvider = (
provider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,

View File

@ -1,5 +1,8 @@
from typing import Optional, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
@ -13,6 +16,8 @@ class WebConversationService:
@classmethod
def pagination_by_last_id(
cls,
*,
session: Session,
app_model: App,
user: Optional[Union[Account, EndUser]],
last_id: Optional[str],
@ -23,24 +28,25 @@ class WebConversationService:
) -> InfiniteScrollPagination:
include_ids = None
exclude_ids = None
if pinned is not None:
pinned_conversations = (
db.session.query(PinnedConversation)
.filter(
if pinned is not None and user:
stmt = (
select(PinnedConversation.conversation_id)
.where(
PinnedConversation.app_id == app_model.id,
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
PinnedConversation.created_by == user.id,
)
.order_by(PinnedConversation.created_at.desc())
.all()
)
pinned_conversation_ids = [pc.conversation_id for pc in pinned_conversations]
pinned_conversation_ids = session.scalars(stmt).all()
if pinned:
include_ids = pinned_conversation_ids
else:
exclude_ids = pinned_conversation_ids
return ConversationService.pagination_by_last_id(
session=session,
app_model=app_model,
user=user,
last_id=last_id,

View File

@ -923,6 +923,3 @@ CREATE_TIDB_SERVICE_JOB_ENABLED=false
# Maximum number of submitted thread count in a ThreadPool for parallel node execution
MAX_SUBMIT_COUNT=100
# Proxy
HTTP_PROXY=
HTTPS_PROXY=

View File

@ -386,8 +386,6 @@ x-shared-env: &shared-api-worker-env
CSP_WHITELIST: ${CSP_WHITELIST:-}
CREATE_TIDB_SERVICE_JOB_ENABLED: ${CREATE_TIDB_SERVICE_JOB_ENABLED:-false}
MAX_SUBMIT_COUNT: ${MAX_SUBMIT_COUNT:-100}
HTTP_PROXY: ${HTTP_PROXY:-}
HTTPS_PROXY: ${HTTPS_PROXY:-}
services:
# API service

View File

@ -64,6 +64,12 @@ const WorkflowProcessItem = ({
setShowMessageLogModal(true)
}, [item, setCurrentLogItem, setCurrentLogModalActiveTab, setShowMessageLogModal])
const showRetryDetail = useCallback(() => {
setCurrentLogItem(item)
setCurrentLogModalActiveTab('TRACING')
setShowMessageLogModal(true)
}, [item, setCurrentLogItem, setCurrentLogModalActiveTab, setShowMessageLogModal])
return (
<div
className={cn(
@ -105,6 +111,7 @@ const WorkflowProcessItem = ({
<TracingPanel
list={data.tracing}
onShowIterationDetail={showIterationDetail}
onShowRetryDetail={showRetryDetail}
hideNodeInfo={hideInfo}
hideNodeProcessDetail={hideProcessDetail}
/>

View File

@ -28,6 +28,7 @@ export type InputProps = {
destructive?: boolean
wrapperClassName?: string
styleCss?: CSSProperties
unit?: string
} & React.InputHTMLAttributes<HTMLInputElement> & VariantProps<typeof inputVariants>
const Input = ({
@ -43,6 +44,7 @@ const Input = ({
value,
placeholder,
onChange,
unit,
...props
}: InputProps) => {
const { t } = useTranslation()
@ -80,6 +82,13 @@ const Input = ({
{destructive && (
<RiErrorWarningLine className='absolute right-2 top-1/2 -translate-y-1/2 w-4 h-4 text-text-destructive-secondary' />
)}
{
unit && (
<div className='absolute right-2 top-1/2 -translate-y-1/2 system-sm-regular text-text-tertiary'>
{unit}
</div>
)
}
</div>
)
}

View File

@ -23,6 +23,7 @@ const SearchInput: FC<SearchInputProps> = ({
const { t } = useTranslation()
const [focus, setFocus] = useState<boolean>(false)
const isComposing = useRef<boolean>(false)
const [internalValue, setInternalValue] = useState<string>(value)
return (
<div className={cn(
@ -45,16 +46,18 @@ const SearchInput: FC<SearchInputProps> = ({
white && '!bg-white hover:!bg-white group-hover:!bg-white placeholder:!text-gray-400',
)}
placeholder={placeholder || t('common.operation.search')!}
value={value}
value={internalValue}
onChange={(e) => {
setInternalValue(e.target.value)
if (!isComposing.current)
onChange(e.target.value)
}}
onCompositionStart={() => {
isComposing.current = true
}}
onCompositionEnd={() => {
onCompositionEnd={(e) => {
isComposing.current = false
onChange(e.data)
}}
onFocus={() => setFocus(true)}
onBlur={() => setFocus(false)}
@ -63,7 +66,10 @@ const SearchInput: FC<SearchInputProps> = ({
{value && (
<div
className='shrink-0 flex items-center justify-center w-4 h-4 cursor-pointer group/clear'
onClick={() => onChange('')}
onClick={() => {
onChange('')
setInternalValue('')
}}
>
<XCircle className='w-3.5 h-3.5 text-gray-400 group-hover/clear:text-gray-600' />
</div>

View File

@ -346,6 +346,9 @@ The text generation application offers non-session support and is ideal for tran
<Property name='user' type='string' key='user'>
User identifier, defined by the developer's rules, must be unique within the application.
</Property>
<Property name='content' type='string' key='content'>
The specific content of message feedback.
</Property>
</Properties>
### Response
@ -353,7 +356,7 @@ The text generation application offers non-session support and is ideal for tran
</Col>
<Col sticky>
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123"\n}'`}>
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123",\n "content": "message feedback information"\n}'`}>
```bash {{ title: 'cURL' }}
curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \
@ -361,7 +364,8 @@ The text generation application offers non-session support and is ideal for tran
--header 'Content-Type: application/json' \
--data-raw '{
"rating": "like",
"user": "abc-123"
"user": "abc-123",
"content": "message feedback information"
}'
```

View File

@ -345,6 +345,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
<Property name='user' type='string' key='user'>
開発者のルールで定義されたユーザー識別子。アプリケーション内で一意である必要があります。
</Property>
<Property name='content' type='string' key='content'>
メッセージのフィードバックです。
</Property>
</Properties>
### レスポンス
@ -352,7 +355,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
</Col>
<Col sticky>
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123"\n}'`}>
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123",\n "content": "message feedback information"\n}'`}>
```bash {{ title: 'cURL' }}
curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \
@ -360,7 +363,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
--header 'Content-Type: application/json' \
--data-raw '{
"rating": "like",
"user": "abc-123"
"user": "abc-123",
"content": "message feedback information"
}'
```

View File

@ -320,6 +320,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
<Property name='user' type='string' key='user'>
用户标识,由开发者定义规则,需保证用户标识在应用内唯一。
</Property>
<Property name='content' type='string' key='content'>
消息反馈的具体信息。
</Property>
</Properties>
### Response
@ -327,7 +330,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
</Col>
<Col sticky>
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123"\n}'`}>
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123",\n "content": "message feedback information"\n}'`}>
```bash {{ title: 'cURL' }}
curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \
@ -335,7 +338,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
--header 'Content-Type: application/json' \
--data-raw '{
"rating": "like",
"user": "abc-123"
"user": "abc-123",
"content": "message feedback information"
}'
```

View File

@ -444,6 +444,9 @@ Chat applications support session persistence, allowing previous chat history to
<Property name='user' type='string' key='user'>
User identifier, defined by the developer's rules, must be unique within the application.
</Property>
<Property name='content' type='string' key='content'>
The specific content of message feedback.
</Property>
</Properties>
### Response
@ -451,7 +454,7 @@ Chat applications support session persistence, allowing previous chat history to
</Col>
<Col sticky>
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123"\n}'`}>
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123",\n "content": "message feedback information"\n}'`}>
```bash {{ title: 'cURL' }}
curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \
@ -459,7 +462,8 @@ Chat applications support session persistence, allowing previous chat history to
--header 'Content-Type: application/json' \
--data-raw '{
"rating": "like",
"user": "abc-123"
"user": "abc-123",
"content": "message feedback information"
}'
```

View File

@ -444,6 +444,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
<Property name='user' type='string' key='user'>
ユーザー識別子、開発者のルールによって定義され、アプリケーション内で一意でなければなりません。
</Property>
<Property name='content' type='string' key='content'>
メッセージのフィードバックです。
</Property>
</Properties>
### 応答
@ -451,7 +454,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
</Col>
<Col sticky>
<CodeGroup title="リクエスト" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123"\n}'`}>
<CodeGroup title="リクエスト" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123",\n "content": "message feedback information"\n}'`}>
```bash {{ title: 'cURL' }}
curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \
@ -459,7 +462,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
--header 'Content-Type: application/json' \
--data-raw '{
"rating": "like",
"user": "abc-123"
"user": "abc-123",
"content": "message feedback information"
}'
```

View File

@ -450,6 +450,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
<Property name='user' type='string' key='user'>
用户标识,由开发者定义规则,需保证用户标识在应用内唯一。
</Property>
<Property name='content' type='string' key='content'>
消息反馈的具体信息。
</Property>
</Properties>
### Response
@ -457,7 +460,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
</Col>
<Col sticky>
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123"\n}'`}>
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123",\n "content": "message feedback information"\n}'`}>
```bash {{ title: 'cURL' }}
curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \
@ -465,7 +468,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
--header 'Content-Type: application/json' \
--data-raw '{
"rating": "like",
"user": "abc-123"
"user": "abc-123",
"content": "message feedback information"
}'
```

View File

@ -408,6 +408,9 @@ Chat applications support session persistence, allowing previous chat history to
<Property name='user' type='string' key='user'>
User identifier, defined by the developer's rules, must be unique within the application.
</Property>
<Property name='content' type='string' key='content'>
The specific content of message feedback.
</Property>
</Properties>
### Response
@ -415,7 +418,7 @@ Chat applications support session persistence, allowing previous chat history to
</Col>
<Col sticky>
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123"\n}'`}>
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123",\n "content": "message feedback information"\n}'`}>
```bash {{ title: 'cURL' }}
curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \
@ -423,7 +426,8 @@ Chat applications support session persistence, allowing previous chat history to
--header 'Content-Type: application/json' \
--data-raw '{
"rating": "like",
"user": "abc-123"
"user": "abc-123",
"content": "message feedback information"
}'
```

View File

@ -408,6 +408,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
<Property name='user' type='string' key='user'>
ユーザー識別子、開発者のルールで定義され、アプリケーション内で一意でなければなりません。
</Property>
<Property name='content' type='string' key='content'>
メッセージのフィードバックです。
</Property>
</Properties>
### 応答
@ -415,7 +418,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
</Col>
<Col sticky>
<CodeGroup title="リクエスト" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123"\n}'`}>
<CodeGroup title="リクエスト" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n --header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123",\n "content": "message feedback information"\n}'`}>
```bash {{ title: 'cURL' }}
curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \
@ -423,7 +426,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from
--header 'Content-Type: application/json' \
--data-raw '{
"rating": "like",
"user": "abc-123"
"user": "abc-123",
"content": "message feedback information"
}'
```

View File

@ -423,6 +423,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
<Property name='user' type='string' key='user'>
用户标识,由开发者定义规则,需保证用户标识在应用内唯一。
</Property>
<Property name='content' type='string' key='content'>
消息反馈的具体信息。
</Property>
</Properties>
### Response
@ -430,7 +433,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
</Col>
<Col sticky>
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123"\n}'`}>
<CodeGroup title="Request" tag="POST" label="/messages/:message_id/feedbacks" targetCode={`curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks \\\n--header 'Authorization: Bearer {api_key}' \\\n--header 'Content-Type: application/json' \\\n--data-raw '{\n "rating": "like",\n "user": "abc-123",\n "content": "message feedback information"\n}'`}>
```bash {{ title: 'cURL' }}
curl -X POST '${props.appDetail.api_base_url}/messages/:message_id/feedbacks' \
@ -438,7 +441,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx'
--header 'Content-Type: application/json' \
--data-raw '{
"rating": "like",
"user": "abc-123"
"user": "abc-123",
"content": "message feedback information"
}'
```

View File

@ -506,3 +506,5 @@ export const WORKFLOW_DATA_UPDATE = 'WORKFLOW_DATA_UPDATE'
export const CUSTOM_NODE = 'custom'
export const CUSTOM_EDGE = 'custom'
export const DSL_EXPORT_CHECK = 'DSL_EXPORT_CHECK'
export const DEFAULT_RETRY_MAX = 3
export const DEFAULT_RETRY_INTERVAL = 100

View File

@ -28,6 +28,7 @@ import {
getFilesInLogs,
} from '@/app/components/base/file-uploader/utils'
import { ErrorHandleTypeEnum } from '@/app/components/workflow/nodes/_base/components/error-handle/types'
import type { NodeTracing } from '@/types/workflow'
export const useWorkflowRun = () => {
const store = useStoreApi()
@ -114,6 +115,7 @@ export const useWorkflowRun = () => {
onIterationStart,
onIterationNext,
onIterationFinish,
onNodeRetry,
onError,
...restCallback
} = callback || {}
@ -440,10 +442,13 @@ export const useWorkflowRun = () => {
})
if (currentIndex > -1 && draft.tracing) {
draft.tracing[currentIndex] = {
...data,
...(draft.tracing[currentIndex].extras
? { extras: draft.tracing[currentIndex].extras }
: {}),
...data,
...(draft.tracing[currentIndex].retryDetail
? { retryDetail: draft.tracing[currentIndex].retryDetail }
: {}),
} as any
}
}))
@ -616,6 +621,41 @@ export const useWorkflowRun = () => {
if (onIterationFinish)
onIterationFinish(params)
},
onNodeRetry: (params) => {
const { data } = params
const {
workflowRunningData,
setWorkflowRunningData,
} = workflowStore.getState()
const {
getNodes,
setNodes,
} = store.getState()
const nodes = getNodes()
setWorkflowRunningData(produce(workflowRunningData!, (draft) => {
const tracing = draft.tracing!
const currentRetryNodeIndex = tracing.findIndex(trace => trace.node_id === data.node_id)
if (currentRetryNodeIndex > -1) {
const currentRetryNode = tracing[currentRetryNodeIndex]
if (currentRetryNode.retryDetail)
draft.tracing![currentRetryNodeIndex].retryDetail!.push(data as NodeTracing)
else
draft.tracing![currentRetryNodeIndex].retryDetail = [data as NodeTracing]
}
}))
const newNodes = produce(nodes, (draft) => {
const currentNode = draft.find(node => node.id === data.node_id)!
currentNode.data._retryIndex = data.retry_index
})
setNodes(newNodes)
if (onNodeRetry)
onNodeRetry(params)
},
onParallelBranchStarted: (params) => {
// console.log(params, 'parallel start')
},

View File

@ -17,17 +17,25 @@ import ResultPanel from '@/app/components/workflow/run/result-panel'
import Toast from '@/app/components/base/toast'
import { TransferMethod } from '@/types/app'
import { getProcessedFiles } from '@/app/components/base/file-uploader/utils'
import type { NodeTracing } from '@/types/workflow'
import RetryResultPanel from '@/app/components/workflow/run/retry-result-panel'
import type { BlockEnum } from '@/app/components/workflow/types'
import type { Emoji } from '@/app/components/tools/types'
const i18nPrefix = 'workflow.singleRun'
type BeforeRunFormProps = {
nodeName: string
nodeType?: BlockEnum
toolIcon?: string | Emoji
onHide: () => void
onRun: (submitData: Record<string, any>) => void
onStop: () => void
runningStatus: NodeRunningStatus
result?: JSX.Element
forms: FormProps[]
retryDetails?: NodeTracing[]
onRetryDetailBack?: any
}
function formatValue(value: string | any, type: InputVarType) {
@ -50,12 +58,16 @@ function formatValue(value: string | any, type: InputVarType) {
}
const BeforeRunForm: FC<BeforeRunFormProps> = ({
nodeName,
nodeType,
toolIcon,
onHide,
onRun,
onStop,
runningStatus,
result,
forms,
retryDetails,
onRetryDetailBack = () => { },
}) => {
const { t } = useTranslation()
@ -122,48 +134,69 @@ const BeforeRunForm: FC<BeforeRunFormProps> = ({
<div className='text-base font-semibold text-gray-900 truncate'>
{t(`${i18nPrefix}.testRun`)} {nodeName}
</div>
<div className='ml-2 shrink-0 p-1 cursor-pointer' onClick={onHide}>
<div className='ml-2 shrink-0 p-1 cursor-pointer' onClick={() => {
onHide()
}}>
<RiCloseLine className='w-4 h-4 text-gray-500 ' />
</div>
</div>
<div className='h-0 grow overflow-y-auto pb-4'>
<div className='mt-3 px-4 space-y-4'>
{forms.map((form, index) => (
<div key={index}>
<Form
key={index}
className={cn(index < forms.length - 1 && 'mb-4')}
{...form}
/>
{index < forms.length - 1 && <Split />}
{
retryDetails?.length && (
<div className='h-0 grow overflow-y-auto pb-4'>
<RetryResultPanel
list={retryDetails.map((item, index) => ({
...item,
title: `${t('workflow.nodes.common.retry.retry')} ${index + 1}`,
node_type: nodeType!,
extras: {
icon: toolIcon!,
},
}))}
onBack={onRetryDetailBack}
/>
</div>
)
}
{
!retryDetails?.length && (
<div className='h-0 grow overflow-y-auto pb-4'>
<div className='mt-3 px-4 space-y-4'>
{forms.map((form, index) => (
<div key={index}>
<Form
key={index}
className={cn(index < forms.length - 1 && 'mb-4')}
{...form}
/>
{index < forms.length - 1 && <Split />}
</div>
))}
</div>
))}
</div>
<div className='mt-4 flex justify-between space-x-2 px-4' >
{isRunning && (
<div
className='p-2 rounded-lg border border-gray-200 bg-white shadow-xs cursor-pointer'
onClick={onStop}
>
<StopCircle className='w-4 h-4 text-gray-500' />
<div className='mt-4 flex justify-between space-x-2 px-4' >
{isRunning && (
<div
className='p-2 rounded-lg border border-gray-200 bg-white shadow-xs cursor-pointer'
onClick={onStop}
>
<StopCircle className='w-4 h-4 text-gray-500' />
</div>
)}
<Button disabled={!isFileLoaded || isRunning} variant='primary' className='w-0 grow space-x-2' onClick={handleRun}>
{isRunning && <RiLoader2Line className='animate-spin w-4 h-4 text-white' />}
<div>{t(`${i18nPrefix}.${isRunning ? 'running' : 'startRun'}`)}</div>
</Button>
</div>
)}
<Button disabled={!isFileLoaded || isRunning} variant='primary' className='w-0 grow space-x-2' onClick={handleRun}>
{isRunning && <RiLoader2Line className='animate-spin w-4 h-4 text-white' />}
<div>{t(`${i18nPrefix}.${isRunning ? 'running' : 'startRun'}`)}</div>
</Button>
</div>
{isRunning && (
<ResultPanel status='running' showSteps={false} />
)}
{isFinished && (
<>
{result}
</>
)}
</div>
{isRunning && (
<ResultPanel status='running' showSteps={false} />
)}
{isFinished && (
<>
{result}
</>
)}
</div>
)
}
</div>
</div>
)

View File

@ -14,7 +14,6 @@ import type {
CommonNodeType,
Node,
} from '@/app/components/workflow/types'
import Split from '@/app/components/workflow/nodes/_base/components/split'
import Tooltip from '@/app/components/base/tooltip'
type ErrorHandleProps = Pick<Node, 'id' | 'data'>
@ -45,7 +44,6 @@ const ErrorHandle = ({
return (
<>
<Split />
<div className='py-4'>
<Collapse
disabled={!error_strategy}

View File

@ -0,0 +1,41 @@
import {
useCallback,
useState,
} from 'react'
import type { WorkflowRetryConfig } from './types'
import {
useNodeDataUpdate,
} from '@/app/components/workflow/hooks'
import type { NodeTracing } from '@/types/workflow'
export const useRetryConfig = (
id: string,
) => {
const { handleNodeDataUpdateWithSyncDraft } = useNodeDataUpdate()
const handleRetryConfigChange = useCallback((value?: WorkflowRetryConfig) => {
handleNodeDataUpdateWithSyncDraft({
id,
data: {
retry_config: value,
},
})
}, [id, handleNodeDataUpdateWithSyncDraft])
return {
handleRetryConfigChange,
}
}
export const useRetryDetailShowInSingleRun = () => {
const [retryDetails, setRetryDetails] = useState<NodeTracing[] | undefined>()
const handleRetryDetailsChange = useCallback((details: NodeTracing[] | undefined) => {
setRetryDetails(details)
}, [])
return {
retryDetails,
handleRetryDetailsChange,
}
}

View File

@ -0,0 +1,88 @@
import { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import {
RiAlertFill,
RiCheckboxCircleFill,
RiLoader2Line,
} from '@remixicon/react'
import type { Node } from '@/app/components/workflow/types'
import { NodeRunningStatus } from '@/app/components/workflow/types'
import cn from '@/utils/classnames'
type RetryOnNodeProps = Pick<Node, 'id' | 'data'>
const RetryOnNode = ({
data,
}: RetryOnNodeProps) => {
const { t } = useTranslation()
const { retry_config } = data
const showSelectedBorder = data.selected || data._isBundled || data._isEntering
const {
isRunning,
isSuccessful,
isException,
isFailed,
} = useMemo(() => {
return {
isRunning: data._runningStatus === NodeRunningStatus.Running && !showSelectedBorder,
isSuccessful: data._runningStatus === NodeRunningStatus.Succeeded && !showSelectedBorder,
isFailed: data._runningStatus === NodeRunningStatus.Failed && !showSelectedBorder,
isException: data._runningStatus === NodeRunningStatus.Exception && !showSelectedBorder,
}
}, [data._runningStatus, showSelectedBorder])
const showDefault = !isRunning && !isSuccessful && !isException && !isFailed
if (!retry_config)
return null
return (
<div className='px-3'>
<div className={cn(
'flex items-center justify-between px-[5px] py-1 bg-workflow-block-parma-bg border-[0.5px] border-transparent rounded-md system-xs-medium-uppercase text-text-tertiary',
isRunning && 'bg-state-accent-hover border-state-accent-active text-text-accent',
isSuccessful && 'bg-state-success-hover border-state-success-active text-text-success',
(isException || isFailed) && 'bg-state-warning-hover border-state-warning-active text-text-warning',
)}>
<div className='flex items-center'>
{
showDefault && (
t('workflow.nodes.common.retry.retryTimes', { times: retry_config.max_retries })
)
}
{
isRunning && (
<>
<RiLoader2Line className='animate-spin mr-1 w-3.5 h-3.5' />
{t('workflow.nodes.common.retry.retrying')}
</>
)
}
{
isSuccessful && (
<>
<RiCheckboxCircleFill className='mr-1 w-3.5 h-3.5' />
{t('workflow.nodes.common.retry.retrySuccessful')}
</>
)
}
{
(isFailed || isException) && (
<>
<RiAlertFill className='mr-1 w-3.5 h-3.5' />
{t('workflow.nodes.common.retry.retryFailed')}
</>
)
}
</div>
{
!showDefault && (
<div>
{data._retryIndex}/{data.retry_config?.max_retries}
</div>
)
}
</div>
</div>
)
}
export default RetryOnNode

View File

@ -0,0 +1,117 @@
import { useTranslation } from 'react-i18next'
import { useRetryConfig } from './hooks'
import s from './style.module.css'
import Switch from '@/app/components/base/switch'
import Slider from '@/app/components/base/slider'
import Input from '@/app/components/base/input'
import type {
Node,
} from '@/app/components/workflow/types'
import Split from '@/app/components/workflow/nodes/_base/components/split'
type RetryOnPanelProps = Pick<Node, 'id' | 'data'>
const RetryOnPanel = ({
id,
data,
}: RetryOnPanelProps) => {
const { t } = useTranslation()
const { handleRetryConfigChange } = useRetryConfig(id)
const { retry_config } = data
const handleRetryEnabledChange = (value: boolean) => {
handleRetryConfigChange({
retry_enabled: value,
max_retries: retry_config?.max_retries || 3,
retry_interval: retry_config?.retry_interval || 1000,
})
}
const handleMaxRetriesChange = (value: number) => {
if (value > 10)
value = 10
else if (value < 1)
value = 1
handleRetryConfigChange({
retry_enabled: true,
max_retries: value,
retry_interval: retry_config?.retry_interval || 1000,
})
}
const handleRetryIntervalChange = (value: number) => {
if (value > 5000)
value = 5000
else if (value < 100)
value = 100
handleRetryConfigChange({
retry_enabled: true,
max_retries: retry_config?.max_retries || 3,
retry_interval: value,
})
}
return (
<>
<div className='pt-2'>
<div className='flex items-center justify-between px-4 py-2 h-10'>
<div className='flex items-center'>
<div className='mr-0.5 system-sm-semibold-uppercase text-text-secondary'>{t('workflow.nodes.common.retry.retryOnFailure')}</div>
</div>
<Switch
defaultValue={retry_config?.retry_enabled}
onChange={v => handleRetryEnabledChange(v)}
/>
</div>
{
retry_config?.retry_enabled && (
<div className='px-4 pb-2'>
<div className='flex items-center mb-1 w-full'>
<div className='grow mr-2 system-xs-medium-uppercase'>{t('workflow.nodes.common.retry.maxRetries')}</div>
<Slider
className='mr-3 w-[108px]'
value={retry_config?.max_retries || 3}
onChange={handleMaxRetriesChange}
min={1}
max={10}
/>
<Input
type='number'
wrapperClassName='w-[80px]'
value={retry_config?.max_retries || 3}
onChange={e => handleMaxRetriesChange(e.target.value as any)}
min={1}
max={10}
unit={t('workflow.nodes.common.retry.times') || ''}
className={s.input}
/>
</div>
<div className='flex items-center'>
<div className='grow mr-2 system-xs-medium-uppercase'>{t('workflow.nodes.common.retry.retryInterval')}</div>
<Slider
className='mr-3 w-[108px]'
value={retry_config?.retry_interval || 1000}
onChange={handleRetryIntervalChange}
min={100}
max={5000}
/>
<Input
type='number'
wrapperClassName='w-[80px]'
value={retry_config?.retry_interval || 1000}
onChange={e => handleRetryIntervalChange(e.target.value as any)}
min={100}
max={5000}
unit={t('workflow.nodes.common.retry.ms') || ''}
className={s.input}
/>
</div>
</div>
)
}
</div>
<Split className='mx-4 mt-2' />
</>
)
}
export default RetryOnPanel

View File

@ -0,0 +1,5 @@
.input::-webkit-inner-spin-button,
.input::-webkit-outer-spin-button {
-webkit-appearance: none;
margin: 0;
}

View File

@ -0,0 +1,5 @@
export type WorkflowRetryConfig = {
max_retries: number
retry_interval: number
retry_enabled: boolean
}

View File

@ -25,7 +25,10 @@ import {
useNodesReadOnly,
useToolIcon,
} from '../../hooks'
import { hasErrorHandleNode } from '../../utils'
import {
hasErrorHandleNode,
hasRetryNode,
} from '../../utils'
import { useNodeIterationInteractions } from '../iteration/use-interactions'
import type { IterationNodeType } from '../iteration/types'
import {
@ -35,6 +38,7 @@ import {
import NodeResizer from './components/node-resizer'
import NodeControl from './components/node-control'
import ErrorHandleOnNode from './components/error-handle/error-handle-on-node'
import RetryOnNode from './components/retry/retry-on-node'
import AddVariablePopupWithPosition from './components/add-variable-popup-with-position'
import cn from '@/utils/classnames'
import BlockIcon from '@/app/components/workflow/block-icon'
@ -237,6 +241,14 @@ const BaseNode: FC<BaseNodeProps> = ({
</div>
)
}
{
hasRetryNode(data.type) && (
<RetryOnNode
id={id}
data={data}
/>
)
}
{
hasErrorHandleNode(data.type) && (
<ErrorHandleOnNode

View File

@ -21,9 +21,11 @@ import {
TitleInput,
} from './components/title-description-input'
import ErrorHandleOnPanel from './components/error-handle/error-handle-on-panel'
import RetryOnPanel from './components/retry/retry-on-panel'
import { useResizePanel } from './hooks/use-resize-panel'
import cn from '@/utils/classnames'
import BlockIcon from '@/app/components/workflow/block-icon'
import Split from '@/app/components/workflow/nodes/_base/components/split'
import {
WorkflowHistoryEvent,
useAvailableBlocks,
@ -38,6 +40,7 @@ import {
import {
canRunBySingle,
hasErrorHandleNode,
hasRetryNode,
} from '@/app/components/workflow/utils'
import Tooltip from '@/app/components/base/tooltip'
import type { Node } from '@/app/components/workflow/types'
@ -168,6 +171,15 @@ const BasePanel: FC<BasePanelProps> = ({
<div>
{cloneElement(children, { id, data })}
</div>
<Split />
{
hasRetryNode(data.type) && (
<RetryOnPanel
id={id}
data={data}
/>
)
}
{
hasErrorHandleNode(data.type) && (
<ErrorHandleOnPanel

View File

@ -2,7 +2,10 @@ import { BlockEnum } from '../../types'
import type { NodeDefault } from '../../types'
import { AuthorizationType, BodyType, Method } from './types'
import type { BodyPayload, HttpNodeType } from './types'
import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants'
import {
ALL_CHAT_AVAILABLE_BLOCKS,
ALL_COMPLETION_AVAILABLE_BLOCKS,
} from '@/app/components/workflow/constants'
const nodeDefault: NodeDefault<HttpNodeType> = {
defaultValue: {
@ -24,6 +27,11 @@ const nodeDefault: NodeDefault<HttpNodeType> = {
max_read_timeout: 0,
max_write_timeout: 0,
},
retry_config: {
retry_enabled: true,
max_retries: 3,
retry_interval: 100,
},
},
getAvailablePrevNodes(isChatMode: boolean) {
const nodes = isChatMode

View File

@ -1,5 +1,5 @@
import type { FC } from 'react'
import React from 'react'
import { memo } from 'react'
import { useTranslation } from 'react-i18next'
import useConfig from './use-config'
import ApiInput from './components/api-input'
@ -18,6 +18,7 @@ import { FileArrow01 } from '@/app/components/base/icons/src/vender/line/files'
import type { NodePanelProps } from '@/app/components/workflow/types'
import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form'
import ResultPanel from '@/app/components/workflow/run/result-panel'
import { useRetryDetailShowInSingleRun } from '@/app/components/workflow/nodes/_base/components/retry/hooks'
const i18nPrefix = 'workflow.nodes.http'
@ -60,6 +61,10 @@ const Panel: FC<NodePanelProps<HttpNodeType>> = ({
hideCurlPanel,
handleCurlImport,
} = useConfig(id, data)
const {
retryDetails,
handleRetryDetailsChange,
} = useRetryDetailShowInSingleRun()
// To prevent prompt editor in body not update data.
if (!isDataReady)
return null
@ -181,6 +186,7 @@ const Panel: FC<NodePanelProps<HttpNodeType>> = ({
{isShowSingleRun && (
<BeforeRunForm
nodeName={inputs.title}
nodeType={inputs.type}
onHide={hideSingleRun}
forms={[
{
@ -192,7 +198,9 @@ const Panel: FC<NodePanelProps<HttpNodeType>> = ({
runningStatus={runningStatus}
onRun={handleRun}
onStop={handleStop}
result={<ResultPanel {...runResult} showSteps={false} />}
retryDetails={retryDetails}
onRetryDetailBack={handleRetryDetailsChange}
result={<ResultPanel {...runResult} showSteps={false} onShowRetryDetail={handleRetryDetailsChange} />}
/>
)}
{(isShowCurlPanel && !readOnly) && (
@ -207,4 +215,4 @@ const Panel: FC<NodePanelProps<HttpNodeType>> = ({
)
}
export default React.memo(Panel)
export default memo(Panel)

View File

@ -129,9 +129,6 @@ export const getMultipleRetrievalConfig = (
reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true,
}
if (!rerankModelIsValid)
result.reranking_model = undefined
const setDefaultWeights = () => {
result.weights = {
vector_setting: {
@ -198,7 +195,6 @@ export const getMultipleRetrievalConfig = (
setDefaultWeights()
}
}
if (reranking_mode === RerankingModeEnum.RerankingModel && !rerankModelIsValid && shouldSetWeightDefaultValue) {
result.reranking_mode = RerankingModeEnum.WeightedScore
setDefaultWeights()

View File

@ -19,6 +19,7 @@ import type { Props as FormProps } from '@/app/components/workflow/nodes/_base/c
import ResultPanel from '@/app/components/workflow/run/result-panel'
import Tooltip from '@/app/components/base/tooltip'
import Editor from '@/app/components/workflow/nodes/_base/components/prompt/editor'
import { useRetryDetailShowInSingleRun } from '@/app/components/workflow/nodes/_base/components/retry/hooks'
const i18nPrefix = 'workflow.nodes.llm'
@ -69,6 +70,10 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
runResult,
filterJinjia2InputVar,
} = useConfig(id, data)
const {
retryDetails,
handleRetryDetailsChange,
} = useRetryDetailShowInSingleRun()
const model = inputs.model
@ -282,12 +287,15 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
{isShowSingleRun && (
<BeforeRunForm
nodeName={inputs.title}
nodeType={inputs.type}
onHide={hideSingleRun}
forms={singleRunForms}
runningStatus={runningStatus}
onRun={handleRun}
onStop={handleStop}
result={<ResultPanel {...runResult} showSteps={false} />}
retryDetails={retryDetails}
onRetryDetailBack={handleRetryDetailsChange}
result={<ResultPanel {...runResult} showSteps={false} onShowRetryDetail={handleRetryDetailsChange} />}
/>
)}
</div>

View File

@ -162,7 +162,7 @@ const InputVarList: FC<Props> = ({
readonly={readOnly}
isShowNodeName
nodeId={nodeId}
value={varInput?.type === VarKindType.constant ? (varInput?.value || '') : (varInput?.value || [])}
value={varInput?.type === VarKindType.constant ? (varInput?.value ?? '') : (varInput?.value ?? [])}
onChange={handleNotMixedTypeChange(variable)}
onOpen={handleOpen(index)}
defaultVarKindType={varInput?.type || (isNumber ? VarKindType.constant : VarKindType.variable)}

View File

@ -14,6 +14,8 @@ import Loading from '@/app/components/base/loading'
import BeforeRunForm from '@/app/components/workflow/nodes/_base/components/before-run-form'
import OutputVars, { VarItem } from '@/app/components/workflow/nodes/_base/components/output-vars'
import ResultPanel from '@/app/components/workflow/run/result-panel'
import { useRetryDetailShowInSingleRun } from '@/app/components/workflow/nodes/_base/components/retry/hooks'
import { useToolIcon } from '@/app/components/workflow/hooks'
const i18nPrefix = 'workflow.nodes.tool'
@ -48,6 +50,11 @@ const Panel: FC<NodePanelProps<ToolNodeType>> = ({
handleStop,
runResult,
} = useConfig(id, data)
const toolIcon = useToolIcon(data)
const {
retryDetails,
handleRetryDetailsChange,
} = useRetryDetailShowInSingleRun()
if (isLoading) {
return <div className='flex h-[200px] items-center justify-center'>
@ -143,12 +150,16 @@ const Panel: FC<NodePanelProps<ToolNodeType>> = ({
{isShowSingleRun && (
<BeforeRunForm
nodeName={inputs.title}
nodeType={inputs.type}
toolIcon={toolIcon}
onHide={hideSingleRun}
forms={singleRunForms}
runningStatus={runningStatus}
onRun={handleRun}
onStop={handleStop}
result={<ResultPanel {...runResult} showSteps={false} />}
retryDetails={retryDetails}
onRetryDetailBack={handleRetryDetailsChange}
result={<ResultPanel {...runResult} showSteps={false} onShowRetryDetail={handleRetryDetailsChange} />}
/>
)}
</div>

View File

@ -27,6 +27,7 @@ import {
getProcessedFilesFromResponse,
} from '@/app/components/base/file-uploader/utils'
import type { FileEntity } from '@/app/components/base/file-uploader/types'
import type { NodeTracing } from '@/types/workflow'
type GetAbortController = (abortController: AbortController) => void
type SendCallback = {
@ -381,6 +382,28 @@ export const useChat = (
}
}))
},
onNodeRetry: ({ data }) => {
if (data.iteration_id)
return
const currentIndex = responseItem.workflowProcess!.tracing!.findIndex((item) => {
if (!item.execution_metadata?.parallel_id)
return item.node_id === data.node_id
return item.node_id === data.node_id && (item.execution_metadata?.parallel_id === data.execution_metadata?.parallel_id || item.parallel_id === data.execution_metadata?.parallel_id)
})
if (responseItem.workflowProcess!.tracing[currentIndex].retryDetail)
responseItem.workflowProcess!.tracing[currentIndex].retryDetail?.push(data as NodeTracing)
else
responseItem.workflowProcess!.tracing[currentIndex].retryDetail = [data as NodeTracing]
handleUpdateChatList(produce(chatListRef.current, (draft) => {
const currentIndex = draft.findIndex(item => item.id === responseItem.id)
draft[currentIndex] = {
...draft[currentIndex],
...responseItem,
}
}))
},
onNodeFinished: ({ data }) => {
if (data.iteration_id)
return
@ -394,6 +417,9 @@ export const useChat = (
...(responseItem.workflowProcess!.tracing[currentIndex]?.extras
? { extras: responseItem.workflowProcess!.tracing[currentIndex].extras }
: {}),
...(responseItem.workflowProcess!.tracing[currentIndex]?.retryDetail
? { retryDetail: responseItem.workflowProcess!.tracing[currentIndex].retryDetail }
: {}),
...data,
} as any
handleUpdateChatList(produce(chatListRef.current, (draft) => {

View File

@ -25,6 +25,7 @@ import {
import { SimpleBtn } from '../../app/text-generate/item'
import Toast from '../../base/toast'
import IterationResultPanel from '../run/iteration-result-panel'
import RetryResultPanel from '../run/retry-result-panel'
import InputsPanel from './inputs-panel'
import cn from '@/utils/classnames'
import Loading from '@/app/components/base/loading'
@ -53,11 +54,16 @@ const WorkflowPreview = () => {
}, [workflowRunningData])
const [iterationRunResult, setIterationRunResult] = useState<NodeTracing[][]>([])
const [retryRunResult, setRetryRunResult] = useState<NodeTracing[]>([])
const [iterDurationMap, setIterDurationMap] = useState<IterationDurationMap>({})
const [isShowIterationDetail, {
setTrue: doShowIterationDetail,
setFalse: doHideIterationDetail,
}] = useBoolean(false)
const [isShowRetryDetail, {
setTrue: doShowRetryDetail,
setFalse: doHideRetryDetail,
}] = useBoolean(false)
const handleShowIterationDetail = useCallback((detail: NodeTracing[][], iterationDurationMap: IterationDurationMap) => {
setIterDurationMap(iterationDurationMap)
@ -65,6 +71,11 @@ const WorkflowPreview = () => {
doShowIterationDetail()
}, [doShowIterationDetail])
const handleRetryDetail = useCallback((detail: NodeTracing[]) => {
setRetryRunResult(detail)
doShowRetryDetail()
}, [doShowRetryDetail])
if (isShowIterationDetail) {
return (
<div className={`
@ -201,11 +212,12 @@ const WorkflowPreview = () => {
<Loading />
</div>
)}
{currentTab === 'TRACING' && (
{currentTab === 'TRACING' && !isShowRetryDetail && (
<TracingPanel
className='bg-background-section-burn'
list={workflowRunningData?.tracing || []}
onShowIterationDetail={handleShowIterationDetail}
onShowRetryDetail={handleRetryDetail}
/>
)}
{currentTab === 'TRACING' && !workflowRunningData?.tracing?.length && (
@ -213,7 +225,14 @@ const WorkflowPreview = () => {
<Loading />
</div>
)}
{
currentTab === 'TRACING' && isShowRetryDetail && (
<RetryResultPanel
list={retryRunResult}
onBack={doHideRetryDetail}
/>
)
}
</div>
</>
)}

View File

@ -9,6 +9,7 @@ import OutputPanel from './output-panel'
import ResultPanel from './result-panel'
import TracingPanel from './tracing-panel'
import IterationResultPanel from './iteration-result-panel'
import RetryResultPanel from './retry-result-panel'
import cn from '@/utils/classnames'
import { ToastContext } from '@/app/components/base/toast'
import Loading from '@/app/components/base/loading'
@ -107,6 +108,18 @@ const RunPanel: FC<RunProps> = ({ hideResult, activeTab = 'RESULT', runID, getRe
const processNonIterationNode = (item: NodeTracing) => {
const { execution_metadata } = item
if (!execution_metadata?.iteration_id) {
if (item.status === 'retry') {
const retryNode = result.find(node => node.node_id === item.node_id)
if (retryNode) {
if (retryNode?.retryDetail)
retryNode.retryDetail.push(item)
else
retryNode.retryDetail = [item]
}
return
}
result.push(item)
return
}
@ -181,10 +194,15 @@ const RunPanel: FC<RunProps> = ({ hideResult, activeTab = 'RESULT', runID, getRe
const [iterationRunResult, setIterationRunResult] = useState<NodeTracing[][]>([])
const [iterDurationMap, setIterDurationMap] = useState<IterationDurationMap>({})
const [retryRunResult, setRetryRunResult] = useState<NodeTracing[]>([])
const [isShowIterationDetail, {
setTrue: doShowIterationDetail,
setFalse: doHideIterationDetail,
}] = useBoolean(false)
const [isShowRetryDetail, {
setTrue: doShowRetryDetail,
setFalse: doHideRetryDetail,
}] = useBoolean(false)
const handleShowIterationDetail = useCallback((detail: NodeTracing[][], iterDurationMap: IterationDurationMap) => {
setIterationRunResult(detail)
@ -192,6 +210,11 @@ const RunPanel: FC<RunProps> = ({ hideResult, activeTab = 'RESULT', runID, getRe
setIterDurationMap(iterDurationMap)
}, [doShowIterationDetail, setIterationRunResult, setIterDurationMap])
const handleShowRetryDetail = useCallback((detail: NodeTracing[]) => {
setRetryRunResult(detail)
doShowRetryDetail()
}, [doShowRetryDetail, setRetryRunResult])
if (isShowIterationDetail) {
return (
<div className='grow relative flex flex-col'>
@ -261,13 +284,22 @@ const RunPanel: FC<RunProps> = ({ hideResult, activeTab = 'RESULT', runID, getRe
exceptionCounts={runDetail.exceptions_count}
/>
)}
{!loading && currentTab === 'TRACING' && (
{!loading && currentTab === 'TRACING' && !isShowRetryDetail && (
<TracingPanel
className='bg-background-section-burn'
list={list}
onShowIterationDetail={handleShowIterationDetail}
onShowRetryDetail={handleShowRetryDetail}
/>
)}
{
!loading && currentTab === 'TRACING' && isShowRetryDetail && (
<RetryResultPanel
list={retryRunResult}
onBack={doHideRetryDetail}
/>
)
}
</div>
</div>
)

View File

@ -8,6 +8,7 @@ import {
RiCheckboxCircleFill,
RiErrorWarningLine,
RiLoader2Line,
RiRestartFill,
} from '@remixicon/react'
import BlockIcon from '../block-icon'
import { BlockEnum } from '../types'
@ -20,6 +21,7 @@ import Button from '@/app/components/base/button'
import { CodeLanguage } from '@/app/components/workflow/nodes/code/types'
import type { IterationDurationMap, NodeTracing } from '@/types/workflow'
import ErrorHandleTip from '@/app/components/workflow/nodes/_base/components/error-handle/error-handle-tip'
import { hasRetryNode } from '@/app/components/workflow/utils'
type Props = {
className?: string
@ -28,8 +30,10 @@ type Props = {
hideInfo?: boolean
hideProcessDetail?: boolean
onShowIterationDetail?: (detail: NodeTracing[][], iterDurationMap: IterationDurationMap) => void
onShowRetryDetail?: (detail: NodeTracing[]) => void
notShowIterationNav?: boolean
justShowIterationNavArrow?: boolean
justShowRetryNavArrow?: boolean
}
const NodePanel: FC<Props> = ({
@ -39,6 +43,7 @@ const NodePanel: FC<Props> = ({
hideInfo = false,
hideProcessDetail,
onShowIterationDetail,
onShowRetryDetail,
notShowIterationNav,
justShowIterationNavArrow,
}) => {
@ -88,11 +93,17 @@ const NodePanel: FC<Props> = ({
}, [nodeInfo.expand, setCollapseState])
const isIterationNode = nodeInfo.node_type === BlockEnum.Iteration
const isRetryNode = hasRetryNode(nodeInfo.node_type) && nodeInfo.retryDetail
const handleOnShowIterationDetail = (e: React.MouseEvent<HTMLButtonElement>) => {
e.stopPropagation()
e.nativeEvent.stopImmediatePropagation()
onShowIterationDetail?.(nodeInfo.details || [], nodeInfo?.iterDurationMap || nodeInfo.execution_metadata?.iteration_duration_map || {})
}
const handleOnShowRetryDetail = (e: React.MouseEvent<HTMLButtonElement>) => {
e.stopPropagation()
e.nativeEvent.stopImmediatePropagation()
onShowRetryDetail?.(nodeInfo.retryDetail || [])
}
return (
<div className={cn('px-2 py-1', className)}>
<div className='group transition-all bg-background-default border border-components-panel-border rounded-[10px] shadow-xs hover:shadow-md'>
@ -169,6 +180,19 @@ const NodePanel: FC<Props> = ({
<Split className='mt-2' />
</div>
)}
{isRetryNode && (
<Button
className='flex items-center justify-between mb-1 w-full'
variant='tertiary'
onClick={handleOnShowRetryDetail}
>
<div className='flex items-center'>
<RiRestartFill className='mr-0.5 w-4 h-4 text-components-button-tertiary-text flex-shrink-0' />
{t('workflow.nodes.common.retry.retries', { num: nodeInfo.retryDetail?.length })}
</div>
<RiArrowRightSLine className='w-4 h-4 text-components-button-tertiary-text flex-shrink-0' />
</Button>
)}
<div className={cn('mb-1', hideInfo && '!px-2 !py-0.5')}>
{(nodeInfo.status === 'stopped') && (
<StatusContainer status='stopped'>

View File

@ -1,11 +1,17 @@
'use client'
import type { FC } from 'react'
import { useTranslation } from 'react-i18next'
import {
RiArrowRightSLine,
RiRestartFill,
} from '@remixicon/react'
import StatusPanel from './status'
import MetaData from './meta'
import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor'
import { CodeLanguage } from '@/app/components/workflow/nodes/code/types'
import ErrorHandleTip from '@/app/components/workflow/nodes/_base/components/error-handle/error-handle-tip'
import type { NodeTracing } from '@/types/workflow'
import Button from '@/app/components/base/button'
type ResultPanelProps = {
inputs?: string
@ -22,6 +28,8 @@ type ResultPanelProps = {
showSteps?: boolean
exceptionCounts?: number
execution_metadata?: any
retry_events?: NodeTracing[]
onShowRetryDetail?: (retries: NodeTracing[]) => void
}
const ResultPanel: FC<ResultPanelProps> = ({
@ -38,8 +46,11 @@ const ResultPanel: FC<ResultPanelProps> = ({
showSteps,
exceptionCounts,
execution_metadata,
retry_events,
onShowRetryDetail,
}) => {
const { t } = useTranslation()
return (
<div className='bg-components-panel-bg py-2'>
<div className='px-4 py-2'>
@ -51,6 +62,23 @@ const ResultPanel: FC<ResultPanelProps> = ({
exceptionCounts={exceptionCounts}
/>
</div>
{
retry_events?.length && onShowRetryDetail && (
<div className='px-4'>
<Button
className='flex items-center justify-between w-full'
variant='tertiary'
onClick={() => onShowRetryDetail(retry_events)}
>
<div className='flex items-center'>
<RiRestartFill className='mr-0.5 w-4 h-4 text-components-button-tertiary-text flex-shrink-0' />
{t('workflow.nodes.common.retry.retries', { num: retry_events?.length })}
</div>
<RiArrowRightSLine className='w-4 h-4 text-components-button-tertiary-text flex-shrink-0' />
</Button>
</div>
)
}
<div className='px-4 py-2 flex flex-col gap-2'>
<CodeEditor
readOnly

View File

@ -0,0 +1,46 @@
'use client'
import type { FC } from 'react'
import { memo } from 'react'
import { useTranslation } from 'react-i18next'
import {
RiArrowLeftLine,
} from '@remixicon/react'
import TracingPanel from './tracing-panel'
import type { NodeTracing } from '@/types/workflow'
type Props = {
list: NodeTracing[]
onBack: () => void
}
const RetryResultPanel: FC<Props> = ({
list,
onBack,
}) => {
const { t } = useTranslation()
return (
<div>
<div
className='flex items-center px-4 h-8 text-text-accent-secondary bg-components-panel-bg system-sm-medium cursor-pointer'
onClick={(e) => {
e.stopPropagation()
e.nativeEvent.stopImmediatePropagation()
onBack()
}}
>
<RiArrowLeftLine className='mr-1 w-4 h-4' />
{t('workflow.singleRun.back')}
</div>
<TracingPanel
list={list.map((item, index) => ({
...item,
title: `${t('workflow.nodes.common.retry.retry')} ${index + 1}`,
}))}
className='bg-background-section-burn'
/>
</div >
)
}
export default memo(RetryResultPanel)

Some files were not shown because too many files have changed in this diff Show More