diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml
index b6c9131c08..9c79dbc57e 100644
--- a/.github/workflows/style.yml
+++ b/.github/workflows/style.yml
@@ -89,7 +89,9 @@ jobs:
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
- run: pnpm run lint
+ run: |
+ pnpm run lint
+ pnpm run eslint
docker-compose-template:
name: Docker Compose Template
diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py
index 19ca464a79..f730cfa3fe 100644
--- a/api/controllers/console/auth/oauth_server.py
+++ b/api/controllers/console/auth/oauth_server.py
@@ -44,22 +44,19 @@ def oauth_server_access_token_required(view):
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app")
- if not request.headers.get("Authorization"):
- raise BadRequest("Authorization is required")
-
authorization_header = request.headers.get("Authorization")
if not authorization_header:
raise BadRequest("Authorization header is required")
- parts = authorization_header.split(" ")
+ parts = authorization_header.strip().split(" ")
if len(parts) != 2:
raise BadRequest("Invalid Authorization header format")
- token_type = parts[0]
- if token_type != "Bearer":
+ token_type = parts[0].strip()
+ if token_type.lower() != "bearer":
raise BadRequest("token_type is invalid")
- access_token = parts[1]
+ access_token = parts[1].strip()
if not access_token:
raise BadRequest("access_token is required")
@@ -125,7 +122,10 @@ class OAuthServerUserTokenApi(Resource):
parser.add_argument("refresh_token", type=str, required=False, location="json")
parsed_args = parser.parse_args()
- grant_type = OAuthGrantType(parsed_args["grant_type"])
+ try:
+ grant_type = OAuthGrantType(parsed_args["grant_type"])
+ except ValueError:
+ raise BadRequest("invalid grant_type")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not parsed_args["code"]:
@@ -163,8 +163,6 @@ class OAuthServerUserTokenApi(Resource):
"refresh_token": refresh_token,
}
)
- else:
- raise BadRequest("invalid grant_type")
class OAuthServerUserAccountApi(Resource):
diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py
index 8d50b0d41c..22bb81f9e3 100644
--- a/api/controllers/console/datasets/datasets_document.py
+++ b/api/controllers/console/datasets/datasets_document.py
@@ -354,9 +354,6 @@ class DatasetInitApi(Resource):
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
- # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
- if not current_user.is_dataset_editor:
- raise Forbidden()
knowledge_config = KnowledgeConfig(**args)
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py
index c5aa318f58..de4f1da801 100644
--- a/api/controllers/inner_api/wraps.py
+++ b/api/controllers/inner_api/wraps.py
@@ -1,8 +1,12 @@
from base64 import b64encode
+from collections.abc import Callable
from functools import wraps
from hashlib import sha1
from hmac import new as hmac_new
+from typing import ParamSpec, TypeVar
+P = ParamSpec("P")
+R = TypeVar("R")
from flask import abort, request
from configs import dify_config
@@ -10,9 +14,9 @@ from extensions.ext_database import db
from models.model import EndUser
-def billing_inner_api_only(view):
+def billing_inner_api_only(view: Callable[P, R]):
@wraps(view)
- def decorated(*args, **kwargs):
+ def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API:
abort(404)
@@ -26,9 +30,9 @@ def billing_inner_api_only(view):
return decorated
-def enterprise_inner_api_only(view):
+def enterprise_inner_api_only(view: Callable[P, R]):
@wraps(view)
- def decorated(*args, **kwargs):
+ def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API:
abort(404)
@@ -78,9 +82,9 @@ def enterprise_inner_api_user_auth(view):
return decorated
-def plugin_inner_api_only(view):
+def plugin_inner_api_only(view: Callable[P, R]):
@wraps(view)
- def decorated(*args, **kwargs):
+ def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.PLUGIN_DAEMON_KEY:
abort(404)
diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py
index d2473c15af..cc4b5f65bd 100644
--- a/api/controllers/service_api/wraps.py
+++ b/api/controllers/service_api/wraps.py
@@ -1,7 +1,7 @@
import time
from collections.abc import Callable
from datetime import timedelta
-from enum import Enum
+from enum import StrEnum, auto
from functools import wraps
from typing import Optional
@@ -23,14 +23,14 @@ from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService
-class WhereisUserArg(Enum):
+class WhereisUserArg(StrEnum):
"""
Enum for whereis_user_arg.
"""
- QUERY = "query"
- JSON = "json"
- FORM = "form"
+ QUERY = auto()
+ JSON = auto()
+ FORM = auto()
class FetchUserArg(BaseModel):
diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py
index f7c83f927f..f5e45bcb47 100644
--- a/api/core/agent/base_agent_runner.py
+++ b/api/core/agent/base_agent_runner.py
@@ -334,7 +334,8 @@ class BaseAgentRunner(AppRunner):
"""
Save agent thought
"""
- agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first()
+ stmt = select(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id)
+ agent_thought = db.session.scalar(stmt)
if not agent_thought:
raise ValueError("agent thought not found")
@@ -492,7 +493,8 @@ class BaseAgentRunner(AppRunner):
return result
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
- files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
+ stmt = select(MessageFile).where(MessageFile.message_id == message.id)
+ files = db.session.scalars(stmt).all()
if not files:
return UserPromptMessage(content=message.query)
if message.app_model_config:
diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py
index 52ae20ee16..74e282fdcd 100644
--- a/api/core/app/apps/advanced_chat/app_generator.py
+++ b/api/core/app/apps/advanced_chat/app_generator.py
@@ -450,6 +450,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
worker_thread.start()
+ # release database connection, because the following new thread operations may take a long time
+ db.session.refresh(workflow)
+ db.session.refresh(message)
+ db.session.refresh(user)
+ db.session.close()
+
# return response or stream generator
response = self._handle_advanced_chat_response(
application_generate_entity=application_generate_entity,
diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py
index 452dbbec01..2a8efad15e 100644
--- a/api/core/app/apps/advanced_chat/app_runner.py
+++ b/api/core/app/apps/advanced_chat/app_runner.py
@@ -73,7 +73,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
- app_record = db.session.query(App).where(App.id == app_config.app_id).first()
+ with Session(db.engine, expire_on_commit=False) as session:
+ app_record = session.scalar(select(App).where(App.id == app_config.app_id))
+
if not app_record:
raise ValueError("App not found")
@@ -147,7 +149,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
- conversation_variables=cast(list[VariableUnion], conversation_variables),
+ conversation_variables=conversation_variables,
)
# init graph
diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
index b08bd9c872..10e51ff338 100644
--- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py
+++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
@@ -68,7 +68,6 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
-from events.message_event import message_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Conversation, EndUser, Message, MessageFile
@@ -886,10 +885,6 @@ class AdvancedChatAppGenerateTaskPipeline:
self._task_state.metadata.usage = usage
else:
self._task_state.metadata.usage = LLMUsage.empty_usage()
- message_was_created.send(
- message,
- application_generate_entity=self._application_generate_entity,
- )
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
"""
diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py
index 39d6ba39f5..d3207365f3 100644
--- a/api/core/app/apps/agent_chat/app_runner.py
+++ b/api/core/app/apps/agent_chat/app_runner.py
@@ -1,6 +1,8 @@
import logging
from typing import cast
+from sqlalchemy import select
+
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
from core.agent.entities import AgentEntity
@@ -44,8 +46,8 @@ class AgentChatAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(AgentChatAppConfig, app_config)
-
- app_record = db.session.query(App).where(App.id == app_config.app_id).first()
+ app_stmt = select(App).where(App.id == app_config.app_id)
+ app_record = db.session.scalar(app_stmt)
if not app_record:
raise ValueError("App not found")
@@ -182,11 +184,12 @@ class AgentChatAppRunner(AppRunner):
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
-
- conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first()
+ conversation_stmt = select(Conversation).where(Conversation.id == conversation.id)
+ conversation_result = db.session.scalar(conversation_stmt)
if conversation_result is None:
raise ValueError("Conversation not found")
- message_result = db.session.query(Message).where(Message.id == message.id).first()
+ msg_stmt = select(Message).where(Message.id == message.id)
+ message_result = db.session.scalar(msg_stmt)
if message_result is None:
raise ValueError("Message not found")
db.session.close()
diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py
index 2cffe4a0a5..feacca1a07 100644
--- a/api/core/app/apps/base_app_queue_manager.py
+++ b/api/core/app/apps/base_app_queue_manager.py
@@ -1,7 +1,7 @@
import queue
import time
from abc import abstractmethod
-from enum import Enum
+from enum import IntEnum, auto
from typing import Any, Optional
from sqlalchemy.orm import DeclarativeMeta
@@ -19,9 +19,9 @@ from core.app.entities.queue_entities import (
from extensions.ext_redis import redis_client
-class PublishFrom(Enum):
- APPLICATION_MANAGER = 1
- TASK_PIPELINE = 2
+class PublishFrom(IntEnum):
+ APPLICATION_MANAGER = auto()
+ TASK_PIPELINE = auto()
class AppQueueManager:
diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py
index 894d7906d5..4385d0f08d 100644
--- a/api/core/app/apps/chat/app_runner.py
+++ b/api/core/app/apps/chat/app_runner.py
@@ -1,6 +1,8 @@
import logging
from typing import cast
+from sqlalchemy import select
+
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.chat.app_config_manager import ChatAppConfig
@@ -42,8 +44,8 @@ class ChatAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(ChatAppConfig, app_config)
-
- app_record = db.session.query(App).where(App.id == app_config.app_id).first()
+ stmt = select(App).where(App.id == app_config.app_id)
+ app_record = db.session.scalar(stmt)
if not app_record:
raise ValueError("App not found")
diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py
index 64dade2968..8d2f3d488b 100644
--- a/api/core/app/apps/completion/app_generator.py
+++ b/api/core/app/apps/completion/app_generator.py
@@ -6,6 +6,7 @@ from typing import Any, Literal, Union, overload
from flask import Flask, copy_current_request_context, current_app
from pydantic import ValidationError
+from sqlalchemy import select
from configs import dify_config
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
@@ -248,17 +249,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
"""
- message = (
- db.session.query(Message)
- .where(
- Message.id == message_id,
- Message.app_id == app_model.id,
- Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
- Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
- Message.from_account_id == (user.id if isinstance(user, Account) else None),
- )
- .first()
+ stmt = select(Message).where(
+ Message.id == message_id,
+ Message.app_id == app_model.id,
+ Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
+ Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
+ Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
+ message = db.session.scalar(stmt)
if not message:
raise MessageNotExistsError()
diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py
index 50d2a0036c..d384bff255 100644
--- a/api/core/app/apps/completion/app_runner.py
+++ b/api/core/app/apps/completion/app_runner.py
@@ -1,6 +1,8 @@
import logging
from typing import cast
+from sqlalchemy import select
+
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.completion.app_config_manager import CompletionAppConfig
@@ -35,8 +37,8 @@ class CompletionAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(CompletionAppConfig, app_config)
-
- app_record = db.session.query(App).where(App.id == app_config.app_id).first()
+ stmt = select(App).where(App.id == app_config.app_id)
+ app_record = db.session.scalar(stmt)
if not app_record:
raise ValueError("App not found")
diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py
index 11c979765b..92f3b6507c 100644
--- a/api/core/app/apps/message_based_app_generator.py
+++ b/api/core/app/apps/message_based_app_generator.py
@@ -3,6 +3,9 @@ import logging
from collections.abc import Generator
from typing import Optional, Union, cast
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager
@@ -83,11 +86,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
if conversation:
- app_model_config = (
- db.session.query(AppModelConfig)
- .where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
- .first()
+ stmt = select(AppModelConfig).where(
+ AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
)
+ app_model_config = db.session.scalar(stmt)
if not app_model_config:
raise AppModelConfigBrokenError()
@@ -253,7 +255,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param conversation_id: conversation id
:return: conversation
"""
- conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
+ with Session(db.engine, expire_on_commit=False) as session:
+ conversation = session.scalar(select(Conversation).where(Conversation.id == conversation_id))
if not conversation:
raise ConversationNotExistsError("Conversation not exists")
@@ -266,7 +269,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param message_id: message id
:return: message
"""
- message = db.session.query(Message).where(Message.id == message_id).first()
+ with Session(db.engine, expire_on_commit=False) as session:
+ message = session.scalar(select(Message).where(Message.id == message_id))
if message is None:
raise MessageNotExistsError("Message not exists")
diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py
index b829340401..be183e2086 100644
--- a/api/core/app/features/annotation_reply/annotation_reply.py
+++ b/api/core/app/features/annotation_reply/annotation_reply.py
@@ -1,6 +1,8 @@
import logging
from typing import Optional
+from sqlalchemy import select
+
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
@@ -25,9 +27,8 @@ class AnnotationReplyFeature:
:param invoke_from: invoke from
:return:
"""
- annotation_setting = (
- db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first()
- )
+ stmt = select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id)
+ annotation_setting = db.session.scalar(stmt)
if not annotation_setting:
return None
diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
index 471118c8cb..e3b917067f 100644
--- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
+++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
@@ -472,9 +472,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:param event: agent thought event
:return:
"""
- agent_thought: Optional[MessageAgentThought] = (
- db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
- )
+ with Session(db.engine, expire_on_commit=False) as session:
+ agent_thought: Optional[MessageAgentThought] = (
+ session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
+ )
if agent_thought:
return AgentThoughtStreamResponse(
diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py
index 50b51f70fe..5c19eda21e 100644
--- a/api/core/app/task_pipeline/message_cycle_manager.py
+++ b/api/core/app/task_pipeline/message_cycle_manager.py
@@ -3,6 +3,8 @@ from threading import Thread
from typing import Optional, Union
from flask import Flask, current_app
+from sqlalchemy import select
+from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import (
@@ -84,7 +86,8 @@ class MessageCycleManager:
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context():
# get conversation and message
- conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
+ stmt = select(Conversation).where(Conversation.id == conversation_id)
+ conversation = db.session.scalar(stmt)
if not conversation:
return
@@ -143,7 +146,8 @@ class MessageCycleManager:
:param event: event
:return:
"""
- message_file = db.session.query(MessageFile).where(MessageFile.id == event.message_file_id).first()
+ with Session(db.engine, expire_on_commit=False) as session:
+ message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
if message_file and message_file.url is not None:
# get tool file id
@@ -183,7 +187,8 @@ class MessageCycleManager:
:param message_id: message id
:return:
"""
- message_file = db.session.query(MessageFile).where(MessageFile.id == message_id).first()
+ with Session(db.engine, expire_on_commit=False) as session:
+ message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
return MessageStreamResponse(
diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py
index c55ba5e0fe..5cf39d7611 100644
--- a/api/core/callback_handler/index_tool_callback_handler.py
+++ b/api/core/callback_handler/index_tool_callback_handler.py
@@ -1,6 +1,8 @@
import logging
from collections.abc import Sequence
+from sqlalchemy import select
+
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
@@ -49,7 +51,8 @@ class DatasetIndexToolCallbackHandler:
for document in documents:
if document.metadata is not None:
document_id = document.metadata["document_id"]
- dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
+ dataset_document_stmt = select(DatasetDocument).where(DatasetDocument.id == document_id)
+ dataset_document = db.session.scalar(dataset_document_stmt)
if not dataset_document:
_logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s",
@@ -57,15 +60,12 @@ class DatasetIndexToolCallbackHandler:
)
continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
- child_chunk = (
- db.session.query(ChildChunk)
- .where(
- ChildChunk.index_node_id == document.metadata["doc_id"],
- ChildChunk.dataset_id == dataset_document.dataset_id,
- ChildChunk.document_id == dataset_document.id,
- )
- .first()
+ child_chunk_stmt = select(ChildChunk).where(
+ ChildChunk.index_node_id == document.metadata["doc_id"],
+ ChildChunk.dataset_id == dataset_document.dataset_id,
+ ChildChunk.document_id == dataset_document.id,
)
+ child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk:
segment = (
db.session.query(DocumentSegment)
diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py
index d81f372d40..2100e7fadc 100644
--- a/api/core/external_data_tool/api/api.py
+++ b/api/core/external_data_tool/api/api.py
@@ -1,5 +1,7 @@
from typing import Optional
+from sqlalchemy import select
+
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
from core.external_data_tool.base import ExternalDataTool
from core.helper import encrypter
@@ -28,13 +30,11 @@ class ApiExternalDataTool(ExternalDataTool):
api_based_extension_id = config.get("api_based_extension_id")
if not api_based_extension_id:
raise ValueError("api_based_extension_id is required")
-
# get api_based_extension
- api_based_extension = (
- db.session.query(APIBasedExtension)
- .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
- .first()
+ stmt = select(APIBasedExtension).where(
+ APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
)
+ api_based_extension = db.session.scalar(stmt)
if not api_based_extension:
raise ValueError("api_based_extension_id is invalid")
@@ -52,13 +52,11 @@ class ApiExternalDataTool(ExternalDataTool):
raise ValueError(f"config is required, config: {self.config}")
api_based_extension_id = self.config.get("api_based_extension_id")
assert api_based_extension_id is not None, "api_based_extension_id is required"
-
# get api_based_extension
- api_based_extension = (
- db.session.query(APIBasedExtension)
- .where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
- .first()
+ stmt = select(APIBasedExtension).where(
+ APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id
)
+ api_based_extension = db.session.scalar(stmt)
if not api_based_extension:
raise ValueError(
diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py
index c6bb2007d6..17345dc203 100644
--- a/api/core/helper/encrypter.py
+++ b/api/core/helper/encrypter.py
@@ -3,7 +3,7 @@ import base64
from libs import rsa
-def obfuscated_token(token: str):
+def obfuscated_token(token: str) -> str:
if not token:
return token
if len(token) <= 8:
@@ -11,6 +11,10 @@ def obfuscated_token(token: str):
return token[:6] + "*" * 12 + token[-2:]
+def full_mask_token(token_length=20):
+ return "*" * token_length
+
+
def encrypt_token(tenant_id: str, token: str):
from extensions.ext_database import db
from models.account import Tenant
diff --git a/api/core/helper/marketplace.py b/api/core/helper/marketplace.py
index fe3078923d..e837f2fd38 100644
--- a/api/core/helper/marketplace.py
+++ b/api/core/helper/marketplace.py
@@ -1,6 +1,6 @@
from collections.abc import Sequence
-import requests
+import httpx
from yarl import URL
from configs import dify_config
@@ -23,7 +23,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
return []
url = str(marketplace_api_url / "api/v1/plugins/batch")
- response = requests.post(url, json={"plugin_ids": plugin_ids})
+ response = httpx.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
@@ -36,7 +36,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
return []
url = str(marketplace_api_url / "api/v1/plugins/batch")
- response = requests.post(url, json={"plugin_ids": plugin_ids})
+ response = httpx.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
result: list[MarketplacePluginDeclaration] = []
for plugin in response.json()["data"]["plugins"]:
@@ -50,5 +50,5 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
def record_install_plugin_event(plugin_unique_identifier: str):
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
- response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
+ response = httpx.post(url, json={"unique_identifier": plugin_unique_identifier})
response.raise_for_status()
diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py
index a8e6c261c2..4a768618f5 100644
--- a/api/core/indexing_runner.py
+++ b/api/core/indexing_runner.py
@@ -8,6 +8,7 @@ import uuid
from typing import Any, Optional, cast
from flask import current_app
+from sqlalchemy import select
from sqlalchemy.orm.exc import ObjectDeletedError
from configs import dify_config
@@ -56,13 +57,11 @@ class IndexingRunner:
if not dataset:
raise ValueError("no dataset found")
-
# get the process rule
- processing_rule = (
- db.session.query(DatasetProcessRule)
- .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
- .first()
+ stmt = select(DatasetProcessRule).where(
+ DatasetProcessRule.id == dataset_document.dataset_process_rule_id
)
+ processing_rule = db.session.scalar(stmt)
if not processing_rule:
raise ValueError("no process rule found")
index_type = dataset_document.doc_form
@@ -123,11 +122,8 @@ class IndexingRunner:
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit()
# get the process rule
- processing_rule = (
- db.session.query(DatasetProcessRule)
- .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
- .first()
- )
+ stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
+ processing_rule = db.session.scalar(stmt)
if not processing_rule:
raise ValueError("no process rule found")
@@ -208,7 +204,6 @@ class IndexingRunner:
child_documents.append(child_document)
document.children = child_documents
documents.append(document)
-
# build index
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
@@ -310,7 +305,8 @@ class IndexingRunner:
# delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
for upload_file_id in image_upload_file_ids:
- image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
+ stmt = select(UploadFile).where(UploadFile.id == upload_file_id)
+ image_file = db.session.scalar(stmt)
if image_file is None:
continue
try:
@@ -339,10 +335,8 @@ class IndexingRunner:
if dataset_document.data_source_type == "upload_file":
if not data_source_info or "upload_file_id" not in data_source_info:
raise ValueError("no upload file found")
-
- file_detail = (
- db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
- )
+ stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
+ file_detail = db.session.scalars(stmt).one_or_none()
if file_detail:
extract_setting = ExtractSetting(
diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py
index 6cbd949be9..cb768e2036 100644
--- a/api/core/memory/token_buffer_memory.py
+++ b/api/core/memory/token_buffer_memory.py
@@ -110,9 +110,9 @@ class TokenBufferMemory:
else:
message_limit = 500
- stmt = stmt.limit(message_limit)
+ msg_limit_stmt = stmt.limit(message_limit)
- messages = db.session.scalars(stmt).all()
+ messages = db.session.scalars(msg_limit_stmt).all()
# instead of all messages from the conversation, we only need to extract messages
# that belong to the thread of last message
diff --git a/api/core/model_manager.py b/api/core/model_manager.py
index 51af3d1877..e567565548 100644
--- a/api/core/model_manager.py
+++ b/api/core/model_manager.py
@@ -158,8 +158,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
-
- self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast(
Union[LLMResult, Generator],
self._round_robin_invoke(
@@ -188,8 +186,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
-
- self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast(
int,
self._round_robin_invoke(
@@ -214,8 +210,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
-
- self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast(
TextEmbeddingResult,
self._round_robin_invoke(
@@ -237,8 +231,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
-
- self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast(
list[int],
self._round_robin_invoke(
@@ -269,8 +261,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
-
- self.model_type_instance = cast(RerankModel, self.model_type_instance)
return cast(
RerankResult,
self._round_robin_invoke(
@@ -295,8 +285,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
-
- self.model_type_instance = cast(ModerationModel, self.model_type_instance)
return cast(
bool,
self._round_robin_invoke(
@@ -318,8 +306,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
-
- self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
return cast(
str,
self._round_robin_invoke(
@@ -343,8 +329,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
-
- self.model_type_instance = cast(TTSModel, self.model_type_instance)
return cast(
Iterable[bytes],
self._round_robin_invoke(
@@ -404,8 +388,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
-
- self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.get_tts_model_voices(
model=self.model, credentials=self.credentials, language=language
)
diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py
index af51b72cd5..06d5c02bb8 100644
--- a/api/core/moderation/api/api.py
+++ b/api/core/moderation/api/api.py
@@ -1,6 +1,7 @@
from typing import Optional
from pydantic import BaseModel, Field
+from sqlalchemy import select
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor
from core.helper.encrypter import decrypt_token
@@ -87,10 +88,9 @@ class ApiModeration(Moderation):
@staticmethod
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
- extension = (
- db.session.query(APIBasedExtension)
- .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
- .first()
+ stmt = select(APIBasedExtension).where(
+ APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
)
+ extension = db.session.scalar(stmt)
return extension
diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py
index 77852e2a98..7caad89353 100644
--- a/api/core/ops/aliyun_trace/aliyun_trace.py
+++ b/api/core/ops/aliyun_trace/aliyun_trace.py
@@ -5,6 +5,7 @@ from typing import Optional
from urllib.parse import urljoin
from opentelemetry.trace import Link, Status, StatusCode
+from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
@@ -260,15 +261,15 @@ class AliyunDataTrace(BaseTraceInstance):
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
-
- app = session.query(App).where(App.id == app_id).first()
+ app_stmt = select(App).where(App.id == app_id)
+ app = session.scalar(app_stmt)
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
-
- service_account = session.query(Account).where(Account.id == app.created_by).first()
+ account_stmt = select(Account).where(Account.id == app.created_by)
+ service_account = session.scalar(account_stmt)
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = (
diff --git a/api/core/ops/base_trace_instance.py b/api/core/ops/base_trace_instance.py
index f8e428daf1..04b46d67a8 100644
--- a/api/core/ops/base_trace_instance.py
+++ b/api/core/ops/base_trace_instance.py
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
+from sqlalchemy import select
from sqlalchemy.orm import Session
from core.ops.entities.config_entity import BaseTracingConfig
@@ -44,14 +45,15 @@ class BaseTraceInstance(ABC):
"""
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
- app = session.query(App).where(App.id == app_id).first()
+ app_stmt = select(App).where(App.id == app_id)
+ app = session.scalar(app_stmt)
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
-
- service_account = session.query(Account).where(Account.id == app.created_by).first()
+ account_stmt = select(Account).where(Account.id == app.created_by)
+ service_account = session.scalar(account_stmt)
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py
index c2bc4339d7..d3040f1093 100644
--- a/api/core/ops/ops_trace_manager.py
+++ b/api/core/ops/ops_trace_manager.py
@@ -228,9 +228,9 @@ class OpsTraceManager:
if not trace_config_data:
return None
-
# decrypt_token
- app = db.session.query(App).where(App.id == app_id).first()
+ stmt = select(App).where(App.id == app_id)
+ app = db.session.scalar(stmt)
if not app:
raise ValueError("App not found")
@@ -297,20 +297,19 @@ class OpsTraceManager:
@classmethod
def get_app_config_through_message_id(cls, message_id: str):
app_model_config = None
- message_data = db.session.query(Message).where(Message.id == message_id).first()
+ message_stmt = select(Message).where(Message.id == message_id)
+ message_data = db.session.scalar(message_stmt)
if not message_data:
return None
conversation_id = message_data.conversation_id
- conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
+ conversation_stmt = select(Conversation).where(Conversation.id == conversation_id)
+ conversation_data = db.session.scalar(conversation_stmt)
if not conversation_data:
return None
if conversation_data.app_model_config_id:
- app_model_config = (
- db.session.query(AppModelConfig)
- .where(AppModelConfig.id == conversation_data.app_model_config_id)
- .first()
- )
+ config_stmt = select(AppModelConfig).where(AppModelConfig.id == conversation_data.app_model_config_id)
+ app_model_config = db.session.scalar(config_stmt)
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
app_model_config = conversation_data.override_model_configs
diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py
index 74972a2a9c..549b6a8889 100644
--- a/api/core/plugin/backwards_invocation/app.py
+++ b/api/core/plugin/backwards_invocation/app.py
@@ -1,6 +1,8 @@
from collections.abc import Generator, Mapping
from typing import Optional, Union
+from sqlalchemy import select
+
from controllers.service_api.wraps import create_or_update_end_user_for_user_id
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
@@ -191,10 +193,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"""
get the user by user id
"""
-
- user = db.session.query(EndUser).where(EndUser.id == user_id).first()
+ stmt = select(EndUser).where(EndUser.id == user_id)
+ user = db.session.scalar(stmt)
if not user:
- user = db.session.query(Account).where(Account.id == user_id).first()
+ stmt = select(Account).where(Account.id == user_id)
+ user = db.session.scalar(stmt)
if not user:
raise ValueError("user not found")
diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py
index 2f4e651461..cdc6ccc821 100644
--- a/api/core/prompt/utils/prompt_message_util.py
+++ b/api/core/prompt/utils/prompt_message_util.py
@@ -87,7 +87,6 @@ class PromptMessageUtil:
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT:
- content = cast(TextPromptMessageContent, content)
text += content.data
else:
content = cast(ImagePromptMessageContent, content)
diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py
index d8cc2293a2..e0fb0591e8 100644
--- a/api/core/provider_manager.py
+++ b/api/core/provider_manager.py
@@ -2,7 +2,7 @@ import contextlib
import json
from collections import defaultdict
from json import JSONDecodeError
-from typing import Any, Optional, cast
+from typing import Any, Optional
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
@@ -154,8 +154,8 @@ class ProviderManager:
for provider_entity in provider_entities:
# handle include, exclude
if is_filtered(
- include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET),
- exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET),
+ include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
+ exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
data=provider_entity,
name_func=lambda x: x.provider,
):
@@ -276,15 +276,11 @@ class ProviderManager:
:param model_type: model type
:return:
"""
- # Get the corresponding TenantDefaultModel record
- default_model = (
- db.session.query(TenantDefaultModel)
- .where(
- TenantDefaultModel.tenant_id == tenant_id,
- TenantDefaultModel.model_type == model_type.to_origin_model_type(),
- )
- .first()
+ stmt = select(TenantDefaultModel).where(
+ TenantDefaultModel.tenant_id == tenant_id,
+ TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
+ default_model = db.session.scalar(stmt)
# If it does not exist, get the first available provider model from get_configurations
# and update the TenantDefaultModel record
@@ -367,16 +363,11 @@ class ProviderManager:
model_names = [model.model for model in available_models]
if model not in model_names:
raise ValueError(f"Model {model} does not exist.")
-
- # Get the list of available models from get_configurations and check if it is LLM
- default_model = (
- db.session.query(TenantDefaultModel)
- .where(
- TenantDefaultModel.tenant_id == tenant_id,
- TenantDefaultModel.model_type == model_type.to_origin_model_type(),
- )
- .first()
+ stmt = select(TenantDefaultModel).where(
+ TenantDefaultModel.tenant_id == tenant_id,
+ TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
+ default_model = db.session.scalar(stmt)
# create or update TenantDefaultModel record
if default_model:
@@ -598,16 +589,13 @@ class ProviderManager:
provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
except IntegrityError:
db.session.rollback()
- existed_provider_record = (
- db.session.query(Provider)
- .where(
- Provider.tenant_id == tenant_id,
- Provider.provider_name == ModelProviderID(provider_name).provider_name,
- Provider.provider_type == ProviderType.SYSTEM.value,
- Provider.quota_type == ProviderQuotaType.TRIAL.value,
- )
- .first()
+ stmt = select(Provider).where(
+ Provider.tenant_id == tenant_id,
+ Provider.provider_name == ModelProviderID(provider_name).provider_name,
+ Provider.provider_type == ProviderType.SYSTEM.value,
+ Provider.quota_type == ProviderQuotaType.TRIAL.value,
)
+ existed_provider_record = db.session.scalar(stmt)
if not existed_provider_record:
continue
diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py
index c98306ea4b..5fb6f9fcc8 100644
--- a/api/core/rag/datasource/keyword/jieba/jieba.py
+++ b/api/core/rag/datasource/keyword/jieba/jieba.py
@@ -3,6 +3,7 @@ from typing import Any, Optional
import orjson
from pydantic import BaseModel
+from sqlalchemy import select
from configs import dify_config
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
@@ -211,11 +212,10 @@ class Jieba(BaseKeyword):
return sorted_chunk_indices[:k]
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
- document_segment = (
- db.session.query(DocumentSegment)
- .where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
- .first()
+ stmt = select(DocumentSegment).where(
+ DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id
)
+ document_segment = db.session.scalar(stmt)
if document_segment:
document_segment.keywords = keywords
db.session.add(document_segment)
diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py
index e872a4e375..fefd42f84d 100644
--- a/api/core/rag/datasource/retrieval_service.py
+++ b/api/core/rag/datasource/retrieval_service.py
@@ -3,6 +3,7 @@ from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from flask import Flask, current_app
+from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from configs import dify_config
@@ -24,7 +25,7 @@ default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
- "top_k": 2,
+ "top_k": 4,
"score_threshold_enabled": False,
}
@@ -127,7 +128,8 @@ class RetrievalService:
external_retrieval_model: Optional[dict] = None,
metadata_filtering_conditions: Optional[dict] = None,
):
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
+ stmt = select(Dataset).where(Dataset.id == dataset_id)
+ dataset = db.session.scalar(stmt)
if not dataset:
return []
metadata_condition = (
@@ -316,10 +318,8 @@ class RetrievalService:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# Handle parent-child documents
child_index_node_id = document.metadata.get("doc_id")
-
- child_chunk = (
- db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first()
- )
+ child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
+ child_chunk = db.session.scalar(child_chunk_stmt)
if not child_chunk:
continue
@@ -378,17 +378,13 @@ class RetrievalService:
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
-
- segment = (
- db.session.query(DocumentSegment)
- .where(
- DocumentSegment.dataset_id == dataset_document.dataset_id,
- DocumentSegment.enabled == True,
- DocumentSegment.status == "completed",
- DocumentSegment.index_node_id == index_node_id,
- )
- .first()
+ document_segment_stmt = select(DocumentSegment).where(
+ DocumentSegment.dataset_id == dataset_document.dataset_id,
+ DocumentSegment.enabled == True,
+ DocumentSegment.status == "completed",
+ DocumentSegment.index_node_id == index_node_id,
)
+ segment = db.session.scalar(document_segment_stmt)
if not segment:
continue
diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
index 6f3e15d166..aa0204ba70 100644
--- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
+++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
@@ -256,7 +256,7 @@ class AnalyticdbVectorOpenAPI:
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
- if match.score > score_threshold:
+ if match.score >= score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
@@ -293,7 +293,7 @@ class AnalyticdbVectorOpenAPI:
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
- if match.score > score_threshold:
+ if match.score >= score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
index df2173d1ca..71472cce41 100644
--- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
+++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
@@ -229,7 +229,7 @@ class AnalyticdbVectorBySql:
documents = []
for record in cur:
id, vector, score, page_content, metadata = record
- if score > score_threshold:
+ if score >= score_threshold:
metadata["score"] = score
doc = Document(
page_content=page_content,
diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py
index d63ca9f695..d30cf42601 100644
--- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py
+++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py
@@ -157,7 +157,7 @@ class BaiduVector(BaseVector):
if meta is not None:
meta = json.loads(meta)
score = row.get("score", 0.0)
- if score > score_threshold:
+ if score >= score_threshold:
meta["score"] = score
doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
docs.append(doc)
diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py
index 699a602365..88da86cf76 100644
--- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py
+++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py
@@ -120,7 +120,7 @@ class ChromaVector(BaseVector):
distance = distances[index]
metadata = dict(metadatas[index])
score = 1 - distance
- if score > score_threshold:
+ if score >= score_threshold:
metadata["score"] = score
doc = Document(
page_content=documents[index],
diff --git a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py
index bd986393d1..d22a7e4fd4 100644
--- a/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py
+++ b/api/core/rag/datasource/vdb/couchbase/couchbase_vector.py
@@ -304,7 +304,7 @@ class CouchbaseVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
- top_k = kwargs.get("top_k", 2)
+ top_k = kwargs.get("top_k", 4)
try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
search_iter = self._scope.search(
diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
index 49c4b392fe..cbad0e67de 100644
--- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
+++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
@@ -216,7 +216,7 @@ class ElasticSearchVector(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
- if score > score_threshold:
+ if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)
diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
index 0a4067e39c..f0d014b1ec 100644
--- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
+++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
@@ -127,7 +127,7 @@ class HuaweiCloudVector(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
- if score > score_threshold:
+ if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)
diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py
index 3c65a41f08..cba10b5aa5 100644
--- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py
+++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py
@@ -275,7 +275,7 @@ class LindormVectorStore(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
- if score > score_threshold:
+ if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)
diff --git a/api/core/rag/datasource/vdb/opengauss/opengauss.py b/api/core/rag/datasource/vdb/opengauss/opengauss.py
index 3ba9569d3f..c448210d94 100644
--- a/api/core/rag/datasource/vdb/opengauss/opengauss.py
+++ b/api/core/rag/datasource/vdb/opengauss/opengauss.py
@@ -194,7 +194,7 @@ class OpenGauss(BaseVector):
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
- if score > score_threshold:
+ if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs
diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py
index 3c9302f4da..71d2ba1427 100644
--- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py
+++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py
@@ -211,7 +211,7 @@ class OpenSearchVector(BaseVector):
metadata["score"] = hit["_score"]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
- if hit["_score"] > score_threshold:
+ if hit["_score"] >= score_threshold:
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py
index 0956914070..1b99f649bf 100644
--- a/api/core/rag/datasource/vdb/oracle/oraclevector.py
+++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py
@@ -261,7 +261,7 @@ class OracleVector(BaseVector):
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
- if score > score_threshold:
+ if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
conn.close()
return docs
diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
index e77befcdae..99cd4a22cb 100644
--- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
+++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
@@ -202,7 +202,7 @@ class PGVectoRS(BaseVector):
score = 1 - dis
metadata["score"] = score
score_threshold = float(kwargs.get("score_threshold") or 0.0)
- if score > score_threshold:
+ if score >= score_threshold:
doc = Document(page_content=record.text, metadata=metadata)
docs.append(doc)
return docs
diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py
index 108167d749..13be18f920 100644
--- a/api/core/rag/datasource/vdb/pgvector/pgvector.py
+++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py
@@ -195,7 +195,7 @@ class PGVector(BaseVector):
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
- if score > score_threshold:
+ if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs
diff --git a/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py b/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py
index 580da7d62e..c33e344bff 100644
--- a/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py
+++ b/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py
@@ -170,7 +170,7 @@ class VastbaseVector(BaseVector):
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
- if score > score_threshold:
+ if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs
diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
index fcf3a6d126..e55c06e665 100644
--- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
+++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
@@ -3,7 +3,7 @@ import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
-from typing import TYPE_CHECKING, Any, Optional, Union, cast
+from typing import TYPE_CHECKING, Any, Optional, Union
import qdrant_client
from flask import current_app
@@ -18,6 +18,7 @@ from qdrant_client.http.models import (
TokenizerType,
)
from qdrant_client.local.qdrant_local import QdrantLocal
+from sqlalchemy import select
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@@ -369,7 +370,7 @@ class QdrantVector(BaseVector):
continue
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
- if result.score > score_threshold:
+ if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
@@ -426,7 +427,6 @@ class QdrantVector(BaseVector):
def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
- self._client = cast(QdrantLocal, self._client)
self._client._load()
@classmethod
@@ -446,11 +446,8 @@ class QdrantVector(BaseVector):
class QdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector:
if dataset.collection_binding_id:
- dataset_collection_binding = (
- db.session.query(DatasetCollectionBinding)
- .where(DatasetCollectionBinding.id == dataset.collection_binding_id)
- .one_or_none()
- )
+ stmt = select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == dataset.collection_binding_id)
+ dataset_collection_binding = db.session.scalars(stmt).one_or_none()
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:
diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py
index 7a42dd1a89..a200bacfb6 100644
--- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py
+++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py
@@ -233,7 +233,7 @@ class RelytVector(BaseVector):
docs = []
for document, score in results:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
- if 1 - score > score_threshold:
+ if 1 - score >= score_threshold:
docs.append(document)
return docs
diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py
index e66959045f..1c154ef360 100644
--- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py
+++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py
@@ -300,7 +300,7 @@ class TableStoreVector(BaseVector):
)
documents = []
for search_hit in search_response.search_hits:
- if search_hit.score > score_threshold:
+ if search_hit.score >= score_threshold:
ots_column_map = {}
for col in search_hit.row[1]:
ots_column_map[col[0]] = col[1]
diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py
index 0517d5a6d1..3df35d081f 100644
--- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py
+++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py
@@ -39,6 +39,9 @@ class TencentConfig(BaseModel):
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
+bm25 = BM25Encoder.default("zh")
+
+
class TencentVector(BaseVector):
field_id: str = "id"
field_vector: str = "vector"
@@ -53,7 +56,6 @@ class TencentVector(BaseVector):
self._dimension = 1024
self._init_database()
self._load_collection()
- self._bm25 = BM25Encoder.default("zh")
def _load_collection(self):
"""
@@ -186,7 +188,7 @@ class TencentVector(BaseVector):
metadata=metadata,
)
if self._enable_hybrid_search:
- doc.__dict__["sparse_vector"] = self._bm25.encode_texts(texts[i])
+ doc.__dict__["sparse_vector"] = bm25.encode_texts(texts[i])
docs.append(doc)
self._client.upsert(
database_name=self._client_config.database,
@@ -264,7 +266,7 @@ class TencentVector(BaseVector):
match=[
KeywordSearch(
field_name="sparse_vector",
- data=self._bm25.encode_queries(query),
+ data=bm25.encode_queries(query),
),
],
rerank=WeightedRerank(
@@ -291,7 +293,7 @@ class TencentVector(BaseVector):
score = 1 - result.get("score", 0.0)
else:
score = result.get("score", 0.0)
- if score > score_threshold:
+ if score >= score_threshold:
meta["score"] = score
doc = Document(page_content=result.get(self.field_text), metadata=meta)
docs.append(doc)
diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
index a76b5d579c..be24f5a561 100644
--- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
+++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
@@ -20,6 +20,7 @@ from qdrant_client.http.models import (
)
from qdrant_client.local.qdrant_local import QdrantLocal
from requests.auth import HTTPDigestAuth
+from sqlalchemy import select
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@@ -351,7 +352,7 @@ class TidbOnQdrantVector(BaseVector):
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold") or 0.0
- if result.score > score_threshold:
+ if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
@@ -416,16 +417,12 @@ class TidbOnQdrantVector(BaseVector):
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
- tidb_auth_binding = (
- db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
- )
+ stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
+ tidb_auth_binding = db.session.scalars(stmt).one_or_none()
if not tidb_auth_binding:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
- tidb_auth_binding = (
- db.session.query(TidbAuthBinding)
- .where(TidbAuthBinding.tenant_id == dataset.tenant_id)
- .one_or_none()
- )
+ stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
+ tidb_auth_binding = db.session.scalars(stmt).one_or_none()
if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
diff --git a/api/core/rag/datasource/vdb/upstash/upstash_vector.py b/api/core/rag/datasource/vdb/upstash/upstash_vector.py
index e4f15be2b0..9e99f14dc5 100644
--- a/api/core/rag/datasource/vdb/upstash/upstash_vector.py
+++ b/api/core/rag/datasource/vdb/upstash/upstash_vector.py
@@ -110,7 +110,7 @@ class UpstashVector(BaseVector):
score = record.score
if metadata is not None and text is not None:
metadata["score"] = score
- if score > score_threshold:
+ if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs
diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py
index eef03ce412..661a8f37aa 100644
--- a/api/core/rag/datasource/vdb/vector_factory.py
+++ b/api/core/rag/datasource/vdb/vector_factory.py
@@ -3,6 +3,8 @@ import time
from abc import ABC, abstractmethod
from typing import Any, Optional
+from sqlalchemy import select
+
from configs import dify_config
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
@@ -45,11 +47,10 @@ class Vector:
vector_type = self._dataset.index_struct_dict["type"]
else:
if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
- whitelist = (
- db.session.query(Whitelist)
- .where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
- .one_or_none()
+ stmt = select(Whitelist).where(
+ Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db"
)
+ whitelist = db.session.scalars(stmt).one_or_none()
if whitelist:
vector_type = VectorType.TIDB_ON_QDRANT
diff --git a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
index 9166d35bc8..a0a2e47d19 100644
--- a/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
+++ b/api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
@@ -192,7 +192,7 @@ class VikingDBVector(BaseVector):
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
if metadata is not None:
metadata = json.loads(metadata)
- if result.score > score_threshold:
+ if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
index 5525ef1685..a7e0789a92 100644
--- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
+++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
@@ -220,7 +220,7 @@ class WeaviateVector(BaseVector):
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
# check score threshold
- if score > score_threshold:
+ if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)
diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py
index f8da3657fc..717cfe8f53 100644
--- a/api/core/rag/docstore/dataset_docstore.py
+++ b/api/core/rag/docstore/dataset_docstore.py
@@ -1,7 +1,7 @@
from collections.abc import Sequence
from typing import Any, Optional
-from sqlalchemy import func
+from sqlalchemy import func, select
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
@@ -41,9 +41,8 @@ class DatasetDocumentStore:
@property
def docs(self) -> dict[str, Document]:
- document_segments = (
- db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id).all()
- )
+ stmt = select(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id)
+ document_segments = db.session.scalars(stmt).all()
output = {}
for document_segment in document_segments:
@@ -228,10 +227,9 @@ class DatasetDocumentStore:
return data
def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]:
- document_segment = (
- db.session.query(DocumentSegment)
- .where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
- .first()
+ stmt = select(DocumentSegment).where(
+ DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id
)
+ document_segment = db.session.scalar(stmt)
return document_segment
diff --git a/api/core/rag/extractor/markdown_extractor.py b/api/core/rag/extractor/markdown_extractor.py
index c97765b1dc..3845392c8d 100644
--- a/api/core/rag/extractor/markdown_extractor.py
+++ b/api/core/rag/extractor/markdown_extractor.py
@@ -2,7 +2,7 @@
import re
from pathlib import Path
-from typing import Optional, cast
+from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.extractor.helpers import detect_file_encodings
@@ -76,7 +76,7 @@ class MarkdownExtractor(BaseExtractor):
markdown_tups.append((current_header, current_text))
markdown_tups = [
- (re.sub(r"#", "", cast(str, key)).strip() if key else None, re.sub(r"<.*?>", "", value))
+ (re.sub(r"#", "", key).strip() if key else None, re.sub(r"<.*?>", "", value))
for key, value in markdown_tups
]
diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py
index 17f4d1af2d..206b2bb921 100644
--- a/api/core/rag/extractor/notion_extractor.py
+++ b/api/core/rag/extractor/notion_extractor.py
@@ -4,6 +4,7 @@ import operator
from typing import Any, Optional, cast
import requests
+from sqlalchemy import select
from configs import dify_config
from core.rag.extractor.extractor_base import BaseExtractor
@@ -367,22 +368,17 @@ class NotionExtractor(BaseExtractor):
@classmethod
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
- data_source_binding = (
- db.session.query(DataSourceOauthBinding)
- .where(
- db.and_(
- DataSourceOauthBinding.tenant_id == tenant_id,
- DataSourceOauthBinding.provider == "notion",
- DataSourceOauthBinding.disabled == False,
- DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
- )
- )
- .first()
+ stmt = select(DataSourceOauthBinding).where(
+ DataSourceOauthBinding.tenant_id == tenant_id,
+ DataSourceOauthBinding.provider == "notion",
+ DataSourceOauthBinding.disabled == False,
+ DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
)
+ data_source_binding = db.session.scalar(stmt)
if not data_source_binding:
raise Exception(
f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}"
)
- return cast(str, data_source_binding.access_token)
+ return data_source_binding.access_token
diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py
index 7dfe2e357c..3c43f34104 100644
--- a/api/core/rag/extractor/pdf_extractor.py
+++ b/api/core/rag/extractor/pdf_extractor.py
@@ -2,7 +2,7 @@
import contextlib
from collections.abc import Iterator
-from typing import Optional, cast
+from typing import Optional
from core.rag.extractor.blob.blob import Blob
from core.rag.extractor.extractor_base import BaseExtractor
@@ -27,7 +27,7 @@ class PdfExtractor(BaseExtractor):
plaintext_file_exists = False
if self._file_cache_key:
with contextlib.suppress(FileNotFoundError):
- text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
+ text = storage.load(self._file_cache_key).decode("utf-8")
plaintext_file_exists = True
return [Document(page_content=text)]
documents = list(self.load())
diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py
index 9b90bd2bb3..997b0b953b 100644
--- a/api/core/rag/index_processor/processor/paragraph_index_processor.py
+++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py
@@ -123,7 +123,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
for result in results:
metadata = result.metadata
metadata["score"] = result.score
- if result.score > score_threshold:
+ if result.score >= score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs
diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py
index 52756fbacd..cb7f6ab57a 100644
--- a/api/core/rag/index_processor/processor/parent_child_index_processor.py
+++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py
@@ -130,13 +130,16 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
if delete_child_chunks:
db.session.query(ChildChunk).where(
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
- ).delete()
+ ).delete(synchronize_session=False)
db.session.commit()
else:
vector.delete()
if delete_child_chunks:
- db.session.query(ChildChunk).where(ChildChunk.dataset_id == dataset.id).delete()
+ # Use existing compound index: (tenant_id, dataset_id, ...)
+ db.session.query(ChildChunk).where(
+ ChildChunk.tenant_id == dataset.tenant_id, ChildChunk.dataset_id == dataset.id
+ ).delete(synchronize_session=False)
db.session.commit()
def retrieve(
@@ -162,7 +165,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
for result in results:
metadata = result.metadata
metadata["score"] = result.score
- if result.score > score_threshold:
+ if result.score >= score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs
diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py
index 609a8aafa1..8c345b7edf 100644
--- a/api/core/rag/index_processor/processor/qa_index_processor.py
+++ b/api/core/rag/index_processor/processor/qa_index_processor.py
@@ -158,7 +158,7 @@ class QAIndexProcessor(BaseIndexProcessor):
for result in results:
metadata = result.metadata
metadata["score"] = result.score
- if result.score > score_threshold:
+ if result.score >= score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs
diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py
index cd4af72832..11010c9d60 100644
--- a/api/core/rag/retrieval/dataset_retrieval.py
+++ b/api/core/rag/retrieval/dataset_retrieval.py
@@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
from typing import Any, Optional, Union, cast
from flask import Flask, current_app
-from sqlalchemy import Float, and_, or_, text
+from sqlalchemy import Float, and_, or_, select, text
from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy.orm import Session
@@ -65,7 +65,7 @@ default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
- "top_k": 2,
+ "top_k": 4,
"score_threshold_enabled": False,
}
@@ -135,7 +135,8 @@ class DatasetRetrieval:
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
- dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
+ dataset = db.session.scalar(dataset_stmt)
# pass if dataset is not available
if not dataset:
@@ -240,15 +241,12 @@ class DatasetRetrieval:
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
- document = (
- db.session.query(DatasetDocument)
- .where(
- DatasetDocument.id == segment.document_id,
- DatasetDocument.enabled == True,
- DatasetDocument.archived == False,
- )
- .first()
+ dataset_document_stmt = select(DatasetDocument).where(
+ DatasetDocument.id == segment.document_id,
+ DatasetDocument.enabled == True,
+ DatasetDocument.archived == False,
)
+ document = db.session.scalar(dataset_document_stmt)
if dataset and document:
source = RetrievalSourceMetadata(
dataset_id=dataset.id,
@@ -327,7 +325,8 @@ class DatasetRetrieval:
if dataset_id:
# get retrieval model config
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
+ dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
+ dataset = db.session.scalar(dataset_stmt)
if dataset:
results = []
if dataset.provider == "external":
@@ -514,22 +513,18 @@ class DatasetRetrieval:
dify_documents = [document for document in documents if document.provider == "dify"]
for document in dify_documents:
if document.metadata is not None:
- dataset_document = (
- db.session.query(DatasetDocument)
- .where(DatasetDocument.id == document.metadata["document_id"])
- .first()
+ dataset_document_stmt = select(DatasetDocument).where(
+ DatasetDocument.id == document.metadata["document_id"]
)
+ dataset_document = db.session.scalar(dataset_document_stmt)
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
- child_chunk = (
- db.session.query(ChildChunk)
- .where(
- ChildChunk.index_node_id == document.metadata["doc_id"],
- ChildChunk.dataset_id == dataset_document.dataset_id,
- ChildChunk.document_id == dataset_document.id,
- )
- .first()
+ child_chunk_stmt = select(ChildChunk).where(
+ ChildChunk.index_node_id == document.metadata["doc_id"],
+ ChildChunk.dataset_id == dataset_document.dataset_id,
+ ChildChunk.document_id == dataset_document.id,
)
+ child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk:
segment = (
db.session.query(DocumentSegment)
@@ -600,7 +595,8 @@ class DatasetRetrieval:
):
with flask_app.app_context():
with Session(db.engine) as session:
- dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
+ dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
+ dataset = db.session.scalar(dataset_stmt)
if not dataset:
return []
@@ -647,7 +643,7 @@ class DatasetRetrieval:
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
- top_k=retrieval_model.get("top_k") or 2,
+ top_k=retrieval_model.get("top_k") or 4,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
@@ -685,7 +681,8 @@ class DatasetRetrieval:
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
- dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
+ dataset = db.session.scalar(dataset_stmt)
# pass if dataset is not available
if not dataset:
@@ -743,7 +740,7 @@ class DatasetRetrieval:
tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=[dataset.id for dataset in available_datasets],
tenant_id=tenant_id,
- top_k=retrieve_config.top_k or 2,
+ top_k=retrieve_config.top_k or 4,
score_threshold=retrieve_config.score_threshold,
hit_callbacks=[hit_callback],
return_resource=return_resource,
@@ -958,7 +955,8 @@ class DatasetRetrieval:
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
) -> Optional[list[dict[str, Any]]]:
# get all metadata field
- metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
+ metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
+ metadata_fields = db.session.scalars(metadata_stmt).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
# get metadata model config
if metadata_model_config is None:
diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py
index cdfefbadb3..90c09a4441 100644
--- a/api/core/tools/tool_label_manager.py
+++ b/api/core/tools/tool_label_manager.py
@@ -1,3 +1,5 @@
+from sqlalchemy import select
+
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.custom_tool.provider import ApiToolProviderController
@@ -54,17 +56,13 @@ class ToolLabelManager:
return controller.tool_labels
else:
raise ValueError("Unsupported tool type")
-
- labels = (
- db.session.query(ToolLabelBinding.label_name)
- .where(
- ToolLabelBinding.tool_id == provider_id,
- ToolLabelBinding.tool_type == controller.provider_type.value,
- )
- .all()
+ stmt = select(ToolLabelBinding.label_name).where(
+ ToolLabelBinding.tool_id == provider_id,
+ ToolLabelBinding.tool_type == controller.provider_type.value,
)
+ labels = db.session.scalars(stmt).all()
- return [label.label_name for label in labels]
+ return list(labels)
@classmethod
def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py
index 474f8e3bcc..0069c6d6ee 100644
--- a/api/core/tools/tool_manager.py
+++ b/api/core/tools/tool_manager.py
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
import sqlalchemy as sa
from pydantic import TypeAdapter
+from sqlalchemy import select
from sqlalchemy.orm import Session
from yarl import URL
@@ -198,14 +199,11 @@ class ToolManager:
# get specific credentials
if is_valid_uuid(credential_id):
try:
- builtin_provider = (
- db.session.query(BuiltinToolProvider)
- .where(
- BuiltinToolProvider.tenant_id == tenant_id,
- BuiltinToolProvider.id == credential_id,
- )
- .first()
+ builtin_provider_stmt = select(BuiltinToolProvider).where(
+ BuiltinToolProvider.tenant_id == tenant_id,
+ BuiltinToolProvider.id == credential_id,
)
+ builtin_provider = db.session.scalar(builtin_provider_stmt)
except Exception as e:
builtin_provider = None
logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
@@ -319,11 +317,10 @@ class ToolManager:
),
)
elif provider_type == ToolProviderType.WORKFLOW:
- workflow_provider = (
- db.session.query(WorkflowToolProvider)
- .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
- .first()
+ workflow_provider_stmt = select(WorkflowToolProvider).where(
+ WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
)
+ workflow_provider = db.session.scalar(workflow_provider_stmt)
if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
@@ -333,16 +330,13 @@ class ToolManager:
if controller_tools is None or len(controller_tools) == 0:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
- return cast(
- WorkflowTool,
- controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
- runtime=ToolRuntime(
- tenant_id=tenant_id,
- credentials={},
- invoke_from=invoke_from,
- tool_invoke_from=tool_invoke_from,
- )
- ),
+ return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
+ runtime=ToolRuntime(
+ tenant_id=tenant_id,
+ credentials={},
+ invoke_from=invoke_from,
+ tool_invoke_from=tool_invoke_from,
+ )
)
elif provider_type == ToolProviderType.APP:
raise NotImplementedError("app provider not implemented")
@@ -652,8 +646,8 @@ class ToolManager:
for provider in builtin_providers:
# handle include, exclude
if is_filtered(
- include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
- exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
+ include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
+ exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider,
name_func=lambda x: x.identity.name,
):
diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py
index 7eb4bc017a..75c0c6738e 100644
--- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py
+++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py
@@ -3,6 +3,7 @@ from typing import Any
from flask import Flask, current_app
from pydantic import BaseModel, Field
+from sqlalchemy import select
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelManager
@@ -85,17 +86,14 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
- segments = (
- db.session.query(DocumentSegment)
- .where(
- DocumentSegment.dataset_id.in_(self.dataset_ids),
- DocumentSegment.completed_at.isnot(None),
- DocumentSegment.status == "completed",
- DocumentSegment.enabled == True,
- DocumentSegment.index_node_id.in_(index_node_ids),
- )
- .all()
+ document_segment_stmt = select(DocumentSegment).where(
+ DocumentSegment.dataset_id.in_(self.dataset_ids),
+ DocumentSegment.completed_at.isnot(None),
+ DocumentSegment.status == "completed",
+ DocumentSegment.enabled == True,
+ DocumentSegment.index_node_id.in_(index_node_ids),
)
+ segments = db.session.scalars(document_segment_stmt).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
@@ -112,15 +110,12 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
resource_number = 1
for segment in sorted_segments:
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
- document = (
- db.session.query(Document)
- .where(
- Document.id == segment.document_id,
- Document.enabled == True,
- Document.archived == False,
- )
- .first()
+ document_stmt = select(Document).where(
+ Document.id == segment.document_id,
+ Document.enabled == True,
+ Document.archived == False,
)
+ document = db.session.scalar(document_stmt)
if dataset and document:
source = RetrievalSourceMetadata(
position=resource_number,
@@ -162,9 +157,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
hit_callbacks: list[DatasetIndexToolCallbackHandler],
):
with flask_app.app_context():
- dataset = (
- db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
- )
+ stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id)
+ dataset = db.session.scalar(stmt)
if not dataset:
return []
@@ -181,7 +175,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
retrieval_method="keyword_search",
dataset_id=dataset.id,
query=query,
- top_k=retrieval_model.get("top_k") or 2,
+ top_k=retrieval_model.get("top_k") or 4,
)
if documents:
all_documents.extend(documents)
@@ -192,7 +186,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
- top_k=retrieval_model.get("top_k") or 2,
+ top_k=retrieval_model.get("top_k") or 4,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py
index 567275531e..4f489e00f4 100644
--- a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py
+++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py
@@ -13,7 +13,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
name: str = "dataset"
description: str = "use this to retrieve a dataset. "
tenant_id: str
- top_k: int = 2
+ top_k: int = 4
score_threshold: Optional[float] = None
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
return_resource: bool
diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
index f7689d7707..b536c5a25c 100644
--- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
+++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
@@ -1,6 +1,7 @@
from typing import Any, Optional, cast
from pydantic import BaseModel, Field
+from sqlalchemy import select
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService
@@ -56,9 +57,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
)
def _run(self, query: str) -> str:
- dataset = (
- db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
- )
+ dataset_stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id)
+ dataset = db.session.scalar(dataset_stmt)
if not dataset:
return ""
@@ -188,15 +188,12 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
- document = (
- db.session.query(DatasetDocument) # type: ignore
- .where(
- DatasetDocument.id == segment.document_id,
- DatasetDocument.enabled == True,
- DatasetDocument.archived == False,
- )
- .first()
+ dataset_document_stmt = select(DatasetDocument).where(
+ DatasetDocument.id == segment.document_id,
+ DatasetDocument.enabled == True,
+ DatasetDocument.archived == False,
)
+ document = db.session.scalar(dataset_document_stmt) # type: ignore
if dataset and document:
source = RetrievalSourceMetadata(
dataset_id=dataset.id,
diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py
index 8357dac0d7..bf075bd730 100644
--- a/api/core/tools/utils/message_transformer.py
+++ b/api/core/tools/utils/message_transformer.py
@@ -3,7 +3,7 @@ from collections.abc import Generator
from datetime import date, datetime
from decimal import Decimal
from mimetypes import guess_extension
-from typing import Optional, cast
+from typing import Optional
from uuid import UUID
import numpy as np
@@ -159,8 +159,7 @@ class ToolFileMessageTransformer:
elif message.type == ToolInvokeMessage.MessageType.JSON:
if isinstance(message.message, ToolInvokeMessage.JsonMessage):
- json_msg = cast(ToolInvokeMessage.JsonMessage, message.message)
- json_msg.json_object = safe_json_value(json_msg.json_object)
+ message.message.json_object = safe_json_value(message.message.json_object)
yield message
else:
yield message
diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py
index 3f59b3f472..251d914800 100644
--- a/api/core/tools/utils/model_invocation_utils.py
+++ b/api/core/tools/utils/model_invocation_utils.py
@@ -129,17 +129,14 @@ class ModelInvocationUtils:
db.session.commit()
try:
- response: LLMResult = cast(
- LLMResult,
- model_instance.invoke_llm(
- prompt_messages=prompt_messages,
- model_parameters=model_parameters,
- tools=[],
- stop=[],
- stream=False,
- user=user_id,
- callbacks=[],
- ),
+ response: LLMResult = model_instance.invoke_llm(
+ prompt_messages=prompt_messages,
+ model_parameters=model_parameters,
+ tools=[],
+ stop=[],
+ stream=False,
+ user=user_id,
+ callbacks=[],
)
except InvokeRateLimitError as e:
raise InvokeModelError(f"Invoke rate limit error: {e}")
diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py
index d8749f9851..c2092853ea 100644
--- a/api/core/tools/workflow_as_tool/tool.py
+++ b/api/core/tools/workflow_as_tool/tool.py
@@ -1,7 +1,9 @@
import json
import logging
from collections.abc import Generator
-from typing import Any, Optional, cast
+from typing import Any, Optional
+
+from sqlalchemy import select
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool
@@ -133,7 +135,8 @@ class WorkflowTool(Tool):
.first()
)
else:
- workflow = db.session.query(Workflow).where(Workflow.app_id == app_id, Workflow.version == version).first()
+ stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
+ workflow = db.session.scalar(stmt)
if not workflow:
raise ValueError("workflow not found or not published")
@@ -144,7 +147,8 @@ class WorkflowTool(Tool):
"""
get the app by app id
"""
- app = db.session.query(App).where(App.id == app_id).first()
+ stmt = select(App).where(App.id == app_id)
+ app = db.session.scalar(stmt)
if not app:
raise ValueError("app not found")
@@ -201,14 +205,14 @@ class WorkflowTool(Tool):
item = self._update_file_mapping(item)
file = build_from_mapping(
mapping=item,
- tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
+ tenant_id=str(self.runtime.tenant_id),
)
files.append(file)
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
value = self._update_file_mapping(value)
file = build_from_mapping(
mapping=value,
- tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
+ tenant_id=str(self.runtime.tenant_id),
)
files.append(file)
diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py
index 16c8116ac1..a994730cd5 100644
--- a/api/core/variables/variables.py
+++ b/api/core/variables/variables.py
@@ -1,5 +1,5 @@
from collections.abc import Sequence
-from typing import Annotated, TypeAlias, cast
+from typing import Annotated, TypeAlias
from uuid import uuid4
from pydantic import Discriminator, Field, Tag
@@ -86,7 +86,7 @@ class SecretVariable(StringVariable):
@property
def log(self) -> str:
- return cast(str, encrypter.obfuscated_token(self.value))
+ return encrypter.obfuscated_token(self.value)
class NoneVariable(NoneSegment, Variable):
diff --git a/api/core/workflow/graph_engine/graph_engine.py.orig b/api/core/workflow/graph_engine/graph_engine.py.orig
new file mode 100644
index 0000000000..833cee0ffe
--- /dev/null
+++ b/api/core/workflow/graph_engine/graph_engine.py.orig
@@ -0,0 +1,339 @@
+"""
+QueueBasedGraphEngine - Main orchestrator for queue-based workflow execution.
+
+This engine uses a modular architecture with separated packages following
+Domain-Driven Design principles for improved maintainability and testability.
+"""
+
+import contextvars
+import logging
+import queue
+from collections.abc import Generator, Mapping
+from typing import final
+
+from flask import Flask, current_app
+
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.workflow.entities import GraphRuntimeState
+from core.workflow.enums import NodeExecutionType
+from core.workflow.graph import Graph
+from core.workflow.graph_events import (
+ GraphEngineEvent,
+ GraphNodeEventBase,
+ GraphRunAbortedEvent,
+ GraphRunFailedEvent,
+ GraphRunStartedEvent,
+ GraphRunSucceededEvent,
+)
+from models.enums import UserFrom
+
+from .command_processing import AbortCommandHandler, CommandProcessor
+from .domain import ExecutionContext, GraphExecution
+from .entities.commands import AbortCommand
+from .error_handling import ErrorHandler
+from .event_management import EventHandler, EventManager
+from .graph_traversal import EdgeProcessor, SkipPropagator
+from .layers.base import Layer
+from .orchestration import Dispatcher, ExecutionCoordinator
+from .protocols.command_channel import CommandChannel
+from .response_coordinator import ResponseStreamCoordinator
+from .state_management import UnifiedStateManager
+from .worker_management import SimpleWorkerPool
+
+logger = logging.getLogger(__name__)
+
+
+@final
+class GraphEngine:
+ """
+ Queue-based graph execution engine.
+
+ Uses a modular architecture that delegates responsibilities to specialized
+ subsystems, following Domain-Driven Design and SOLID principles.
+ """
+
+ def __init__(
+ self,
+ tenant_id: str,
+ app_id: str,
+ workflow_id: str,
+ user_id: str,
+ user_from: UserFrom,
+ invoke_from: InvokeFrom,
+ call_depth: int,
+ graph: Graph,
+ graph_config: Mapping[str, object],
+ graph_runtime_state: GraphRuntimeState,
+ max_execution_steps: int,
+ max_execution_time: int,
+ command_channel: CommandChannel,
+ min_workers: int | None = None,
+ max_workers: int | None = None,
+ scale_up_threshold: int | None = None,
+ scale_down_idle_time: float | None = None,
+ ) -> None:
+ """Initialize the graph engine with all subsystems and dependencies."""
+
+ # === Domain Models ===
+ # Execution context encapsulates workflow execution metadata
+ self._execution_context = ExecutionContext(
+ tenant_id=tenant_id,
+ app_id=app_id,
+ workflow_id=workflow_id,
+ user_id=user_id,
+ user_from=user_from,
+ invoke_from=invoke_from,
+ call_depth=call_depth,
+ max_execution_steps=max_execution_steps,
+ max_execution_time=max_execution_time,
+ )
+
+ # Graph execution tracks the overall execution state
+ self._graph_execution = GraphExecution(workflow_id=workflow_id)
+
+ # === Core Dependencies ===
+ # Graph structure and configuration
+ self._graph = graph
+ self._graph_config = graph_config
+ self._graph_runtime_state = graph_runtime_state
+ self._command_channel = command_channel
+
+ # === Worker Management Parameters ===
+ # Parameters for dynamic worker pool scaling
+ self._min_workers = min_workers
+ self._max_workers = max_workers
+ self._scale_up_threshold = scale_up_threshold
+ self._scale_down_idle_time = scale_down_idle_time
+
+ # === Execution Queues ===
+ # Queue for nodes ready to execute
+ self._ready_queue: queue.Queue[str] = queue.Queue()
+ # Queue for events generated during execution
+ self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
+
+ # === State Management ===
+ # Unified state manager handles all node state transitions and queue operations
+ self._state_manager = UnifiedStateManager(self._graph, self._ready_queue)
+
+ # === Response Coordination ===
+ # Coordinates response streaming from response nodes
+ self._response_coordinator = ResponseStreamCoordinator(
+ variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph
+ )
+
+ # === Event Management ===
+ # Event manager handles both collection and emission of events
+ self._event_manager = EventManager()
+
+ # === Error Handling ===
+ # Centralized error handler for graph execution errors
+ self._error_handler = ErrorHandler(self._graph, self._graph_execution)
+
+ # === Graph Traversal Components ===
+ # Propagates skip status through the graph when conditions aren't met
+ self._skip_propagator = SkipPropagator(
+ graph=self._graph,
+ state_manager=self._state_manager,
+ )
+
+ # Processes edges to determine next nodes after execution
+ # Also handles conditional branching and route selection
+ self._edge_processor = EdgeProcessor(
+ graph=self._graph,
+ state_manager=self._state_manager,
+ response_coordinator=self._response_coordinator,
+ skip_propagator=self._skip_propagator,
+ )
+
+ # === Event Handler Registry ===
+ # Central registry for handling all node execution events
+ self._event_handler_registry = EventHandler(
+ graph=self._graph,
+ graph_runtime_state=self._graph_runtime_state,
+ graph_execution=self._graph_execution,
+ response_coordinator=self._response_coordinator,
+ event_collector=self._event_manager,
+ edge_processor=self._edge_processor,
+ state_manager=self._state_manager,
+ error_handler=self._error_handler,
+ )
+
+ # === Command Processing ===
+ # Processes external commands (e.g., abort requests)
+ self._command_processor = CommandProcessor(
+ command_channel=self._command_channel,
+ graph_execution=self._graph_execution,
+ )
+
+ # Register abort command handler
+ abort_handler = AbortCommandHandler()
+ self._command_processor.register_handler(
+ AbortCommand,
+ abort_handler,
+ )
+
+ # === Worker Pool Setup ===
+ # Capture Flask app context for worker threads
+ flask_app: Flask | None = None
+ try:
+ app = current_app._get_current_object() # type: ignore
+ if isinstance(app, Flask):
+ flask_app = app
+ except RuntimeError:
+ pass
+
+ # Capture context variables for worker threads
+ context_vars = contextvars.copy_context()
+
+ # Create worker pool for parallel node execution
+ self._worker_pool = SimpleWorkerPool(
+ ready_queue=self._ready_queue,
+ event_queue=self._event_queue,
+ graph=self._graph,
+ flask_app=flask_app,
+ context_vars=context_vars,
+ min_workers=self._min_workers,
+ max_workers=self._max_workers,
+ scale_up_threshold=self._scale_up_threshold,
+ scale_down_idle_time=self._scale_down_idle_time,
+ )
+
+ # === Orchestration ===
+ # Coordinates the overall execution lifecycle
+ self._execution_coordinator = ExecutionCoordinator(
+ graph_execution=self._graph_execution,
+ state_manager=self._state_manager,
+ event_handler=self._event_handler_registry,
+ event_collector=self._event_manager,
+ command_processor=self._command_processor,
+ worker_pool=self._worker_pool,
+ )
+
+ # Dispatches events and manages execution flow
+ self._dispatcher = Dispatcher(
+ event_queue=self._event_queue,
+ event_handler=self._event_handler_registry,
+ event_collector=self._event_manager,
+ execution_coordinator=self._execution_coordinator,
+ max_execution_time=self._execution_context.max_execution_time,
+ event_emitter=self._event_manager,
+ )
+
+ # === Extensibility ===
+ # Layers allow plugins to extend engine functionality
+ self._layers: list[Layer] = []
+
+ # === Validation ===
+ # Ensure all nodes share the same GraphRuntimeState instance
+ self._validate_graph_state_consistency()
+
+ def _validate_graph_state_consistency(self) -> None:
+ """Validate that all nodes share the same GraphRuntimeState."""
+ expected_state_id = id(self._graph_runtime_state)
+ for node in self._graph.nodes.values():
+ if id(node.graph_runtime_state) != expected_state_id:
+ raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
+
+ def layer(self, layer: Layer) -> "GraphEngine":
+ """Add a layer for extending functionality."""
+ self._layers.append(layer)
+ return self
+
+ def run(self) -> Generator[GraphEngineEvent, None, None]:
+ """
+ Execute the graph using the modular architecture.
+
+ Returns:
+ Generator yielding GraphEngineEvent instances
+ """
+ try:
+ # Initialize layers
+ self._initialize_layers()
+
+ # Start execution
+ self._graph_execution.start()
+ start_event = GraphRunStartedEvent()
+ yield start_event
+
+ # Start subsystems
+ self._start_execution()
+
+ # Yield events as they occur
+ yield from self._event_manager.emit_events()
+
+ # Handle completion
+ if self._graph_execution.aborted:
+ abort_reason = "Workflow execution aborted by user command"
+ if self._graph_execution.error:
+ abort_reason = str(self._graph_execution.error)
+ yield GraphRunAbortedEvent(
+ reason=abort_reason,
+ outputs=self._graph_runtime_state.outputs,
+ )
+ elif self._graph_execution.has_error:
+ if self._graph_execution.error:
+ raise self._graph_execution.error
+ else:
+ yield GraphRunSucceededEvent(
+ outputs=self._graph_runtime_state.outputs,
+ )
+
+ except Exception as e:
+ yield GraphRunFailedEvent(error=str(e))
+ raise
+
+ finally:
+ self._stop_execution()
+
+ def _initialize_layers(self) -> None:
+ """Initialize layers with context."""
+ self._event_manager.set_layers(self._layers)
+ for layer in self._layers:
+ try:
+ layer.initialize(self._graph_runtime_state, self._command_channel)
+ except Exception as e:
+ logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)
+
+ try:
+ layer.on_graph_start()
+ except Exception as e:
+ logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e)
+
+ def _start_execution(self) -> None:
+ """Start execution subsystems."""
+ # Start worker pool (it calculates initial workers internally)
+ self._worker_pool.start()
+
+ # Register response nodes
+ for node in self._graph.nodes.values():
+ if node.execution_type == NodeExecutionType.RESPONSE:
+ self._response_coordinator.register(node.id)
+
+ # Enqueue root node
+ root_node = self._graph.root_node
+ self._state_manager.enqueue_node(root_node.id)
+ self._state_manager.start_execution(root_node.id)
+
+ # Start dispatcher
+ self._dispatcher.start()
+
+ def _stop_execution(self) -> None:
+ """Stop execution subsystems."""
+ self._dispatcher.stop()
+ self._worker_pool.stop()
+ # Don't mark complete here as the dispatcher already does it
+
+ # Notify layers
+ logger = logging.getLogger(__name__)
+
+ for layer in self._layers:
+ try:
+ layer.on_graph_end(self._graph_execution.error)
+ except Exception as e:
+ logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e)
+
+ # Public property accessors for attributes that need external access
+ @property
+ def graph_runtime_state(self) -> GraphRuntimeState:
+ """Get the graph runtime state."""
+ return self._graph_runtime_state
diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py
index 57b58ab8f5..fa912d5035 100644
--- a/api/core/workflow/nodes/agent/agent_node.py
+++ b/api/core/workflow/nodes/agent/agent_node.py
@@ -157,7 +157,7 @@ class AgentNode(Node):
messages=message_stream,
tool_info={
"icon": self.agent_strategy_icon,
- "agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
+ "agent_strategy": self._node_data.agent_strategy_name,
},
parameters_for_log=parameters_for_log,
user_id=self.user_id,
@@ -401,8 +401,7 @@ class AgentNode(Node):
current_plugin = next(
plugin
for plugin in plugins
- if f"{plugin.plugin_id}/{plugin.name}"
- == cast(AgentNodeData, self._node_data).agent_strategy_provider_name
+ if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name
)
icon = current_plugin.declaration.icon
except StopIteration:
diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py
index 65dde51191..ff241b9cf0 100644
--- a/api/core/workflow/nodes/document_extractor/node.py
+++ b/api/core/workflow/nodes/document_extractor/node.py
@@ -301,12 +301,12 @@ def _extract_text_from_yaml(file_content: bytes) -> str:
encoding = "utf-8"
yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore"))
- return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
+ return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e:
# If decoding fails, try with utf-8 as last resort
try:
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore"))
- return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
+ return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, yaml.YAMLError):
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e
diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
index d7ccb338f5..bd5cab1e72 100644
--- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
+++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
@@ -6,7 +6,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
-from sqlalchemy import Float, and_, func, or_, text
+from sqlalchemy import Float, and_, func, or_, select, text
from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy.orm import sessionmaker
@@ -75,7 +75,7 @@ default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
- "top_k": 2,
+ "top_k": 4,
"score_threshold_enabled": False,
}
@@ -358,15 +358,12 @@ class KnowledgeRetrievalNode(Node):
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
- document = (
- db.session.query(Document)
- .where(
- Document.id == segment.document_id,
- Document.enabled == True,
- Document.archived == False,
- )
- .first()
+ stmt = select(Document).where(
+ Document.id == segment.document_id,
+ Document.enabled == True,
+ Document.archived == False,
)
+ document = db.session.scalar(stmt)
if dataset and document:
source = {
"metadata": {
@@ -505,7 +502,8 @@ class KnowledgeRetrievalNode(Node):
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> list[dict[str, Any]]:
# get all metadata field
- metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
+ stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
+ metadata_fields = db.session.scalars(stmt).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
if node_data.metadata_model_config is None:
raise ValueError("metadata_model_config is required")
diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
index 3e4882fd1e..648ea69936 100644
--- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
+++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
@@ -138,7 +138,7 @@ class ParameterExtractorNode(Node):
"""
Run the node.
"""
- node_data = cast(ParameterExtractorNodeData, self._node_data)
+ node_data = self._node_data
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else ""
diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py
index 968332959c..afc45cf9cf 100644
--- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py
+++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py
@@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, Sequence
-from typing import TYPE_CHECKING, Any, Optional, cast
+from typing import TYPE_CHECKING, Any, Optional
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@@ -103,7 +103,7 @@ class QuestionClassifierNode(Node):
return "1"
def _run(self):
- node_data = cast(QuestionClassifierNodeData, self._node_data)
+ node_data = self._node_data
variable_pool = self.graph_runtime_state.variable_pool
# extract variables
diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py
index 39524dcd4f..9708edcb38 100644
--- a/api/core/workflow/nodes/tool/tool_node.py
+++ b/api/core/workflow/nodes/tool/tool_node.py
@@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence
-from typing import TYPE_CHECKING, Any, Optional, cast
+from typing import TYPE_CHECKING, Any, Optional
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -61,7 +61,7 @@ class ToolNode(Node):
"""
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
- node_data = cast(ToolNodeData, self._node_data)
+ node_data = self._node_data
# fetch tool icon
tool_info = {
diff --git a/api/core/workflow/nodes/tool/tool_node.py.orig b/api/core/workflow/nodes/tool/tool_node.py.orig
new file mode 100644
index 0000000000..9708edcb38
--- /dev/null
+++ b/api/core/workflow/nodes/tool/tool_node.py.orig
@@ -0,0 +1,493 @@
+from collections.abc import Generator, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, Optional
+
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
+from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
+from core.file import File, FileTransferMethod
+from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
+from core.tools.errors import ToolInvokeError
+from core.tools.tool_engine import ToolEngine
+from core.tools.utils.message_transformer import ToolFileMessageTransformer
+from core.variables.segments import ArrayAnySegment, ArrayFileSegment
+from core.variables.variables import ArrayAnyVariable
+from core.workflow.enums import (
+ ErrorStrategy,
+ NodeType,
+ SystemVariableKey,
+ WorkflowNodeExecutionMetadataKey,
+ WorkflowNodeExecutionStatus,
+)
+from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.base.node import Node
+from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
+from extensions.ext_database import db
+from factories import file_factory
+from models import ToolFile
+from services.tools.builtin_tools_manage_service import BuiltinToolManageService
+
+from .entities import ToolNodeData
+from .exc import (
+ ToolFileError,
+ ToolNodeError,
+ ToolParameterError,
+)
+
+if TYPE_CHECKING:
+ from core.workflow.entities import VariablePool
+
+
+class ToolNode(Node):
+ """
+ Tool Node
+ """
+
+ node_type = NodeType.TOOL
+
+ _node_data: ToolNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = ToolNodeData.model_validate(data)
+
+ @classmethod
+ def version(cls) -> str:
+ return "1"
+
+ def _run(self) -> Generator:
+ """
+ Run the tool node
+ """
+ from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
+
+ node_data = self._node_data
+
+ # fetch tool icon
+ tool_info = {
+ "provider_type": node_data.provider_type.value,
+ "provider_id": node_data.provider_id,
+ "plugin_unique_identifier": node_data.plugin_unique_identifier,
+ }
+
+ # get tool runtime
+ try:
+ from core.tools.tool_manager import ToolManager
+
+ # This is an issue that caused problems before.
+ # Logically, we shouldn't use the node_data.version field for judgment
+ # But for backward compatibility with historical data
+ # this version field judgment is still preserved here.
+ variable_pool: VariablePool | None = None
+ if node_data.version != "1" or node_data.tool_node_version != "1":
+ variable_pool = self.graph_runtime_state.variable_pool
+ tool_runtime = ToolManager.get_workflow_tool_runtime(
+ self.tenant_id, self.app_id, self._node_id, self._node_data, self.invoke_from, variable_pool
+ )
+ except ToolNodeError as e:
+ yield StreamCompletedEvent(
+ node_run_result=NodeRunResult(
+ status=WorkflowNodeExecutionStatus.FAILED,
+ inputs={},
+ metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
+ error=f"Failed to get tool runtime: {str(e)}",
+ error_type=type(e).__name__,
+ )
+ )
+ return
+
+ # get parameters
+ tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
+ parameters = self._generate_parameters(
+ tool_parameters=tool_parameters,
+ variable_pool=self.graph_runtime_state.variable_pool,
+ node_data=self._node_data,
+ )
+ parameters_for_log = self._generate_parameters(
+ tool_parameters=tool_parameters,
+ variable_pool=self.graph_runtime_state.variable_pool,
+ node_data=self._node_data,
+ for_log=True,
+ )
+ # get conversation id
+ conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
+
+ try:
+ message_stream = ToolEngine.generic_invoke(
+ tool=tool_runtime,
+ tool_parameters=parameters,
+ user_id=self.user_id,
+ workflow_tool_callback=DifyWorkflowCallbackHandler(),
+ workflow_call_depth=self.workflow_call_depth,
+ app_id=self.app_id,
+ conversation_id=conversation_id.text if conversation_id else None,
+ )
+ except ToolNodeError as e:
+ yield StreamCompletedEvent(
+ node_run_result=NodeRunResult(
+ status=WorkflowNodeExecutionStatus.FAILED,
+ inputs=parameters_for_log,
+ metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
+ error=f"Failed to invoke tool: {str(e)}",
+ error_type=type(e).__name__,
+ )
+ )
+ return
+
+ try:
+ # convert tool messages
+ yield from self._transform_message(
+ messages=message_stream,
+ tool_info=tool_info,
+ parameters_for_log=parameters_for_log,
+ user_id=self.user_id,
+ tenant_id=self.tenant_id,
+ node_id=self._node_id,
+ )
+ except ToolInvokeError as e:
+ yield StreamCompletedEvent(
+ node_run_result=NodeRunResult(
+ status=WorkflowNodeExecutionStatus.FAILED,
+ inputs=parameters_for_log,
+ metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
+ error=f"Failed to invoke tool {node_data.provider_name}: {str(e)}",
+ error_type=type(e).__name__,
+ )
+ )
+ except PluginInvokeError as e:
+ yield StreamCompletedEvent(
+ node_run_result=NodeRunResult(
+ status=WorkflowNodeExecutionStatus.FAILED,
+ inputs=parameters_for_log,
+ metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
+ error="An error occurred in the plugin, "
+ f"please contact the author of {node_data.provider_name} for help, "
+ f"error type: {e.get_error_type()}, "
+ f"error details: {e.get_error_message()}",
+ error_type=type(e).__name__,
+ )
+ )
+ except PluginDaemonClientSideError as e:
+ yield StreamCompletedEvent(
+ node_run_result=NodeRunResult(
+ status=WorkflowNodeExecutionStatus.FAILED,
+ inputs=parameters_for_log,
+ metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
+ error=f"Failed to invoke tool, error: {e.description}",
+ error_type=type(e).__name__,
+ )
+ )
+
+ def _generate_parameters(
+ self,
+ *,
+ tool_parameters: Sequence[ToolParameter],
+ variable_pool: "VariablePool",
+ node_data: ToolNodeData,
+ for_log: bool = False,
+ ) -> dict[str, Any]:
+ """
+ Generate parameters based on the given tool parameters, variable pool, and node data.
+
+ Args:
+ tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
+ variable_pool (VariablePool): The variable pool containing the variables.
+ node_data (ToolNodeData): The data associated with the tool node.
+
+ Returns:
+ Mapping[str, Any]: A dictionary containing the generated parameters.
+
+ """
+ tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
+
+ result: dict[str, Any] = {}
+ for parameter_name in node_data.tool_parameters:
+ parameter = tool_parameters_dictionary.get(parameter_name)
+ if not parameter:
+ result[parameter_name] = None
+ continue
+ tool_input = node_data.tool_parameters[parameter_name]
+ if tool_input.type == "variable":
+ variable = variable_pool.get(tool_input.value)
+ if variable is None:
+ if parameter.required:
+ raise ToolParameterError(f"Variable {tool_input.value} does not exist")
+ continue
+ parameter_value = variable.value
+ elif tool_input.type in {"mixed", "constant"}:
+ segment_group = variable_pool.convert_template(str(tool_input.value))
+ parameter_value = segment_group.log if for_log else segment_group.text
+ else:
+ raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
+ result[parameter_name] = parameter_value
+
+ return result
+
+ def _fetch_files(self, variable_pool: "VariablePool") -> list[File]:
+ variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
+ assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
+ return list(variable.value) if variable else []
+
+ def _transform_message(
+ self,
+ messages: Generator[ToolInvokeMessage, None, None],
+ tool_info: Mapping[str, Any],
+ parameters_for_log: dict[str, Any],
+ user_id: str,
+ tenant_id: str,
+ node_id: str,
+ ) -> Generator:
+ """
+ Convert ToolInvokeMessages into tuple[plain_text, files]
+ """
+ # transform message and handle file storage
+ from core.plugin.impl.plugin import PluginInstaller
+
+ message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
+ messages=messages,
+ user_id=user_id,
+ tenant_id=tenant_id,
+ conversation_id=None,
+ )
+
+ text = ""
+ files: list[File] = []
+ json: list[dict] = []
+
+ variables: dict[str, Any] = {}
+
+ for message in message_stream:
+ if message.type in {
+ ToolInvokeMessage.MessageType.IMAGE_LINK,
+ ToolInvokeMessage.MessageType.BINARY_LINK,
+ ToolInvokeMessage.MessageType.IMAGE,
+ }:
+ assert isinstance(message.message, ToolInvokeMessage.TextMessage)
+
+ url = message.message.text
+ if message.meta:
+ transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
+ else:
+ transfer_method = FileTransferMethod.TOOL_FILE
+
+ tool_file_id = str(url).split("/")[-1].split(".")[0]
+
+ with Session(db.engine) as session:
+ stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
+ tool_file = session.scalar(stmt)
+ if tool_file is None:
+ raise ToolFileError(f"Tool file {tool_file_id} does not exist")
+
+ mapping = {
+ "tool_file_id": tool_file_id,
+ "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
+ "transfer_method": transfer_method,
+ "url": url,
+ }
+ file = file_factory.build_from_mapping(
+ mapping=mapping,
+ tenant_id=tenant_id,
+ )
+ files.append(file)
+ elif message.type == ToolInvokeMessage.MessageType.BLOB:
+ # get tool file id
+ assert isinstance(message.message, ToolInvokeMessage.TextMessage)
+ assert message.meta
+
+ tool_file_id = message.message.text.split("/")[-1].split(".")[0]
+ with Session(db.engine) as session:
+ stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
+ tool_file = session.scalar(stmt)
+ if tool_file is None:
+ raise ToolFileError(f"tool file {tool_file_id} not exists")
+
+ mapping = {
+ "tool_file_id": tool_file_id,
+ "transfer_method": FileTransferMethod.TOOL_FILE,
+ }
+
+ files.append(
+ file_factory.build_from_mapping(
+ mapping=mapping,
+ tenant_id=tenant_id,
+ )
+ )
+ elif message.type == ToolInvokeMessage.MessageType.TEXT:
+ assert isinstance(message.message, ToolInvokeMessage.TextMessage)
+ text += message.message.text
+ yield StreamChunkEvent(
+ selector=[node_id, "text"],
+ chunk=message.message.text,
+ is_final=False,
+ )
+ elif message.type == ToolInvokeMessage.MessageType.JSON:
+ assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
+ # JSON message handling for tool node
+ if message.message.json_object is not None:
+ json.append(message.message.json_object)
+ elif message.type == ToolInvokeMessage.MessageType.LINK:
+ assert isinstance(message.message, ToolInvokeMessage.TextMessage)
+ stream_text = f"Link: {message.message.text}\n"
+ text += stream_text
+ yield StreamChunkEvent(
+ selector=[node_id, "text"],
+ chunk=stream_text,
+ is_final=False,
+ )
+ elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
+ assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
+ variable_name = message.message.variable_name
+ variable_value = message.message.variable_value
+ if message.message.stream:
+ if not isinstance(variable_value, str):
+ raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.")
+ if variable_name not in variables:
+ variables[variable_name] = ""
+ variables[variable_name] += variable_value
+
+ yield StreamChunkEvent(
+ selector=[node_id, variable_name],
+ chunk=variable_value,
+ is_final=False,
+ )
+ else:
+ variables[variable_name] = variable_value
+ elif message.type == ToolInvokeMessage.MessageType.FILE:
+ assert message.meta is not None
+ assert isinstance(message.meta, dict)
+ # Validate that meta contains a 'file' key
+ if "file" not in message.meta:
+ raise ToolNodeError("File message is missing 'file' key in meta")
+
+ # Validate that the file is an instance of File
+ if not isinstance(message.meta["file"], File):
+ raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
+ files.append(message.meta["file"])
+ elif message.type == ToolInvokeMessage.MessageType.LOG:
+ assert isinstance(message.message, ToolInvokeMessage.LogMessage)
+ if message.message.metadata:
+ icon = tool_info.get("icon", "")
+ dict_metadata = dict(message.message.metadata)
+ if dict_metadata.get("provider"):
+ manager = PluginInstaller()
+ plugins = manager.list_plugins(tenant_id)
+ try:
+ current_plugin = next(
+ plugin
+ for plugin in plugins
+ if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
+ )
+ icon = current_plugin.declaration.icon
+ except StopIteration:
+ pass
+ icon_dark = None
+ try:
+ builtin_tool = next(
+ provider
+ for provider in BuiltinToolManageService.list_builtin_tools(
+ user_id,
+ tenant_id,
+ )
+ if provider.name == dict_metadata["provider"]
+ )
+ icon = builtin_tool.icon
+ icon_dark = builtin_tool.icon_dark
+ except StopIteration:
+ pass
+
+ dict_metadata["icon"] = icon
+ dict_metadata["icon_dark"] = icon_dark
+ message.message.metadata = dict_metadata
+
+ # Add agent_logs to outputs['json'] to ensure frontend can access thinking process
+ json_output: list[dict[str, Any]] = []
+
+ # Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
+ if json:
+ json_output.extend(json)
+ else:
+ json_output.append({"data": []})
+
+ # Send final chunk events for all streamed outputs
+ # Final chunk for text stream
+ yield StreamChunkEvent(
+ selector=[self._node_id, "text"],
+ chunk="",
+ is_final=True,
+ )
+
+ # Final chunks for any streamed variables
+ for var_name in variables:
+ yield StreamChunkEvent(
+ selector=[self._node_id, var_name],
+ chunk="",
+ is_final=True,
+ )
+
+ yield StreamCompletedEvent(
+ node_run_result=NodeRunResult(
+ status=WorkflowNodeExecutionStatus.SUCCEEDED,
+ outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
+ metadata={
+ WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
+ },
+ inputs=parameters_for_log,
+ )
+ )
+
+ @classmethod
+ def _extract_variable_selector_to_variable_mapping(
+ cls,
+ *,
+ graph_config: Mapping[str, Any],
+ node_id: str,
+ node_data: Mapping[str, Any],
+ ) -> Mapping[str, Sequence[str]]:
+ """
+ Extract variable selector to variable mapping
+ :param graph_config: graph config
+ :param node_id: node id
+ :param node_data: node data
+ :return:
+ """
+ # Create typed NodeData from dict
+ typed_node_data = ToolNodeData.model_validate(node_data)
+
+ result = {}
+ for parameter_name in typed_node_data.tool_parameters:
+ input = typed_node_data.tool_parameters[parameter_name]
+ if input.type == "mixed":
+ assert isinstance(input.value, str)
+ selectors = VariableTemplateParser(input.value).extract_variable_selectors()
+ for selector in selectors:
+ result[selector.variable] = selector.value_selector
+ elif input.type == "variable":
+ result[parameter_name] = input.value
+ elif input.type == "constant":
+ pass
+
+ result = {node_id + "." + key: value for key, value in result.items()}
+
+ return result
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
+ @property
+ def retry(self) -> bool:
+ return self._node_data.retry_config.retry_enabled
diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py
index 78f7b39a06..6f1c530b14 100644
--- a/api/core/workflow/workflow_entry.py
+++ b/api/core/workflow/workflow_entry.py
@@ -2,7 +2,7 @@ import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
-from typing import Any, Optional, cast
+from typing import Any, Optional
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
@@ -318,7 +318,6 @@ class WorkflowEntry:
# init graph
graph = Graph.init(graph_config=graph_dict, node_factory=node_factory)
- node_cls = cast(type[Node], node_cls)
# init workflow run state
node_config = {
"id": node_id,
diff --git a/api/core/workflow/workflow_entry.py.orig b/api/core/workflow/workflow_entry.py.orig
new file mode 100644
index 0000000000..69dd1bdebc
--- /dev/null
+++ b/api/core/workflow/workflow_entry.py.orig
@@ -0,0 +1,445 @@
+import logging
+import time
+import uuid
+from collections.abc import Generator, Mapping, Sequence
+from typing import Any, Optional
+
+from configs import dify_config
+from core.app.apps.exc import GenerateTaskStoppedError
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.file.models import File
+from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
+from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
+from core.workflow.errors import WorkflowNodeRunFailedError
+from core.workflow.graph import Graph
+from core.workflow.graph_engine import GraphEngine
+from core.workflow.graph_engine.command_channels import InMemoryChannel
+from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer
+from core.workflow.graph_engine.protocols.command_channel import CommandChannel
+from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
+from core.workflow.nodes import NodeType
+from core.workflow.nodes.base.node import Node
+from core.workflow.nodes.node_factory import DifyNodeFactory
+from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
+from core.workflow.system_variable import SystemVariable
+from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
+from factories import file_factory
+from models.enums import UserFrom
+from models.workflow import Workflow
+
+logger = logging.getLogger(__name__)
+
+
+class WorkflowEntry:
+ def __init__(
+ self,
+ tenant_id: str,
+ app_id: str,
+ workflow_id: str,
+ graph_config: Mapping[str, Any],
+ graph: Graph,
+ user_id: str,
+ user_from: UserFrom,
+ invoke_from: InvokeFrom,
+ call_depth: int,
+ graph_runtime_state: GraphRuntimeState,
+ command_channel: Optional[CommandChannel] = None,
+ ) -> None:
+ """
+ Init workflow entry
+ :param tenant_id: tenant id
+ :param app_id: app id
+ :param workflow_id: workflow id
+ :param workflow_type: workflow type
+ :param graph_config: workflow graph config
+ :param graph: workflow graph
+ :param user_id: user id
+ :param user_from: user from
+ :param invoke_from: invoke from
+ :param call_depth: call depth
+ :param variable_pool: variable pool
+ :param graph_runtime_state: pre-created graph runtime state
+ :param command_channel: command channel for external control (optional, defaults to InMemoryChannel)
+ :param thread_pool_id: thread pool id
+ """
+ # check call depth
+ workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
+ if call_depth > workflow_call_max_depth:
+ raise ValueError(f"Max workflow call depth {workflow_call_max_depth} reached.")
+
+ # Use provided command channel or default to InMemoryChannel
+ if command_channel is None:
+ command_channel = InMemoryChannel()
+
+ self.command_channel = command_channel
+ self.graph_engine = GraphEngine(
+ tenant_id=tenant_id,
+ app_id=app_id,
+ workflow_id=workflow_id,
+ user_id=user_id,
+ user_from=user_from,
+ invoke_from=invoke_from,
+ call_depth=call_depth,
+ graph=graph,
+ graph_config=graph_config,
+ graph_runtime_state=graph_runtime_state,
+ max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
+ max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
+ command_channel=command_channel,
+ )
+
+ # Add debug logging layer when in debug mode
+ if dify_config.DEBUG:
+ logger.info("Debug mode enabled - adding DebugLoggingLayer to GraphEngine")
+ debug_layer = DebugLoggingLayer(
+ level="DEBUG",
+ include_inputs=True,
+ include_outputs=True,
+ include_process_data=False, # Process data can be very verbose
+ logger_name=f"GraphEngine.Debug.{workflow_id[:8]}", # Use workflow ID prefix for unique logger
+ )
+ self.graph_engine.layer(debug_layer)
+
+ # Add execution limits layer
+ limits_layer = ExecutionLimitsLayer(
+ max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
+ )
+ self.graph_engine.layer(limits_layer)
+
+ def run(self) -> Generator[GraphEngineEvent, None, None]:
+ graph_engine = self.graph_engine
+
+ try:
+ # run workflow
+ generator = graph_engine.run()
+ yield from generator
+ except GenerateTaskStoppedError:
+ pass
+ except Exception as e:
+ logger.exception("Unknown Error when workflow entry running")
+ yield GraphRunFailedEvent(error=str(e))
+ return
+
+ @classmethod
+ def single_step_run(
+ cls,
+ *,
+ workflow: Workflow,
+ node_id: str,
+ user_id: str,
+ user_inputs: Mapping[str, Any],
+ variable_pool: VariablePool,
+ variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
+ ) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
+ """
+ Single step run workflow node
+ :param workflow: Workflow instance
+ :param node_id: node id
+ :param user_id: user id
+ :param user_inputs: user inputs
+ :return:
+ """
+ node_config = workflow.get_node_config_by_id(node_id)
+ node_config_data = node_config.get("data", {})
+
+ # Get node class
+ node_type = NodeType(node_config_data.get("type"))
+ node_version = node_config_data.get("version", "1")
+ node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
+
+ # init graph init params and runtime state
+ graph_init_params = GraphInitParams(
+ tenant_id=workflow.tenant_id,
+ app_id=workflow.app_id,
+ workflow_id=workflow.id,
+ graph_config=workflow.graph_dict,
+ user_id=user_id,
+ user_from=UserFrom.ACCOUNT,
+ invoke_from=InvokeFrom.DEBUGGER,
+ call_depth=0,
+ )
+ graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
+
+ # init node factory
+ node_factory = DifyNodeFactory(
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+
+ # init graph
+ graph = Graph.init(graph_config=workflow.graph_dict, node_factory=node_factory)
+
+ # init workflow run state
+ node = node_cls(
+ id=str(uuid.uuid4()),
+ config=node_config,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+ node.init_node_data(node_config_data)
+
+ try:
+ # variable selector to variable mapping
+ variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
+ graph_config=workflow.graph_dict, config=node_config
+ )
+ except NotImplementedError:
+ variable_mapping = {}
+
+ # Loading missing variable from draft var here, and set it into
+ # variable_pool.
+ load_into_variable_pool(
+ variable_loader=variable_loader,
+ variable_pool=variable_pool,
+ variable_mapping=variable_mapping,
+ user_inputs=user_inputs,
+ )
+
+ cls.mapping_user_inputs_to_variable_pool(
+ variable_mapping=variable_mapping,
+ user_inputs=user_inputs,
+ variable_pool=variable_pool,
+ tenant_id=workflow.tenant_id,
+ )
+
+ try:
+ # run node
+ generator = node.run()
+ except Exception as e:
+ logger.exception(
+ "error while running node, workflow_id=%s, node_id=%s, node_type=%s, node_version=%s",
+ workflow.id,
+ node.id,
+ node.node_type,
+ node.version(),
+ )
+ raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
+ return node, generator
+
+ @staticmethod
+ def _create_single_node_graph(
+ node_id: str,
+ node_data: dict[str, Any],
+ node_width: int = 114,
+ node_height: int = 514,
+ ) -> dict[str, Any]:
+ """
+ Create a minimal graph structure for testing a single node in isolation.
+
+ :param node_id: ID of the target node
+ :param node_data: configuration data for the target node
+ :param node_width: width for UI layout (default: 200)
+ :param node_height: height for UI layout (default: 100)
+ :return: graph dictionary with start node and target node
+ """
+ node_config = {
+ "id": node_id,
+ "width": node_width,
+ "height": node_height,
+ "type": "custom",
+ "data": node_data,
+ }
+ start_node_config = {
+ "id": "start",
+ "width": node_width,
+ "height": node_height,
+ "type": "custom",
+ "data": {
+ "type": NodeType.START.value,
+ "title": "Start",
+ "desc": "Start",
+ },
+ }
+ return {
+ "nodes": [start_node_config, node_config],
+ "edges": [
+ {
+ "source": "start",
+ "target": node_id,
+ "sourceHandle": "source",
+ "targetHandle": "target",
+ }
+ ],
+ }
+
+ @classmethod
+ def run_free_node(
+ cls, node_data: dict, node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
+ ) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
+ """
+ Run free node
+
+ NOTE: only parameter_extractor/question_classifier are supported
+
+ :param node_data: node data
+ :param node_id: node id
+ :param tenant_id: tenant id
+ :param user_id: user id
+ :param user_inputs: user inputs
+ :return:
+ """
+ # Create a minimal graph for single node execution
+ graph_dict = cls._create_single_node_graph(node_id, node_data)
+
+ node_type = NodeType(node_data.get("type", ""))
+ if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}:
+ raise ValueError(f"Node type {node_type} not supported")
+
+ node_cls = NODE_TYPE_CLASSES_MAPPING[node_type]["1"]
+ if not node_cls:
+ raise ValueError(f"Node class not found for node type {node_type}")
+
+ # init variable pool
+ variable_pool = VariablePool(
+ system_variables=SystemVariable.empty(),
+ user_inputs={},
+ environment_variables=[],
+ )
+
+ # init graph init params and runtime state
+ graph_init_params = GraphInitParams(
+ tenant_id=tenant_id,
+ app_id="",
+ workflow_id="",
+ graph_config=graph_dict,
+ user_id=user_id,
+ user_from=UserFrom.ACCOUNT,
+ invoke_from=InvokeFrom.DEBUGGER,
+ call_depth=0,
+ )
+ graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
+
+ # init node factory
+ node_factory = DifyNodeFactory(
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+
+ # init graph
+ graph = Graph.init(graph_config=graph_dict, node_factory=node_factory)
+
+ node_cls = cast(type[Node], node_cls)
+ # init workflow run state
+ node_config = {
+ "id": node_id,
+ "data": node_data,
+ }
+ node: Node = node_cls(
+ id=str(uuid.uuid4()),
+ config=node_config,
+ graph_init_params=graph_init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+ node.init_node_data(node_data)
+
+ try:
+ # variable selector to variable mapping
+ try:
+ variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
+ graph_config=graph_dict, config=node_config
+ )
+ except NotImplementedError:
+ variable_mapping = {}
+
+ cls.mapping_user_inputs_to_variable_pool(
+ variable_mapping=variable_mapping,
+ user_inputs=user_inputs,
+ variable_pool=variable_pool,
+ tenant_id=tenant_id,
+ )
+
+ # run node
+ generator = node.run()
+
+ return node, generator
+ except Exception as e:
+ logger.exception(
+ "error while running node, node_id=%s, node_type=%s, node_version=%s",
+ node.id,
+ node.node_type,
+ node.version(),
+ )
+ raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
+
+ @staticmethod
+ def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
+ # NOTE(QuantumGhost): Avoid using this function in new code.
+ # Keep values structured as long as possible and only convert to dict
+ # immediately before serialization (e.g., JSON serialization) to maintain
+ # data integrity and type information.
+ result = WorkflowEntry._handle_special_values(value)
+ return result if isinstance(result, Mapping) or result is None else dict(result)
+
+ @staticmethod
+ def _handle_special_values(value: Any) -> Any:
+ if value is None:
+ return value
+ if isinstance(value, dict):
+ res = {}
+ for k, v in value.items():
+ res[k] = WorkflowEntry._handle_special_values(v)
+ return res
+ if isinstance(value, list):
+ res_list = []
+ for item in value:
+ res_list.append(WorkflowEntry._handle_special_values(item))
+ return res_list
+ if isinstance(value, File):
+ return value.to_dict()
+ return value
+
+ @classmethod
+ def mapping_user_inputs_to_variable_pool(
+ cls,
+ *,
+ variable_mapping: Mapping[str, Sequence[str]],
+ user_inputs: Mapping[str, Any],
+ variable_pool: VariablePool,
+ tenant_id: str,
+ ) -> None:
+ # NOTE(QuantumGhost): This logic should remain synchronized with
+ # the implementation of `load_into_variable_pool`, specifically the logic about
+ # variable existence checking.
+
+ # WARNING(QuantumGhost): The semantics of this method are not clearly defined,
+ # and multiple parts of the codebase depend on its current behavior.
+ # Modify with caution.
+ for node_variable, variable_selector in variable_mapping.items():
+ # fetch node id and variable key from node_variable
+ node_variable_list = node_variable.split(".")
+ if len(node_variable_list) < 1:
+ raise ValueError(f"Invalid node variable {node_variable}")
+
+ node_variable_key = ".".join(node_variable_list[1:])
+
+ if (node_variable_key not in user_inputs and node_variable not in user_inputs) and not variable_pool.get(
+ variable_selector
+ ):
+ raise ValueError(f"Variable key {node_variable} not found in user inputs.")
+
+ # environment variable already exist in variable pool, not from user inputs
+ if variable_pool.get(variable_selector):
+ continue
+
+ # fetch variable node id from variable selector
+ variable_node_id = variable_selector[0]
+ variable_key_list = variable_selector[1:]
+ variable_key_list = list(variable_key_list)
+
+ # get input value
+ input_value = user_inputs.get(node_variable)
+ if not input_value:
+ input_value = user_inputs.get(node_variable_key)
+
+ if isinstance(input_value, dict) and "type" in input_value and "transfer_method" in input_value:
+ input_value = file_factory.build_from_mapping(mapping=input_value, tenant_id=tenant_id)
+ if (
+ isinstance(input_value, list)
+ and all(isinstance(item, dict) for item in input_value)
+ and all("type" in item and "transfer_method" in item for item in input_value)
+ ):
+ input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id)
+
+ # append variable and value to variable pool
+ if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID:
+ variable_pool.add([variable_node_id] + variable_key_list, input_value)
diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py
index 6680bc692d..35feaae4e0 100644
--- a/api/events/event_handlers/update_provider_when_message_created.py
+++ b/api/events/event_handlers/update_provider_when_message_created.py
@@ -12,6 +12,7 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, Ch
from core.entities.provider_entities import QuotaUnit, SystemConfiguration
from events.message_event import message_was_created
from extensions.ext_database import db
+from extensions.ext_redis import redis_client, redis_fallback
from libs import datetime_utils
from models.model import Message
from models.provider import Provider, ProviderType
@@ -19,6 +20,32 @@ from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)
+# Redis cache key prefix for provider last used timestamps
+_PROVIDER_LAST_USED_CACHE_PREFIX = "provider:last_used"
+# Default TTL for cache entries (10 minutes)
+_CACHE_TTL_SECONDS = 600
+LAST_USED_UPDATE_WINDOW_SECONDS = 60 * 5
+
+
+def _get_provider_cache_key(tenant_id: str, provider_name: str) -> str:
+ """Generate Redis cache key for provider last used timestamp."""
+ return f"{_PROVIDER_LAST_USED_CACHE_PREFIX}:{tenant_id}:{provider_name}"
+
+
+@redis_fallback(default_return=None)
+def _get_last_update_timestamp(cache_key: str) -> Optional[datetime]:
+ """Get last update timestamp from Redis cache."""
+ timestamp_str = redis_client.get(cache_key)
+ if timestamp_str:
+ return datetime.fromtimestamp(float(timestamp_str.decode("utf-8")))
+ return None
+
+
+@redis_fallback()
+def _set_last_update_timestamp(cache_key: str, timestamp: datetime) -> None:
+ """Set last update timestamp in Redis cache with TTL."""
+ redis_client.setex(cache_key, _CACHE_TTL_SECONDS, str(timestamp.timestamp()))
+
class _ProviderUpdateFilters(BaseModel):
"""Filters for identifying Provider records to update."""
@@ -139,7 +166,7 @@ def handle(sender: Message, **kwargs):
provider_name,
)
- except Exception as e:
+ except Exception:
# Log failure with timing and context
duration = time_module.perf_counter() - start_time
@@ -215,8 +242,23 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]
# Prepare values dict for SQLAlchemy update
update_values = {}
- # updateing to `last_used` is removed due to performance reason.
- # ref: https://github.com/langgenius/dify/issues/24526
+
+ # NOTE: For frequently used providers under high load, this implementation may experience
+ # race conditions or update contention despite the time-window optimization:
+ # 1. Multiple concurrent requests might check the same cache key simultaneously
+ # 2. Redis cache operations are not atomic with database updates
+ # 3. Heavy providers could still face database lock contention during peak usage
+ # The current implementation is acceptable for most scenarios, but future optimization
+ # considerations could include: batched updates, or async processing.
+ if values.last_used is not None:
+ cache_key = _get_provider_cache_key(filters.tenant_id, filters.provider_name)
+ now = datetime_utils.naive_utc_now()
+ last_update = _get_last_update_timestamp(cache_key)
+
+ if last_update is None or (now - last_update).total_seconds() > LAST_USED_UPDATE_WINDOW_SECONDS:
+ update_values["last_used"] = values.last_used
+ _set_last_update_timestamp(cache_key, now)
+
if values.quota_used is not None:
update_values["quota_used"] = values.quota_used
# Skip the current update operation if no updates are required.
diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py
index 0ea7d3ae1e..62e3bfa3ba 100644
--- a/api/factories/file_factory.py
+++ b/api/factories/file_factory.py
@@ -3,7 +3,7 @@ import os
import urllib.parse
import uuid
from collections.abc import Callable, Mapping, Sequence
-from typing import Any, cast
+from typing import Any
import httpx
from sqlalchemy import select
@@ -258,7 +258,6 @@ def _get_remote_file_info(url: str):
mime_type = ""
resp = ssrf_proxy.head(url, follow_redirects=True)
- resp = cast(httpx.Response, resp)
if resp.status_code == httpx.codes.OK:
if content_disposition := resp.headers.get("Content-Disposition"):
filename = str(content_disposition.split("filename=")[-1].strip('"'))
diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py
index f048d0f3b6..53cb9de3ee 100644
--- a/api/fields/workflow_fields.py
+++ b/api/fields/workflow_fields.py
@@ -17,7 +17,7 @@ class EnvironmentVariableField(fields.Raw):
return {
"id": value.id,
"name": value.name,
- "value": encrypter.obfuscated_token(value.value),
+ "value": encrypter.full_mask_token(),
"value_type": value.value_type.value,
"description": value.description,
}
diff --git a/api/migrations/versions/2025_08_29_1534-b95962a3885c_add_workflow_app_log_run_id_index.py b/api/migrations/versions/2025_08_29_1534-b95962a3885c_add_workflow_app_log_run_id_index.py
new file mode 100644
index 0000000000..465f8664a5
--- /dev/null
+++ b/api/migrations/versions/2025_08_29_1534-b95962a3885c_add_workflow_app_log_run_id_index.py
@@ -0,0 +1,32 @@
+"""chore: add workflow app log run id index
+
+Revision ID: b95962a3885c
+Revises: 0e154742a5fa
+Create Date: 2025-08-29 15:34:09.838623
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = 'b95962a3885c'
+down_revision = '8d289573e1da'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('workflow_app_logs', schema=None) as batch_op:
+ batch_op.create_index('workflow_app_log_workflow_run_id_idx', ['workflow_run_id'], unique=False)
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('workflow_app_logs', schema=None) as batch_op:
+ batch_op.drop_index('workflow_app_log_workflow_run_id_idx')
+ # ### end Alembic commands ###
diff --git a/api/models/tools.py b/api/models/tools.py
index 26bbc03694..2d15b778e9 100644
--- a/api/models/tools.py
+++ b/api/models/tools.py
@@ -319,7 +319,7 @@ class MCPToolProvider(Base):
@property
def decrypted_server_url(self) -> str:
- return cast(str, encrypter.decrypt_token(self.tenant_id, self.server_url))
+ return encrypter.decrypt_token(self.tenant_id, self.server_url)
@property
def masked_server_url(self) -> str:
diff --git a/api/models/workflow.py b/api/models/workflow.py
index 842b227028..3c449c89cc 100644
--- a/api/models/workflow.py
+++ b/api/models/workflow.py
@@ -835,6 +835,7 @@ class WorkflowAppLog(Base):
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"),
sa.Index("workflow_app_log_app_idx", "tenant_id", "app_id"),
+ sa.Index("workflow_app_log_workflow_run_id_idx", "workflow_run_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
diff --git a/api/services/account_service.py b/api/services/account_service.py
index 089e667166..50ce171ded 100644
--- a/api/services/account_service.py
+++ b/api/services/account_service.py
@@ -146,7 +146,7 @@ class AccountService:
account.last_active_at = naive_utc_now()
db.session.commit()
- return cast(Account, account)
+ return account
@staticmethod
def get_account_jwt_token(account: Account) -> str:
@@ -191,7 +191,7 @@ class AccountService:
db.session.commit()
- return cast(Account, account)
+ return account
@staticmethod
def update_account_password(account, password, new_password):
@@ -1127,7 +1127,7 @@ class TenantService:
def get_custom_config(tenant_id: str) -> dict:
tenant = db.get_or_404(Tenant, tenant_id)
- return cast(dict, tenant.custom_config_dict)
+ return tenant.custom_config_dict
@staticmethod
def is_owner(account: Account, tenant: Tenant) -> bool:
diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py
index 6603063c22..9ee92bc2dc 100644
--- a/api/services/annotation_service.py
+++ b/api/services/annotation_service.py
@@ -1,5 +1,5 @@
import uuid
-from typing import cast
+from typing import Optional
import pandas as pd
from flask_login import current_user
@@ -40,7 +40,7 @@ class AppAnnotationService:
if not message:
raise NotFound("Message Not Exists.")
- annotation = message.annotation
+ annotation: Optional[MessageAnnotation] = message.annotation
# save the message annotation
if annotation:
annotation.content = args["answer"]
@@ -70,7 +70,7 @@ class AppAnnotationService:
app_id,
annotation_setting.collection_binding_id,
)
- return cast(MessageAnnotation, annotation)
+ return annotation
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:
diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py
index 0d10aa15dd..bc574fd8bc 100644
--- a/api/services/dataset_service.py
+++ b/api/services/dataset_service.py
@@ -1149,7 +1149,7 @@ class DocumentService:
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
- "top_k": 2,
+ "top_k": 4,
"score_threshold_enabled": False,
}
@@ -1612,7 +1612,7 @@ class DocumentService:
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
reranking_enable=False,
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
- top_k=2,
+ top_k=4,
score_threshold_enabled=False,
)
# save dataset
diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py
index 1517ca6594..bce28da032 100644
--- a/api/services/hit_testing_service.py
+++ b/api/services/hit_testing_service.py
@@ -18,7 +18,7 @@ default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
- "top_k": 2,
+ "top_k": 4,
"score_threshold_enabled": False,
}
@@ -66,7 +66,7 @@ class HitTestingService:
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
dataset_id=dataset.id,
query=query,
- top_k=retrieval_model.get("top_k", 2),
+ top_k=retrieval_model.get("top_k", 4),
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
diff --git a/api/services/message_service.py b/api/services/message_service.py
index a19d6ee157..13c8e948ca 100644
--- a/api/services/message_service.py
+++ b/api/services/message_service.py
@@ -112,7 +112,9 @@ class MessageService:
base_query = base_query.where(Message.conversation_id == conversation.id)
# Check if include_ids is not None and not empty to avoid WHERE false condition
- if include_ids is not None and len(include_ids) > 0:
+ if include_ids is not None:
+ if len(include_ids) == 0:
+ return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
base_query = base_query.where(Message.id.in_(include_ids))
if last_id:
diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py
index 88e8697d17..955e898ec1 100644
--- a/api/tasks/duplicate_document_indexing_task.py
+++ b/api/tasks/duplicate_document_indexing_task.py
@@ -27,73 +27,73 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
documents = []
start_at = time.perf_counter()
- dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
- if dataset is None:
- logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
- db.session.close()
- return
-
- # check document limit
- features = FeatureService.get_features(dataset.tenant_id)
try:
- if features.billing.enabled:
- vector_space = features.vector_space
- count = len(document_ids)
- if features.billing.subscription.plan == "sandbox" and count > 1:
- raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
- batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
- if count > batch_upload_limit:
- raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
- if 0 < vector_space.limit <= vector_space.size:
- raise ValueError(
- "Your total number of documents plus the number of uploads have over the limit of "
- "your subscription."
+ dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
+ if dataset is None:
+ logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
+ db.session.close()
+ return
+
+ # check document limit
+ features = FeatureService.get_features(dataset.tenant_id)
+ try:
+ if features.billing.enabled:
+ vector_space = features.vector_space
+ count = len(document_ids)
+ if features.billing.subscription.plan == "sandbox" and count > 1:
+ raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
+ batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
+ if count > batch_upload_limit:
+ raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
+ if 0 < vector_space.limit <= vector_space.size:
+ raise ValueError(
+ "Your total number of documents plus the number of uploads have over the limit of "
+ "your subscription."
+ )
+ except Exception as e:
+ for document_id in document_ids:
+ document = (
+ db.session.query(Document)
+ .where(Document.id == document_id, Document.dataset_id == dataset_id)
+ .first()
)
- except Exception as e:
+ if document:
+ document.indexing_status = "error"
+ document.error = str(e)
+ document.stopped_at = naive_utc_now()
+ db.session.add(document)
+ db.session.commit()
+ return
+
for document_id in document_ids:
+ logger.info(click.style(f"Start process document: {document_id}", fg="green"))
+
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
+
if document:
- document.indexing_status = "error"
- document.error = str(e)
- document.stopped_at = naive_utc_now()
+ # clean old data
+ index_type = document.doc_form
+ index_processor = IndexProcessorFactory(index_type).init_index_processor()
+
+ segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
+ if segments:
+ index_node_ids = [segment.index_node_id for segment in segments]
+
+ # delete from vector index
+ index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
+
+ for segment in segments:
+ db.session.delete(segment)
+ db.session.commit()
+
+ document.indexing_status = "parsing"
+ document.processing_started_at = naive_utc_now()
+ documents.append(document)
db.session.add(document)
db.session.commit()
- return
- finally:
- db.session.close()
- for document_id in document_ids:
- logger.info(click.style(f"Start process document: {document_id}", fg="green"))
-
- document = (
- db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
- )
-
- if document:
- # clean old data
- index_type = document.doc_form
- index_processor = IndexProcessorFactory(index_type).init_index_processor()
-
- segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
- if segments:
- index_node_ids = [segment.index_node_id for segment in segments]
-
- # delete from vector index
- index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
-
- for segment in segments:
- db.session.delete(segment)
- db.session.commit()
-
- document.indexing_status = "parsing"
- document.processing_started_at = naive_utc_now()
- documents.append(document)
- db.session.add(document)
- db.session.commit()
-
- try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py
index e6f3f0ddf6..fe46ed7658 100644
--- a/api/tests/integration_tests/workflow/nodes/test_code.py
+++ b/api/tests/integration_tests/workflow/nodes/test_code.py
@@ -1,7 +1,6 @@
import time
import uuid
from os import getenv
-from typing import cast
import pytest
@@ -11,7 +10,6 @@ from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.code.code_node import CodeNode
-from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
@@ -242,8 +240,6 @@ def test_execute_code_output_validator_depth():
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
}
- node._node_data = cast(CodeNodeData, node._node_data)
-
# validate
node._transform_result(result, node._node_data.outputs)
@@ -338,8 +334,6 @@ def test_execute_code_output_object_list():
]
}
- node._node_data = cast(CodeNodeData, node._node_data)
-
# validate
node._transform_result(result, node._node_data.outputs)
diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py.orig b/api/tests/integration_tests/workflow/nodes/test_code.py.orig
new file mode 100644
index 0000000000..fe46ed7658
--- /dev/null
+++ b/api/tests/integration_tests/workflow/nodes/test_code.py.orig
@@ -0,0 +1,390 @@
+import time
+import uuid
+from os import getenv
+
+import pytest
+
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
+from core.workflow.enums import WorkflowNodeExecutionStatus
+from core.workflow.graph import Graph
+from core.workflow.node_events import NodeRunResult
+from core.workflow.nodes.code.code_node import CodeNode
+from core.workflow.nodes.node_factory import DifyNodeFactory
+from core.workflow.system_variable import SystemVariable
+from models.enums import UserFrom
+from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
+
+CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000"))
+
+
+def init_code_node(code_config: dict):
+ graph_config = {
+ "edges": [
+ {
+ "id": "start-source-code-target",
+ "source": "start",
+ "target": "code",
+ },
+ ],
+ "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, code_config],
+ }
+
+ init_params = GraphInitParams(
+ tenant_id="1",
+ app_id="1",
+ workflow_id="1",
+ graph_config=graph_config,
+ user_id="1",
+ user_from=UserFrom.ACCOUNT,
+ invoke_from=InvokeFrom.DEBUGGER,
+ call_depth=0,
+ )
+
+ # construct variable pool
+ variable_pool = VariablePool(
+ system_variables=SystemVariable(user_id="aaa", files=[]),
+ user_inputs={},
+ environment_variables=[],
+ conversation_variables=[],
+ )
+ variable_pool.add(["code", "args1"], 1)
+ variable_pool.add(["code", "args2"], 2)
+
+ graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
+
+ # Create node factory
+ node_factory = DifyNodeFactory(
+ graph_init_params=init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+
+ graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
+
+ node = CodeNode(
+ id=str(uuid.uuid4()),
+ config=code_config,
+ graph_init_params=init_params,
+ graph_runtime_state=graph_runtime_state,
+ )
+
+ # Initialize node data
+ if "data" in code_config:
+ node.init_node_data(code_config["data"])
+
+ return node
+
+
+@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
+def test_execute_code(setup_code_executor_mock):
+ code = """
+ def main(args1: int, args2: int) -> dict:
+ return {
+ "result": args1 + args2,
+ }
+ """
+ # trim first 4 spaces at the beginning of each line
+ code = "\n".join([line[4:] for line in code.split("\n")])
+
+ code_config = {
+ "id": "code",
+ "data": {
+ "outputs": {
+ "result": {
+ "type": "number",
+ },
+ },
+ "title": "123",
+ "variables": [
+ {
+ "variable": "args1",
+ "value_selector": ["1", "args1"],
+ },
+ {"variable": "args2", "value_selector": ["1", "args2"]},
+ ],
+ "answer": "123",
+ "code_language": "python3",
+ "code": code,
+ },
+ }
+
+ node = init_code_node(code_config)
+ node.graph_runtime_state.variable_pool.add(["1", "args1"], 1)
+ node.graph_runtime_state.variable_pool.add(["1", "args2"], 2)
+
+ # execute node
+ result = node._run()
+ assert isinstance(result, NodeRunResult)
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
+ assert result.outputs is not None
+ assert result.outputs["result"] == 3
+ assert result.error == ""
+
+
+@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
+def test_execute_code_output_validator(setup_code_executor_mock):
+ code = """
+ def main(args1: int, args2: int) -> dict:
+ return {
+ "result": args1 + args2,
+ }
+ """
+ # trim first 4 spaces at the beginning of each line
+ code = "\n".join([line[4:] for line in code.split("\n")])
+
+ code_config = {
+ "id": "code",
+ "data": {
+ "outputs": {
+ "result": {
+ "type": "string",
+ },
+ },
+ "title": "123",
+ "variables": [
+ {
+ "variable": "args1",
+ "value_selector": ["1", "args1"],
+ },
+ {"variable": "args2", "value_selector": ["1", "args2"]},
+ ],
+ "answer": "123",
+ "code_language": "python3",
+ "code": code,
+ },
+ }
+
+ node = init_code_node(code_config)
+ node.graph_runtime_state.variable_pool.add(["1", "args1"], 1)
+ node.graph_runtime_state.variable_pool.add(["1", "args2"], 2)
+
+ # execute node
+ result = node._run()
+ assert isinstance(result, NodeRunResult)
+ assert result.status == WorkflowNodeExecutionStatus.FAILED
+ assert result.error == "Output variable `result` must be a string"
+
+
+def test_execute_code_output_validator_depth():
+ code = """
+ def main(args1: int, args2: int) -> dict:
+ return {
+ "result": {
+ "result": args1 + args2,
+ }
+ }
+ """
+ # trim first 4 spaces at the beginning of each line
+ code = "\n".join([line[4:] for line in code.split("\n")])
+
+ code_config = {
+ "id": "code",
+ "data": {
+ "outputs": {
+ "string_validator": {
+ "type": "string",
+ },
+ "number_validator": {
+ "type": "number",
+ },
+ "number_array_validator": {
+ "type": "array[number]",
+ },
+ "string_array_validator": {
+ "type": "array[string]",
+ },
+ "object_validator": {
+ "type": "object",
+ "children": {
+ "result": {
+ "type": "number",
+ },
+ "depth": {
+ "type": "object",
+ "children": {
+ "depth": {
+ "type": "object",
+ "children": {
+ "depth": {
+ "type": "number",
+ }
+ },
+ }
+ },
+ },
+ },
+ },
+ },
+ "title": "123",
+ "variables": [
+ {
+ "variable": "args1",
+ "value_selector": ["1", "args1"],
+ },
+ {"variable": "args2", "value_selector": ["1", "args2"]},
+ ],
+ "answer": "123",
+ "code_language": "python3",
+ "code": code,
+ },
+ }
+
+ node = init_code_node(code_config)
+
+ # construct result
+ result = {
+ "number_validator": 1,
+ "string_validator": "1",
+ "number_array_validator": [1, 2, 3, 3.333],
+ "string_array_validator": ["1", "2", "3"],
+ "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
+ }
+
+ # validate
+ node._transform_result(result, node._node_data.outputs)
+
+ # construct result
+ result = {
+ "number_validator": "1",
+ "string_validator": 1,
+ "number_array_validator": ["1", "2", "3", "3.333"],
+ "string_array_validator": [1, 2, 3],
+ "object_validator": {"result": "1", "depth": {"depth": {"depth": "1"}}},
+ }
+
+ # validate
+ with pytest.raises(ValueError):
+ node._transform_result(result, node._node_data.outputs)
+
+ # construct result
+ result = {
+ "number_validator": 1,
+ "string_validator": (CODE_MAX_STRING_LENGTH + 1) * "1",
+ "number_array_validator": [1, 2, 3, 3.333],
+ "string_array_validator": ["1", "2", "3"],
+ "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
+ }
+
+ # validate
+ with pytest.raises(ValueError):
+ node._transform_result(result, node._node_data.outputs)
+
+ # construct result
+ result = {
+ "number_validator": 1,
+ "string_validator": "1",
+ "number_array_validator": [1, 2, 3, 3.333] * 2000,
+ "string_array_validator": ["1", "2", "3"],
+ "object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
+ }
+
+ # validate
+ with pytest.raises(ValueError):
+ node._transform_result(result, node._node_data.outputs)
+
+
+def test_execute_code_output_object_list():
+ code = """
+ def main(args1: int, args2: int) -> dict:
+ return {
+ "result": {
+ "result": args1 + args2,
+ }
+ }
+ """
+ # trim first 4 spaces at the beginning of each line
+ code = "\n".join([line[4:] for line in code.split("\n")])
+
+ code_config = {
+ "id": "code",
+ "data": {
+ "outputs": {
+ "object_list": {
+ "type": "array[object]",
+ },
+ },
+ "title": "123",
+ "variables": [
+ {
+ "variable": "args1",
+ "value_selector": ["1", "args1"],
+ },
+ {"variable": "args2", "value_selector": ["1", "args2"]},
+ ],
+ "answer": "123",
+ "code_language": "python3",
+ "code": code,
+ },
+ }
+
+ node = init_code_node(code_config)
+
+ # construct result
+ result = {
+ "object_list": [
+ {
+ "result": 1,
+ },
+ {
+ "result": 2,
+ },
+ {
+ "result": [1, 2, 3],
+ },
+ ]
+ }
+
+ # validate
+ node._transform_result(result, node._node_data.outputs)
+
+ # construct result
+ result = {
+ "object_list": [
+ {
+ "result": 1,
+ },
+ {
+ "result": 2,
+ },
+ {
+ "result": [1, 2, 3],
+ },
+ 1,
+ ]
+ }
+
+ # validate
+ with pytest.raises(ValueError):
+ node._transform_result(result, node._node_data.outputs)
+
+
+def test_execute_code_scientific_notation():
+ code = """
+ def main() -> dict:
+ return {
+ "result": -8.0E-5
+ }
+ """
+ code = "\n".join([line[4:] for line in code.split("\n")])
+
+ code_config = {
+ "id": "code",
+ "data": {
+ "outputs": {
+ "result": {
+ "type": "number",
+ },
+ },
+ "title": "123",
+ "variables": [],
+ "answer": "123",
+ "code_language": "python3",
+ "code": code,
+ },
+ }
+
+ node = init_code_node(code_config)
+ # execute node
+ result = node._run()
+ assert isinstance(result, NodeRunResult)
+ assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
diff --git a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py
new file mode 100644
index 0000000000..bf25968100
--- /dev/null
+++ b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py
@@ -0,0 +1,788 @@
+from unittest.mock import Mock, patch
+
+import pytest
+from faker import Faker
+
+from core.tools.entities.api_entities import ToolProviderApiEntity
+from core.tools.entities.common_entities import I18nObject
+from core.tools.entities.tool_entities import ToolProviderType
+from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
+from services.tools.tools_transform_service import ToolTransformService
+
+
+class TestToolTransformService:
+ """Integration tests for ToolTransformService using testcontainers."""
+
+ @pytest.fixture
+ def mock_external_service_dependencies(self):
+ """Mock setup for external service dependencies."""
+ with (
+ patch("services.tools.tools_transform_service.dify_config") as mock_dify_config,
+ ):
+ # Setup default mock returns
+ mock_dify_config.CONSOLE_API_URL = "https://console.example.com"
+
+ yield {
+ "dify_config": mock_dify_config,
+ }
+
+ def _create_test_tool_provider(
+ self, db_session_with_containers, mock_external_service_dependencies, provider_type="api"
+ ):
+ """
+ Helper method to create a test tool provider for testing.
+
+ Args:
+ db_session_with_containers: Database session from testcontainers infrastructure
+ mock_external_service_dependencies: Mock dependencies
+ provider_type: Type of provider to create
+
+ Returns:
+ Tool provider instance
+ """
+ fake = Faker()
+
+ if provider_type == "api":
+ provider = ApiToolProvider(
+ name=fake.company(),
+ description=fake.text(max_nb_chars=100),
+ icon='{"background": "#FF6B6B", "content": "🔧"}',
+ icon_dark='{"background": "#252525", "content": "🔧"}',
+ tenant_id="test_tenant_id",
+ user_id="test_user_id",
+ credentials={"auth_type": "api_key_header", "api_key": "test_key"},
+ provider_type="api",
+ )
+ elif provider_type == "builtin":
+ provider = BuiltinToolProvider(
+ name=fake.company(),
+ description=fake.text(max_nb_chars=100),
+ icon="🔧",
+ icon_dark="🔧",
+ tenant_id="test_tenant_id",
+ provider="test_provider",
+ credential_type="api_key",
+ credentials={"api_key": "test_key"},
+ )
+ elif provider_type == "workflow":
+ provider = WorkflowToolProvider(
+ name=fake.company(),
+ description=fake.text(max_nb_chars=100),
+ icon='{"background": "#FF6B6B", "content": "🔧"}',
+ icon_dark='{"background": "#252525", "content": "🔧"}',
+ tenant_id="test_tenant_id",
+ user_id="test_user_id",
+ workflow_id="test_workflow_id",
+ )
+ elif provider_type == "mcp":
+ provider = MCPToolProvider(
+ name=fake.company(),
+ description=fake.text(max_nb_chars=100),
+ provider_icon='{"background": "#FF6B6B", "content": "🔧"}',
+ tenant_id="test_tenant_id",
+ user_id="test_user_id",
+ server_url="https://mcp.example.com",
+ server_identifier="test_server",
+ tools='[{"name": "test_tool", "description": "Test tool"}]',
+ authed=True,
+ )
+ else:
+ raise ValueError(f"Unknown provider type: {provider_type}")
+
+ from extensions.ext_database import db
+
+ db.session.add(provider)
+ db.session.commit()
+
+ return provider
+
+ def test_get_plugin_icon_url_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful plugin icon URL generation.
+
+ This test verifies:
+ - Proper URL construction for plugin icons
+ - Correct tenant_id and filename handling
+ - URL format compliance
+ """
+ # Arrange: Setup test data
+ fake = Faker()
+ tenant_id = fake.uuid4()
+ filename = "test_icon.png"
+
+ # Act: Execute the method under test
+ result = ToolTransformService.get_plugin_icon_url(tenant_id, filename)
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert isinstance(result, str)
+ assert "console/api/workspaces/current/plugin/icon" in result
+ assert tenant_id in result
+ assert filename in result
+ assert result.startswith("https://console.example.com")
+
+ # Verify URL structure
+ expected_url = f"https://console.example.com/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={filename}"
+ assert result == expected_url
+
+ def test_get_plugin_icon_url_with_empty_console_url(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test plugin icon URL generation when CONSOLE_API_URL is empty.
+
+ This test verifies:
+ - Fallback to relative URL when CONSOLE_API_URL is None
+ - Proper URL construction with relative path
+ """
+ # Arrange: Setup mock with empty console URL
+ mock_external_service_dependencies["dify_config"].CONSOLE_API_URL = None
+ fake = Faker()
+ tenant_id = fake.uuid4()
+ filename = "test_icon.png"
+
+ # Act: Execute the method under test
+ result = ToolTransformService.get_plugin_icon_url(tenant_id, filename)
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert isinstance(result, str)
+ assert result.startswith("/console/api/workspaces/current/plugin/icon")
+ assert tenant_id in result
+ assert filename in result
+
+ # Verify URL structure
+ expected_url = f"/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={filename}"
+ assert result == expected_url
+
+ def test_get_tool_provider_icon_url_builtin_success(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test successful tool provider icon URL generation for builtin providers.
+
+ This test verifies:
+ - Proper URL construction for builtin tool providers
+ - Correct provider type handling
+ - URL format compliance
+ """
+ # Arrange: Setup test data
+ fake = Faker()
+ provider_type = ToolProviderType.BUILT_IN.value
+ provider_name = fake.company()
+ icon = "🔧"
+
+ # Act: Execute the method under test
+ result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert isinstance(result, str)
+ assert "console/api/workspaces/current/tool-provider/builtin" in result
+ # Note: provider_name may contain spaces that get URL encoded
+ assert provider_name.replace(" ", "%20") in result or provider_name in result
+ assert result.endswith("/icon")
+ assert result.startswith("https://console.example.com")
+
+ # Verify URL structure (accounting for URL encoding)
+ # The actual result will have URL-encoded spaces (%20), so we need to compare accordingly
+ expected_url = (
+ f"https://console.example.com/console/api/workspaces/current/tool-provider/builtin/{provider_name}/icon"
+ )
+ # Convert expected URL to match the actual URL encoding
+ expected_encoded = expected_url.replace(" ", "%20")
+ assert result == expected_encoded
+
+ def test_get_tool_provider_icon_url_api_success(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test successful tool provider icon URL generation for API providers.
+
+ This test verifies:
+ - Proper icon handling for API tool providers
+ - JSON string parsing for icon data
+ - Fallback icon when parsing fails
+ """
+ # Arrange: Setup test data
+ fake = Faker()
+ provider_type = ToolProviderType.API.value
+ provider_name = fake.company()
+ icon = '{"background": "#FF6B6B", "content": "🔧"}'
+
+ # Act: Execute the method under test
+ result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert isinstance(result, dict)
+ assert result["background"] == "#FF6B6B"
+ assert result["content"] == "🔧"
+
+ def test_get_tool_provider_icon_url_api_invalid_json(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test tool provider icon URL generation for API providers with invalid JSON.
+
+ This test verifies:
+ - Proper fallback when JSON parsing fails
+ - Default icon structure when exception occurs
+ """
+ # Arrange: Setup test data with invalid JSON
+ fake = Faker()
+ provider_type = ToolProviderType.API.value
+ provider_name = fake.company()
+ icon = '{"invalid": json}'
+
+ # Act: Execute the method under test
+ result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert isinstance(result, dict)
+ assert result["background"] == "#252525"
+ # Note: emoji characters may be represented as Unicode escape sequences
+ assert result["content"] == "😁" or result["content"] == "\ud83d\ude01"
+
+ def test_get_tool_provider_icon_url_workflow_success(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test successful tool provider icon URL generation for workflow providers.
+
+ This test verifies:
+ - Proper icon handling for workflow tool providers
+ - Direct icon return for workflow type
+ """
+ # Arrange: Setup test data
+ fake = Faker()
+ provider_type = ToolProviderType.WORKFLOW.value
+ provider_name = fake.company()
+ icon = {"background": "#FF6B6B", "content": "🔧"}
+
+ # Act: Execute the method under test
+ result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert isinstance(result, dict)
+ assert result["background"] == "#FF6B6B"
+ assert result["content"] == "🔧"
+
+ def test_get_tool_provider_icon_url_mcp_success(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test successful tool provider icon URL generation for MCP providers.
+
+ This test verifies:
+ - Direct icon return for MCP type
+ - No URL transformation for MCP providers
+ """
+ # Arrange: Setup test data
+ fake = Faker()
+ provider_type = ToolProviderType.MCP.value
+ provider_name = fake.company()
+ icon = {"background": "#FF6B6B", "content": "🔧"}
+
+ # Act: Execute the method under test
+ result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert isinstance(result, dict)
+ assert result["background"] == "#FF6B6B"
+ assert result["content"] == "🔧"
+
+ def test_get_tool_provider_icon_url_unknown_type(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test tool provider icon URL generation for unknown provider types.
+
+ This test verifies:
+ - Empty string return for unknown provider types
+ - Proper handling of unsupported types
+ """
+ # Arrange: Setup test data with unknown type
+ fake = Faker()
+ provider_type = "unknown_type"
+ provider_name = fake.company()
+ icon = "🔧"
+
+ # Act: Execute the method under test
+ result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
+
+ # Assert: Verify the expected outcomes
+ assert result == ""
+
+ def test_repack_provider_dict_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful provider repacking with dictionary input.
+
+ This test verifies:
+ - Proper icon URL generation for dictionary providers
+ - Correct provider type handling
+ - Icon transformation for different provider types
+ """
+ # Arrange: Setup test data
+ fake = Faker()
+ tenant_id = fake.uuid4()
+ provider = {"type": ToolProviderType.BUILT_IN.value, "name": fake.company(), "icon": "🔧"}
+
+ # Act: Execute the method under test
+ ToolTransformService.repack_provider(tenant_id, provider)
+
+ # Assert: Verify the expected outcomes
+ assert "icon" in provider
+ assert isinstance(provider["icon"], str)
+ assert "console/api/workspaces/current/tool-provider/builtin" in provider["icon"]
+ # Note: provider name may contain spaces that get URL encoded
+ assert provider["name"].replace(" ", "%20") in provider["icon"] or provider["name"] in provider["icon"]
+
+ def test_repack_provider_entity_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful provider repacking with ToolProviderApiEntity input.
+
+ This test verifies:
+ - Proper icon URL generation for entity providers
+ - Plugin icon handling when plugin_id is present
+ - Regular icon handling when plugin_id is not present
+ """
+ # Arrange: Setup test data
+ fake = Faker()
+ tenant_id = fake.uuid4()
+
+ # Create provider entity with plugin_id
+ provider = ToolProviderApiEntity(
+ id=fake.uuid4(),
+ author=fake.name(),
+ name=fake.company(),
+ description=I18nObject(en_US=fake.text(max_nb_chars=100)),
+ icon="test_icon.png",
+ icon_dark="test_icon_dark.png",
+ label=I18nObject(en_US=fake.company()),
+ type=ToolProviderType.API,
+ masked_credentials={},
+ is_team_authorization=True,
+ plugin_id="test_plugin_id",
+ tools=[],
+ labels=[],
+ )
+
+ # Act: Execute the method under test
+ ToolTransformService.repack_provider(tenant_id, provider)
+
+ # Assert: Verify the expected outcomes
+ assert provider.icon is not None
+ assert isinstance(provider.icon, str)
+ assert "console/api/workspaces/current/plugin/icon" in provider.icon
+ assert tenant_id in provider.icon
+ assert "test_icon.png" in provider.icon
+
+ # Verify dark icon handling
+ assert provider.icon_dark is not None
+ assert isinstance(provider.icon_dark, str)
+ assert "console/api/workspaces/current/plugin/icon" in provider.icon_dark
+ assert tenant_id in provider.icon_dark
+ assert "test_icon_dark.png" in provider.icon_dark
+
+ def test_repack_provider_entity_no_plugin_success(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test successful provider repacking with ToolProviderApiEntity input without plugin_id.
+
+ This test verifies:
+ - Proper icon URL generation for non-plugin providers
+ - Regular tool provider icon handling
+ - Dark icon handling when present
+ """
+ # Arrange: Setup test data
+ fake = Faker()
+ tenant_id = fake.uuid4()
+
+ # Create provider entity without plugin_id
+ provider = ToolProviderApiEntity(
+ id=fake.uuid4(),
+ author=fake.name(),
+ name=fake.company(),
+ description=I18nObject(en_US=fake.text(max_nb_chars=100)),
+ icon='{"background": "#FF6B6B", "content": "🔧"}',
+ icon_dark='{"background": "#252525", "content": "🔧"}',
+ label=I18nObject(en_US=fake.company()),
+ type=ToolProviderType.API,
+ masked_credentials={},
+ is_team_authorization=True,
+ plugin_id=None,
+ tools=[],
+ labels=[],
+ )
+
+ # Act: Execute the method under test
+ ToolTransformService.repack_provider(tenant_id, provider)
+
+ # Assert: Verify the expected outcomes
+ assert provider.icon is not None
+ assert isinstance(provider.icon, dict)
+ assert provider.icon["background"] == "#FF6B6B"
+ assert provider.icon["content"] == "🔧"
+
+ # Verify dark icon handling
+ assert provider.icon_dark is not None
+ assert isinstance(provider.icon_dark, dict)
+ assert provider.icon_dark["background"] == "#252525"
+ assert provider.icon_dark["content"] == "🔧"
+
+ def test_repack_provider_entity_no_dark_icon(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test provider repacking with ToolProviderApiEntity input without dark icon.
+
+ This test verifies:
+ - Proper handling when icon_dark is None or empty
+ - No errors when dark icon is not present
+ """
+ # Arrange: Setup test data
+ fake = Faker()
+ tenant_id = fake.uuid4()
+
+ # Create provider entity without dark icon
+ provider = ToolProviderApiEntity(
+ id=fake.uuid4(),
+ author=fake.name(),
+ name=fake.company(),
+ description=I18nObject(en_US=fake.text(max_nb_chars=100)),
+ icon='{"background": "#FF6B6B", "content": "🔧"}',
+ icon_dark=None,
+ label=I18nObject(en_US=fake.company()),
+ type=ToolProviderType.API,
+ masked_credentials={},
+ is_team_authorization=True,
+ plugin_id=None,
+ tools=[],
+ labels=[],
+ )
+
+ # Act: Execute the method under test
+ ToolTransformService.repack_provider(tenant_id, provider)
+
+ # Assert: Verify the expected outcomes
+ assert provider.icon is not None
+ assert isinstance(provider.icon, dict)
+ assert provider.icon["background"] == "#FF6B6B"
+ assert provider.icon["content"] == "🔧"
+
+ # Verify dark icon remains None
+ assert provider.icon_dark is None
+
+ def test_builtin_provider_to_user_provider_success(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test successful conversion of builtin provider to user provider.
+
+ This test verifies:
+ - Proper entity creation with all required fields
+ - Credentials schema handling
+ - Team authorization setup
+ - Plugin ID handling
+ """
+ # Arrange: Setup test data
+ fake = Faker()
+
+ # Create mock provider controller
+ mock_controller = Mock()
+ mock_controller.entity.identity.name = fake.company()
+ mock_controller.entity.identity.author = fake.name()
+ mock_controller.entity.identity.description = I18nObject(en_US=fake.text(max_nb_chars=100))
+ mock_controller.entity.identity.icon = "🔧"
+ mock_controller.entity.identity.icon_dark = "🔧"
+ mock_controller.entity.identity.label = I18nObject(en_US=fake.company())
+ mock_controller.plugin_id = None
+ mock_controller.plugin_unique_identifier = None
+ mock_controller.tool_labels = ["label1", "label2"]
+ mock_controller.need_credentials = True
+
+ # Mock credentials schema
+ mock_credential = Mock()
+ mock_credential.to_basic_provider_config.return_value.name = "api_key"
+ mock_controller.get_credentials_schema_by_type.return_value = [mock_credential]
+
+ # Create mock database provider
+ mock_db_provider = Mock()
+ mock_db_provider.credential_type = "api-key"
+ mock_db_provider.tenant_id = fake.uuid4()
+ mock_db_provider.credentials = {"api_key": "encrypted_key"}
+
+ # Mock encryption
+ with patch("services.tools.tools_transform_service.create_provider_encrypter") as mock_encrypter:
+ mock_encrypter_instance = Mock()
+ mock_encrypter_instance.decrypt.return_value = {"api_key": "decrypted_key"}
+ mock_encrypter_instance.mask_tool_credentials.return_value = {"api_key": ""}
+ mock_encrypter.return_value = (mock_encrypter_instance, None)
+
+ # Act: Execute the method under test
+ result = ToolTransformService.builtin_provider_to_user_provider(
+ mock_controller, mock_db_provider, decrypt_credentials=True
+ )
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert result.id == mock_controller.entity.identity.name
+ assert result.author == mock_controller.entity.identity.author
+ assert result.name == mock_controller.entity.identity.name
+ assert result.description == mock_controller.entity.identity.description
+ assert result.icon == mock_controller.entity.identity.icon
+ assert result.icon_dark == mock_controller.entity.identity.icon_dark
+ assert result.label == mock_controller.entity.identity.label
+ assert result.type == ToolProviderType.BUILT_IN
+ assert result.is_team_authorization is True
+ assert result.plugin_id is None
+ assert result.tools == []
+ assert result.labels == ["label1", "label2"]
+ assert result.masked_credentials == {"api_key": ""}
+ assert result.original_credentials == {"api_key": "decrypted_key"}
+
+ def test_builtin_provider_to_user_provider_plugin_success(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test successful conversion of builtin provider to user provider with plugin.
+
+ This test verifies:
+ - Plugin ID and unique identifier handling
+ - Proper entity creation for plugin providers
+ """
+ # Arrange: Setup test data
+ fake = Faker()
+
+ # Create mock provider controller with plugin
+ mock_controller = Mock()
+ mock_controller.entity.identity.name = fake.company()
+ mock_controller.entity.identity.author = fake.name()
+ mock_controller.entity.identity.description = I18nObject(en_US=fake.text(max_nb_chars=100))
+ mock_controller.entity.identity.icon = "🔧"
+ mock_controller.entity.identity.icon_dark = "🔧"
+ mock_controller.entity.identity.label = I18nObject(en_US=fake.company())
+ mock_controller.plugin_id = "test_plugin_id"
+ mock_controller.plugin_unique_identifier = "test_unique_id"
+ mock_controller.tool_labels = ["label1"]
+ mock_controller.need_credentials = False
+
+ # Mock credentials schema
+ mock_credential = Mock()
+ mock_credential.to_basic_provider_config.return_value.name = "api_key"
+ mock_controller.get_credentials_schema_by_type.return_value = [mock_credential]
+
+ # Act: Execute the method under test
+ result = ToolTransformService.builtin_provider_to_user_provider(
+ mock_controller, None, decrypt_credentials=False
+ )
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ # Note: The method checks isinstance(provider_controller, PluginToolProviderController)
+ # Since we're using a Mock, this check will fail, so plugin_id will remain None
+ # In a real test with actual PluginToolProviderController, this would work
+ assert result.is_team_authorization is True
+ assert result.allow_delete is False
+
+ def test_builtin_provider_to_user_provider_no_credentials(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test conversion of builtin provider to user provider without credentials.
+
+ This test verifies:
+ - Proper handling when no credentials are needed
+ - Team authorization setup for no-credentials providers
+ """
+ # Arrange: Setup test data
+ fake = Faker()
+
+ # Create mock provider controller
+ mock_controller = Mock()
+ mock_controller.entity.identity.name = fake.company()
+ mock_controller.entity.identity.author = fake.name()
+ mock_controller.entity.identity.description = I18nObject(en_US=fake.text(max_nb_chars=100))
+ mock_controller.entity.identity.icon = "🔧"
+ mock_controller.entity.identity.icon_dark = "🔧"
+ mock_controller.entity.identity.label = I18nObject(en_US=fake.company())
+ mock_controller.plugin_id = None
+ mock_controller.plugin_unique_identifier = None
+ mock_controller.tool_labels = []
+ mock_controller.need_credentials = False
+
+ # Mock credentials schema
+ mock_credential = Mock()
+ mock_credential.to_basic_provider_config.return_value.name = "api_key"
+ mock_controller.get_credentials_schema_by_type.return_value = [mock_credential]
+
+ # Act: Execute the method under test
+ result = ToolTransformService.builtin_provider_to_user_provider(
+ mock_controller, None, decrypt_credentials=False
+ )
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert result.is_team_authorization is True
+ assert result.allow_delete is False
+ assert result.masked_credentials == {}
+
+ def test_api_provider_to_controller_success(self, db_session_with_containers, mock_external_service_dependencies):
+ """
+ Test successful conversion of API provider to controller.
+
+ This test verifies:
+ - Proper controller creation from database provider
+ - Auth type handling for different credential types
+ - Backward compatibility for auth types
+ """
+ # Arrange: Setup test data
+ fake = Faker()
+
+ # Create API tool provider with api_key_header auth
+ provider = ApiToolProvider(
+ name=fake.company(),
+ description=fake.text(max_nb_chars=100),
+ icon='{"background": "#FF6B6B", "content": "🔧"}',
+ tenant_id=fake.uuid4(),
+ user_id=fake.uuid4(),
+ credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}',
+ schema="{}",
+ schema_type_str="openapi",
+ tools_str="[]",
+ )
+
+ from extensions.ext_database import db
+
+ db.session.add(provider)
+ db.session.commit()
+
+ # Act: Execute the method under test
+ result = ToolTransformService.api_provider_to_controller(provider)
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert hasattr(result, "from_db")
+ # Additional assertions would depend on the actual controller implementation
+
+ def test_api_provider_to_controller_api_key_query(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test conversion of API provider to controller with api_key_query auth type.
+
+ This test verifies:
+ - Proper auth type handling for query parameter authentication
+ """
+ # Arrange: Setup test data
+ fake = Faker()
+
+ # Create API tool provider with api_key_query auth
+ provider = ApiToolProvider(
+ name=fake.company(),
+ description=fake.text(max_nb_chars=100),
+ icon='{"background": "#FF6B6B", "content": "🔧"}',
+ tenant_id=fake.uuid4(),
+ user_id=fake.uuid4(),
+ credentials_str='{"auth_type": "api_key_query", "api_key": "test_key"}',
+ schema="{}",
+ schema_type_str="openapi",
+ tools_str="[]",
+ )
+
+ from extensions.ext_database import db
+
+ db.session.add(provider)
+ db.session.commit()
+
+ # Act: Execute the method under test
+ result = ToolTransformService.api_provider_to_controller(provider)
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert hasattr(result, "from_db")
+
+ def test_api_provider_to_controller_backward_compatibility(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test conversion of API provider to controller with backward compatibility auth types.
+
+ This test verifies:
+ - Proper handling of legacy auth type values
+ - Backward compatibility for api_key and api_key_header
+ """
+ # Arrange: Setup test data
+ fake = Faker()
+
+ # Create API tool provider with legacy auth type
+ provider = ApiToolProvider(
+ name=fake.company(),
+ description=fake.text(max_nb_chars=100),
+ icon='{"background": "#FF6B6B", "content": "🔧"}',
+ tenant_id=fake.uuid4(),
+ user_id=fake.uuid4(),
+ credentials_str='{"auth_type": "api_key", "api_key": "test_key"}',
+ schema="{}",
+ schema_type_str="openapi",
+ tools_str="[]",
+ )
+
+ from extensions.ext_database import db
+
+ db.session.add(provider)
+ db.session.commit()
+
+ # Act: Execute the method under test
+ result = ToolTransformService.api_provider_to_controller(provider)
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert hasattr(result, "from_db")
+
+ def test_workflow_provider_to_controller_success(
+ self, db_session_with_containers, mock_external_service_dependencies
+ ):
+ """
+ Test successful conversion of workflow provider to controller.
+
+ This test verifies:
+ - Proper controller creation from workflow provider
+ - Workflow-specific controller handling
+ """
+ # Arrange: Setup test data
+ fake = Faker()
+
+ # Create workflow tool provider
+ provider = WorkflowToolProvider(
+ name=fake.company(),
+ description=fake.text(max_nb_chars=100),
+ icon='{"background": "#FF6B6B", "content": "🔧"}',
+ tenant_id=fake.uuid4(),
+ user_id=fake.uuid4(),
+ app_id=fake.uuid4(),
+ label="Test Workflow",
+ version="1.0.0",
+ parameter_configuration="[]",
+ )
+
+ from extensions.ext_database import db
+
+ db.session.add(provider)
+ db.session.commit()
+
+ # Mock the WorkflowToolProviderController.from_db method to avoid app dependency
+ with patch("services.tools.tools_transform_service.WorkflowToolProviderController.from_db") as mock_from_db:
+ mock_controller = Mock()
+ mock_from_db.return_value = mock_controller
+
+ # Act: Execute the method under test
+ result = ToolTransformService.workflow_provider_to_controller(provider)
+
+ # Assert: Verify the expected outcomes
+ assert result is not None
+ assert result == mock_controller
+ mock_from_db.assert_called_once_with(provider)
diff --git a/web/__tests__/check-i18n.test.ts b/web/__tests__/check-i18n.test.ts
index b4c4f1540d..b579f22d4b 100644
--- a/web/__tests__/check-i18n.test.ts
+++ b/web/__tests__/check-i18n.test.ts
@@ -621,7 +621,7 @@ export default translation
&& !trimmed.startsWith('//'))
break
}
- else {
+ else {
break
}
diff --git a/web/__tests__/description-validation.test.tsx b/web/__tests__/description-validation.test.tsx
index 85263b035f..a78a4e632e 100644
--- a/web/__tests__/description-validation.test.tsx
+++ b/web/__tests__/description-validation.test.tsx
@@ -60,7 +60,7 @@ describe('Description Validation Logic', () => {
try {
validateDescriptionLength(invalidDescription)
}
- catch (error) {
+ catch (error) {
expect((error as Error).message).toBe(expectedErrorMessage)
}
})
@@ -86,7 +86,7 @@ describe('Description Validation Logic', () => {
expect(() => validateDescriptionLength(testDescription)).not.toThrow()
expect(validateDescriptionLength(testDescription)).toBe(testDescription)
}
- else {
+ else {
expect(() => validateDescriptionLength(testDescription)).toThrow(
'Description cannot exceed 400 characters.',
)
diff --git a/web/__tests__/document-list-sorting.test.tsx b/web/__tests__/document-list-sorting.test.tsx
index 1510dbec23..77c0bb60cf 100644
--- a/web/__tests__/document-list-sorting.test.tsx
+++ b/web/__tests__/document-list-sorting.test.tsx
@@ -39,7 +39,7 @@ describe('Document List Sorting', () => {
const result = aValue.localeCompare(bValue)
return order === 'asc' ? result : -result
}
- else {
+ else {
const result = aValue - bValue
return order === 'asc' ? result : -result
}
diff --git a/web/__tests__/plugin-tool-workflow-error.test.tsx b/web/__tests__/plugin-tool-workflow-error.test.tsx
index 370052bc80..87bda8fa13 100644
--- a/web/__tests__/plugin-tool-workflow-error.test.tsx
+++ b/web/__tests__/plugin-tool-workflow-error.test.tsx
@@ -196,7 +196,7 @@ describe('Plugin Tool Workflow Integration', () => {
const _pluginId = (tool.uniqueIdentifier as any).split(':')[0]
}).toThrow()
}
- else {
+ else {
// Valid tools should work fine
expect(() => {
const _pluginId = tool.uniqueIdentifier.split(':')[0]
diff --git a/web/__tests__/real-browser-flicker.test.tsx b/web/__tests__/real-browser-flicker.test.tsx
index cf3abd5f80..52bdf4777f 100644
--- a/web/__tests__/real-browser-flicker.test.tsx
+++ b/web/__tests__/real-browser-flicker.test.tsx
@@ -252,7 +252,7 @@ describe('Real Browser Environment Dark Mode Flicker Test', () => {
if (hasStyleChange)
console.log('⚠️ Style changes detected - this causes visible flicker')
- else
+ else
console.log('✅ No style changes detected')
expect(timingData.length).toBeGreaterThan(1)
diff --git a/web/__tests__/workflow-parallel-limit.test.tsx b/web/__tests__/workflow-parallel-limit.test.tsx
index 0843122ab4..64e9d328f0 100644
--- a/web/__tests__/workflow-parallel-limit.test.tsx
+++ b/web/__tests__/workflow-parallel-limit.test.tsx
@@ -15,7 +15,7 @@ const originalEnv = process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT
function setupEnvironment(value?: string) {
if (value)
process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT = value
- else
+ else
delete process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT
// Clear module cache to force re-evaluation
@@ -25,7 +25,7 @@ function setupEnvironment(value?: string) {
function restoreEnvironment() {
if (originalEnv)
process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT = originalEnv
- else
+ else
delete process.env.NEXT_PUBLIC_MAX_PARALLEL_LIMIT
jest.resetModules()
diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx
index a3281be8eb..b1e915b2bf 100644
--- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx
+++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/__tests__/svg-attribute-error-reproduction.spec.tsx
@@ -47,7 +47,7 @@ describe('SVG Attribute Error Reproduction', () => {
console.log(` ${index + 1}. ${error.substring(0, 100)}...`)
})
}
- else {
+ else {
console.log('No inkscape errors found in this render')
}
@@ -150,7 +150,7 @@ describe('SVG Attribute Error Reproduction', () => {
if (problematicKeys.length > 0)
console.log(`🚨 PROBLEM: Still found problematic attributes: ${problematicKeys.join(', ')}`)
- else
+ else
console.log('✅ No problematic attributes found after normalization')
})
})
diff --git a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx
index 0408d2ee34..5890c2ea92 100644
--- a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx
+++ b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx
@@ -106,7 +106,7 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => {
onClick={() => {
if (hoverArea === 'right' && !onAvatarError)
setIsShowDeleteConfirm(true)
- else
+ else
setIsShowAvatarPicker(true)
}}
onMouseMove={(e) => {
diff --git a/web/app/components/app-sidebar/basic.tsx b/web/app/components/app-sidebar/basic.tsx
index 00357d6c27..77a965c03e 100644
--- a/web/app/components/app-sidebar/basic.tsx
+++ b/web/app/components/app-sidebar/basic.tsx
@@ -45,8 +45,8 @@ const ICON_MAP = {
,
dataset: