Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine

This commit is contained in:
-LAN- 2025-09-02 11:52:25 +08:00
commit 0b0dc63f29
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF
272 changed files with 3736 additions and 1072 deletions

View File

@ -89,7 +89,9 @@ jobs:
- name: Web style check
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: pnpm run lint
run: |
pnpm run lint
pnpm run eslint
docker-compose-template:
name: Docker Compose Template

View File

@ -44,22 +44,19 @@ def oauth_server_access_token_required(view):
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
raise BadRequest("Invalid oauth_provider_app")
if not request.headers.get("Authorization"):
raise BadRequest("Authorization is required")
authorization_header = request.headers.get("Authorization")
if not authorization_header:
raise BadRequest("Authorization header is required")
parts = authorization_header.split(" ")
parts = authorization_header.strip().split(" ")
if len(parts) != 2:
raise BadRequest("Invalid Authorization header format")
token_type = parts[0]
if token_type != "Bearer":
token_type = parts[0].strip()
if token_type.lower() != "bearer":
raise BadRequest("token_type is invalid")
access_token = parts[1]
access_token = parts[1].strip()
if not access_token:
raise BadRequest("access_token is required")
@ -125,7 +122,10 @@ class OAuthServerUserTokenApi(Resource):
parser.add_argument("refresh_token", type=str, required=False, location="json")
parsed_args = parser.parse_args()
grant_type = OAuthGrantType(parsed_args["grant_type"])
try:
grant_type = OAuthGrantType(parsed_args["grant_type"])
except ValueError:
raise BadRequest("invalid grant_type")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not parsed_args["code"]:
@ -163,8 +163,6 @@ class OAuthServerUserTokenApi(Resource):
"refresh_token": refresh_token,
}
)
else:
raise BadRequest("invalid grant_type")
class OAuthServerUserAccountApi(Resource):

View File

@ -354,9 +354,6 @@ class DatasetInitApi(Resource):
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor:
raise Forbidden()
knowledge_config = KnowledgeConfig(**args)
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:

View File

@ -1,8 +1,12 @@
from base64 import b64encode
from collections.abc import Callable
from functools import wraps
from hashlib import sha1
from hmac import new as hmac_new
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
from flask import abort, request
from configs import dify_config
@ -10,9 +14,9 @@ from extensions.ext_database import db
from models.model import EndUser
def billing_inner_api_only(view):
def billing_inner_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API:
abort(404)
@ -26,9 +30,9 @@ def billing_inner_api_only(view):
return decorated
def enterprise_inner_api_only(view):
def enterprise_inner_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.INNER_API:
abort(404)
@ -78,9 +82,9 @@ def enterprise_inner_api_user_auth(view):
return decorated
def plugin_inner_api_only(view):
def plugin_inner_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args, **kwargs):
def decorated(*args: P.args, **kwargs: P.kwargs):
if not dify_config.PLUGIN_DAEMON_KEY:
abort(404)

View File

@ -1,7 +1,7 @@
import time
from collections.abc import Callable
from datetime import timedelta
from enum import Enum
from enum import StrEnum, auto
from functools import wraps
from typing import Optional
@ -23,14 +23,14 @@ from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService
class WhereisUserArg(Enum):
class WhereisUserArg(StrEnum):
"""
Enum for whereis_user_arg.
"""
QUERY = "query"
JSON = "json"
FORM = "form"
QUERY = auto()
JSON = auto()
FORM = auto()
class FetchUserArg(BaseModel):

View File

@ -334,7 +334,8 @@ class BaseAgentRunner(AppRunner):
"""
Save agent thought
"""
agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first()
stmt = select(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id)
agent_thought = db.session.scalar(stmt)
if not agent_thought:
raise ValueError("agent thought not found")
@ -492,7 +493,8 @@ class BaseAgentRunner(AppRunner):
return result
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
stmt = select(MessageFile).where(MessageFile.message_id == message.id)
files = db.session.scalars(stmt).all()
if not files:
return UserPromptMessage(content=message.query)
if message.app_model_config:

View File

@ -450,6 +450,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
worker_thread.start()
# release database connection, because the following new thread operations may take a long time
db.session.refresh(workflow)
db.session.refresh(message)
db.session.refresh(user)
db.session.close()
# return response or stream generator
response = self._handle_advanced_chat_response(
application_generate_entity=application_generate_entity,

View File

@ -73,7 +73,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
with Session(db.engine, expire_on_commit=False) as session:
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
if not app_record:
raise ValueError("App not found")
@ -147,7 +149,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables),
conversation_variables=conversation_variables,
)
# init graph

View File

@ -68,7 +68,6 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from events.message_event import message_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Conversation, EndUser, Message, MessageFile
@ -886,10 +885,6 @@ class AdvancedChatAppGenerateTaskPipeline:
self._task_state.metadata.usage = usage
else:
self._task_state.metadata.usage = LLMUsage.empty_usage()
message_was_created.send(
message,
application_generate_entity=self._application_generate_entity,
)
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
"""

View File

@ -1,6 +1,8 @@
import logging
from typing import cast
from sqlalchemy import select
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
from core.agent.entities import AgentEntity
@ -44,8 +46,8 @@ class AgentChatAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(AgentChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
app_stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(app_stmt)
if not app_record:
raise ValueError("App not found")
@ -182,11 +184,12 @@ class AgentChatAppRunner(AppRunner):
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first()
conversation_stmt = select(Conversation).where(Conversation.id == conversation.id)
conversation_result = db.session.scalar(conversation_stmt)
if conversation_result is None:
raise ValueError("Conversation not found")
message_result = db.session.query(Message).where(Message.id == message.id).first()
msg_stmt = select(Message).where(Message.id == message.id)
message_result = db.session.scalar(msg_stmt)
if message_result is None:
raise ValueError("Message not found")
db.session.close()

View File

@ -1,7 +1,7 @@
import queue
import time
from abc import abstractmethod
from enum import Enum
from enum import IntEnum, auto
from typing import Any, Optional
from sqlalchemy.orm import DeclarativeMeta
@ -19,9 +19,9 @@ from core.app.entities.queue_entities import (
from extensions.ext_redis import redis_client
class PublishFrom(Enum):
APPLICATION_MANAGER = 1
TASK_PIPELINE = 2
class PublishFrom(IntEnum):
APPLICATION_MANAGER = auto()
TASK_PIPELINE = auto()
class AppQueueManager:

View File

@ -1,6 +1,8 @@
import logging
from typing import cast
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.chat.app_config_manager import ChatAppConfig
@ -42,8 +44,8 @@ class ChatAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(ChatAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
if not app_record:
raise ValueError("App not found")

View File

@ -6,6 +6,7 @@ from typing import Any, Literal, Union, overload
from flask import Flask, copy_current_request_context, current_app
from pydantic import ValidationError
from sqlalchemy import select
from configs import dify_config
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
@ -248,17 +249,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
:param invoke_from: invoke from source
:param stream: is stream
"""
message = (
db.session.query(Message)
.where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
.first()
stmt = select(Message).where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
Message.from_account_id == (user.id if isinstance(user, Account) else None),
)
message = db.session.scalar(stmt)
if not message:
raise MessageNotExistsError()

View File

@ -1,6 +1,8 @@
import logging
from typing import cast
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.base_app_runner import AppRunner
from core.app.apps.completion.app_config_manager import CompletionAppConfig
@ -35,8 +37,8 @@ class CompletionAppRunner(AppRunner):
"""
app_config = application_generate_entity.app_config
app_config = cast(CompletionAppConfig, app_config)
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
stmt = select(App).where(App.id == app_config.app_id)
app_record = db.session.scalar(stmt)
if not app_record:
raise ValueError("App not found")

View File

@ -3,6 +3,9 @@ import logging
from collections.abc import Generator
from typing import Optional, Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager
@ -83,11 +86,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
if conversation:
app_model_config = (
db.session.query(AppModelConfig)
.where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.first()
stmt = select(AppModelConfig).where(
AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
)
app_model_config = db.session.scalar(stmt)
if not app_model_config:
raise AppModelConfigBrokenError()
@ -253,7 +255,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param conversation_id: conversation id
:return: conversation
"""
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
with Session(db.engine, expire_on_commit=False) as session:
conversation = session.scalar(select(Conversation).where(Conversation.id == conversation_id))
if not conversation:
raise ConversationNotExistsError("Conversation not exists")
@ -266,7 +269,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param message_id: message id
:return: message
"""
message = db.session.query(Message).where(Message.id == message_id).first()
with Session(db.engine, expire_on_commit=False) as session:
message = session.scalar(select(Message).where(Message.id == message_id))
if message is None:
raise MessageNotExistsError("Message not exists")

View File

@ -1,6 +1,8 @@
import logging
from typing import Optional
from sqlalchemy import select
from core.app.entities.app_invoke_entities import InvokeFrom
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
@ -25,9 +27,8 @@ class AnnotationReplyFeature:
:param invoke_from: invoke from
:return:
"""
annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first()
)
stmt = select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id)
annotation_setting = db.session.scalar(stmt)
if not annotation_setting:
return None

View File

@ -472,9 +472,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:param event: agent thought event
:return:
"""
agent_thought: Optional[MessageAgentThought] = (
db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
)
with Session(db.engine, expire_on_commit=False) as session:
agent_thought: Optional[MessageAgentThought] = (
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
)
if agent_thought:
return AgentThoughtStreamResponse(

View File

@ -3,6 +3,8 @@ from threading import Thread
from typing import Optional, Union
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import (
@ -84,7 +86,8 @@ class MessageCycleManager:
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context():
# get conversation and message
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
stmt = select(Conversation).where(Conversation.id == conversation_id)
conversation = db.session.scalar(stmt)
if not conversation:
return
@ -143,7 +146,8 @@ class MessageCycleManager:
:param event: event
:return:
"""
message_file = db.session.query(MessageFile).where(MessageFile.id == event.message_file_id).first()
with Session(db.engine, expire_on_commit=False) as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
if message_file and message_file.url is not None:
# get tool file id
@ -183,7 +187,8 @@ class MessageCycleManager:
:param message_id: message id
:return:
"""
message_file = db.session.query(MessageFile).where(MessageFile.id == message_id).first()
with Session(db.engine, expire_on_commit=False) as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
return MessageStreamResponse(

View File

@ -1,6 +1,8 @@
import logging
from collections.abc import Sequence
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
@ -49,7 +51,8 @@ class DatasetIndexToolCallbackHandler:
for document in documents:
if document.metadata is not None:
document_id = document.metadata["document_id"]
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
dataset_document_stmt = select(DatasetDocument).where(DatasetDocument.id == document_id)
dataset_document = db.session.scalar(dataset_document_stmt)
if not dataset_document:
_logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s",
@ -57,15 +60,12 @@ class DatasetIndexToolCallbackHandler:
)
continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = (
db.session.query(ChildChunk)
.where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk:
segment = (
db.session.query(DocumentSegment)

View File

@ -1,5 +1,7 @@
from typing import Optional
from sqlalchemy import select
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
from core.external_data_tool.base import ExternalDataTool
from core.helper import encrypter
@ -28,13 +30,11 @@ class ApiExternalDataTool(ExternalDataTool):
api_based_extension_id = config.get("api_based_extension_id")
if not api_based_extension_id:
raise ValueError("api_based_extension_id is required")
# get api_based_extension
api_based_extension = (
db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
stmt = select(APIBasedExtension).where(
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
)
api_based_extension = db.session.scalar(stmt)
if not api_based_extension:
raise ValueError("api_based_extension_id is invalid")
@ -52,13 +52,11 @@ class ApiExternalDataTool(ExternalDataTool):
raise ValueError(f"config is required, config: {self.config}")
api_based_extension_id = self.config.get("api_based_extension_id")
assert api_based_extension_id is not None, "api_based_extension_id is required"
# get api_based_extension
api_based_extension = (
db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
stmt = select(APIBasedExtension).where(
APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id
)
api_based_extension = db.session.scalar(stmt)
if not api_based_extension:
raise ValueError(

View File

@ -3,7 +3,7 @@ import base64
from libs import rsa
def obfuscated_token(token: str):
def obfuscated_token(token: str) -> str:
if not token:
return token
if len(token) <= 8:
@ -11,6 +11,10 @@ def obfuscated_token(token: str):
return token[:6] + "*" * 12 + token[-2:]
def full_mask_token(token_length=20):
return "*" * token_length
def encrypt_token(tenant_id: str, token: str):
from extensions.ext_database import db
from models.account import Tenant

View File

@ -1,6 +1,6 @@
from collections.abc import Sequence
import requests
import httpx
from yarl import URL
from configs import dify_config
@ -23,7 +23,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
return []
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids})
response = httpx.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
@ -36,7 +36,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
return []
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids})
response = httpx.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
result: list[MarketplacePluginDeclaration] = []
for plugin in response.json()["data"]["plugins"]:
@ -50,5 +50,5 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
def record_install_plugin_event(plugin_unique_identifier: str):
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
response = httpx.post(url, json={"unique_identifier": plugin_unique_identifier})
response.raise_for_status()

View File

@ -8,6 +8,7 @@ import uuid
from typing import Any, Optional, cast
from flask import current_app
from sqlalchemy import select
from sqlalchemy.orm.exc import ObjectDeletedError
from configs import dify_config
@ -56,13 +57,11 @@ class IndexingRunner:
if not dataset:
raise ValueError("no dataset found")
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
stmt = select(DatasetProcessRule).where(
DatasetProcessRule.id == dataset_document.dataset_process_rule_id
)
processing_rule = db.session.scalar(stmt)
if not processing_rule:
raise ValueError("no process rule found")
index_type = dataset_document.doc_form
@ -123,11 +122,8 @@ class IndexingRunner:
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit()
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
processing_rule = db.session.scalar(stmt)
if not processing_rule:
raise ValueError("no process rule found")
@ -208,7 +204,6 @@ class IndexingRunner:
child_documents.append(child_document)
document.children = child_documents
documents.append(document)
# build index
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
@ -310,7 +305,8 @@ class IndexingRunner:
# delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
stmt = select(UploadFile).where(UploadFile.id == upload_file_id)
image_file = db.session.scalar(stmt)
if image_file is None:
continue
try:
@ -339,10 +335,8 @@ class IndexingRunner:
if dataset_document.data_source_type == "upload_file":
if not data_source_info or "upload_file_id" not in data_source_info:
raise ValueError("no upload file found")
file_detail = (
db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
)
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
file_detail = db.session.scalars(stmt).one_or_none()
if file_detail:
extract_setting = ExtractSetting(

View File

@ -110,9 +110,9 @@ class TokenBufferMemory:
else:
message_limit = 500
stmt = stmt.limit(message_limit)
msg_limit_stmt = stmt.limit(message_limit)
messages = db.session.scalars(stmt).all()
messages = db.session.scalars(msg_limit_stmt).all()
# instead of all messages from the conversation, we only need to extract messages
# that belong to the thread of last message

View File

@ -158,8 +158,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast(
Union[LLMResult, Generator],
self._round_robin_invoke(
@ -188,8 +186,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast(
int,
self._round_robin_invoke(
@ -214,8 +210,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast(
TextEmbeddingResult,
self._round_robin_invoke(
@ -237,8 +231,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast(
list[int],
self._round_robin_invoke(
@ -269,8 +261,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
self.model_type_instance = cast(RerankModel, self.model_type_instance)
return cast(
RerankResult,
self._round_robin_invoke(
@ -295,8 +285,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
return cast(
bool,
self._round_robin_invoke(
@ -318,8 +306,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
return cast(
str,
self._round_robin_invoke(
@ -343,8 +329,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return cast(
Iterable[bytes],
self._round_robin_invoke(
@ -404,8 +388,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.get_tts_model_voices(
model=self.model, credentials=self.credentials, language=language
)

View File

@ -1,6 +1,7 @@
from typing import Optional
from pydantic import BaseModel, Field
from sqlalchemy import select
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor
from core.helper.encrypter import decrypt_token
@ -87,10 +88,9 @@ class ApiModeration(Moderation):
@staticmethod
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
extension = (
db.session.query(APIBasedExtension)
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
stmt = select(APIBasedExtension).where(
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
)
extension = db.session.scalar(stmt)
return extension

View File

@ -5,6 +5,7 @@ from typing import Optional
from urllib.parse import urljoin
from opentelemetry.trace import Link, Status, StatusCode
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
@ -260,15 +261,15 @@ class AliyunDataTrace(BaseTraceInstance):
app_id = trace_info.metadata.get("app_id")
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
app = session.query(App).where(App.id == app_id).first()
app_stmt = select(App).where(App.id == app_id)
app = session.scalar(app_stmt)
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).where(Account.id == app.created_by).first()
account_stmt = select(Account).where(Account.id == app.created_by)
service_account = session.scalar(account_stmt)
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = (

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.ops.entities.config_entity import BaseTracingConfig
@ -44,14 +45,15 @@ class BaseTraceInstance(ABC):
"""
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
app = session.query(App).where(App.id == app_id).first()
app_stmt = select(App).where(App.id == app_id)
app = session.scalar(app_stmt)
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
service_account = session.query(Account).where(Account.id == app.created_by).first()
account_stmt = select(Account).where(Account.id == app.created_by)
service_account = session.scalar(account_stmt)
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")

View File

@ -228,9 +228,9 @@ class OpsTraceManager:
if not trace_config_data:
return None
# decrypt_token
app = db.session.query(App).where(App.id == app_id).first()
stmt = select(App).where(App.id == app_id)
app = db.session.scalar(stmt)
if not app:
raise ValueError("App not found")
@ -297,20 +297,19 @@ class OpsTraceManager:
@classmethod
def get_app_config_through_message_id(cls, message_id: str):
app_model_config = None
message_data = db.session.query(Message).where(Message.id == message_id).first()
message_stmt = select(Message).where(Message.id == message_id)
message_data = db.session.scalar(message_stmt)
if not message_data:
return None
conversation_id = message_data.conversation_id
conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
conversation_stmt = select(Conversation).where(Conversation.id == conversation_id)
conversation_data = db.session.scalar(conversation_stmt)
if not conversation_data:
return None
if conversation_data.app_model_config_id:
app_model_config = (
db.session.query(AppModelConfig)
.where(AppModelConfig.id == conversation_data.app_model_config_id)
.first()
)
config_stmt = select(AppModelConfig).where(AppModelConfig.id == conversation_data.app_model_config_id)
app_model_config = db.session.scalar(config_stmt)
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
app_model_config = conversation_data.override_model_configs

View File

@ -1,6 +1,8 @@
from collections.abc import Generator, Mapping
from typing import Optional, Union
from sqlalchemy import select
from controllers.service_api.wraps import create_or_update_end_user_for_user_id
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
@ -191,10 +193,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"""
get the user by user id
"""
user = db.session.query(EndUser).where(EndUser.id == user_id).first()
stmt = select(EndUser).where(EndUser.id == user_id)
user = db.session.scalar(stmt)
if not user:
user = db.session.query(Account).where(Account.id == user_id).first()
stmt = select(Account).where(Account.id == user_id)
user = db.session.scalar(stmt)
if not user:
raise ValueError("user not found")

View File

@ -87,7 +87,6 @@ class PromptMessageUtil:
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT:
content = cast(TextPromptMessageContent, content)
text += content.data
else:
content = cast(ImagePromptMessageContent, content)

View File

@ -2,7 +2,7 @@ import contextlib
import json
from collections import defaultdict
from json import JSONDecodeError
from typing import Any, Optional, cast
from typing import Any, Optional
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
@ -154,8 +154,8 @@ class ProviderManager:
for provider_entity in provider_entities:
# handle include, exclude
if is_filtered(
include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET),
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
data=provider_entity,
name_func=lambda x: x.provider,
):
@ -276,15 +276,11 @@ class ProviderManager:
:param model_type: model type
:return:
"""
# Get the corresponding TenantDefaultModel record
default_model = (
db.session.query(TenantDefaultModel)
.where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
.first()
stmt = select(TenantDefaultModel).where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
default_model = db.session.scalar(stmt)
# If it does not exist, get the first available provider model from get_configurations
# and update the TenantDefaultModel record
@ -367,16 +363,11 @@ class ProviderManager:
model_names = [model.model for model in available_models]
if model not in model_names:
raise ValueError(f"Model {model} does not exist.")
# Get the list of available models from get_configurations and check if it is LLM
default_model = (
db.session.query(TenantDefaultModel)
.where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
.first()
stmt = select(TenantDefaultModel).where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
default_model = db.session.scalar(stmt)
# create or update TenantDefaultModel record
if default_model:
@ -598,16 +589,13 @@ class ProviderManager:
provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
except IntegrityError:
db.session.rollback()
existed_provider_record = (
db.session.query(Provider)
.where(
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
)
.first()
stmt = select(Provider).where(
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
)
existed_provider_record = db.session.scalar(stmt)
if not existed_provider_record:
continue

View File

@ -3,6 +3,7 @@ from typing import Any, Optional
import orjson
from pydantic import BaseModel
from sqlalchemy import select
from configs import dify_config
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
@ -211,11 +212,10 @@ class Jieba(BaseKeyword):
return sorted_chunk_indices[:k]
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
document_segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
.first()
stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id
)
document_segment = db.session.scalar(stmt)
if document_segment:
document_segment.keywords = keywords
db.session.add(document_segment)

View File

@ -3,6 +3,7 @@ from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from flask import Flask, current_app
from sqlalchemy import select
from sqlalchemy.orm import Session, load_only
from configs import dify_config
@ -24,7 +25,7 @@ default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"top_k": 4,
"score_threshold_enabled": False,
}
@ -127,7 +128,8 @@ class RetrievalService:
external_retrieval_model: Optional[dict] = None,
metadata_filtering_conditions: Optional[dict] = None,
):
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
return []
metadata_condition = (
@ -316,10 +318,8 @@ class RetrievalService:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# Handle parent-child documents
child_index_node_id = document.metadata.get("doc_id")
child_chunk = (
db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first()
)
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
child_chunk = db.session.scalar(child_chunk_stmt)
if not child_chunk:
continue
@ -378,17 +378,13 @@ class RetrievalService:
index_node_id = document.metadata.get("doc_id")
if not index_node_id:
continue
segment = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
segment = db.session.scalar(document_segment_stmt)
if not segment:
continue

View File

@ -256,7 +256,7 @@ class AnalyticdbVectorOpenAPI:
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
if match.score >= score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(
@ -293,7 +293,7 @@ class AnalyticdbVectorOpenAPI:
response = self._client.query_collection_data(request)
documents = []
for match in response.body.matches.match:
if match.score > score_threshold:
if match.score >= score_threshold:
metadata = json.loads(match.metadata.get("metadata_"))
metadata["score"] = match.score
doc = Document(

View File

@ -229,7 +229,7 @@ class AnalyticdbVectorBySql:
documents = []
for record in cur:
id, vector, score, page_content, metadata = record
if score > score_threshold:
if score >= score_threshold:
metadata["score"] = score
doc = Document(
page_content=page_content,

View File

@ -157,7 +157,7 @@ class BaiduVector(BaseVector):
if meta is not None:
meta = json.loads(meta)
score = row.get("score", 0.0)
if score > score_threshold:
if score >= score_threshold:
meta["score"] = score
doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
docs.append(doc)

View File

@ -120,7 +120,7 @@ class ChromaVector(BaseVector):
distance = distances[index]
metadata = dict(metadatas[index])
score = 1 - distance
if score > score_threshold:
if score >= score_threshold:
metadata["score"] = score
doc = Document(
page_content=documents[index],

View File

@ -304,7 +304,7 @@ class CouchbaseVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 2)
top_k = kwargs.get("top_k", 4)
try:
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
search_iter = self._scope.search(

View File

@ -216,7 +216,7 @@ class ElasticSearchVector(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold:
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)

View File

@ -127,7 +127,7 @@ class HuaweiCloudVector(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold:
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)

View File

@ -275,7 +275,7 @@ class LindormVectorStore(BaseVector):
docs = []
for doc, score in docs_and_scores:
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
if score > score_threshold:
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)

View File

@ -194,7 +194,7 @@ class OpenGauss(BaseVector):
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score > score_threshold:
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs

View File

@ -211,7 +211,7 @@ class OpenSearchVector(BaseVector):
metadata["score"] = hit["_score"]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if hit["_score"] > score_threshold:
if hit["_score"] >= score_threshold:
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)

View File

@ -261,7 +261,7 @@ class OracleVector(BaseVector):
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score > score_threshold:
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
conn.close()
return docs

View File

@ -202,7 +202,7 @@ class PGVectoRS(BaseVector):
score = 1 - dis
metadata["score"] = score
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if score > score_threshold:
if score >= score_threshold:
doc = Document(page_content=record.text, metadata=metadata)
docs.append(doc)
return docs

View File

@ -195,7 +195,7 @@ class PGVector(BaseVector):
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score > score_threshold:
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs

View File

@ -170,7 +170,7 @@ class VastbaseVector(BaseVector):
metadata, text, distance = record
score = 1 - distance
metadata["score"] = score
if score > score_threshold:
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs

View File

@ -3,7 +3,7 @@ import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Optional, Union
import qdrant_client
from flask import current_app
@ -18,6 +18,7 @@ from qdrant_client.http.models import (
TokenizerType,
)
from qdrant_client.local.qdrant_local import QdrantLocal
from sqlalchemy import select
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@ -369,7 +370,7 @@ class QdrantVector(BaseVector):
continue
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
if result.score > score_threshold:
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
@ -426,7 +427,6 @@ class QdrantVector(BaseVector):
def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
self._client = cast(QdrantLocal, self._client)
self._client._load()
@classmethod
@ -446,11 +446,8 @@ class QdrantVector(BaseVector):
class QdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector:
if dataset.collection_binding_id:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
stmt = select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == dataset.collection_binding_id)
dataset_collection_binding = db.session.scalars(stmt).one_or_none()
if dataset_collection_binding:
collection_name = dataset_collection_binding.collection_name
else:

View File

@ -233,7 +233,7 @@ class RelytVector(BaseVector):
docs = []
for document, score in results:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if 1 - score > score_threshold:
if 1 - score >= score_threshold:
docs.append(document)
return docs

View File

@ -300,7 +300,7 @@ class TableStoreVector(BaseVector):
)
documents = []
for search_hit in search_response.search_hits:
if search_hit.score > score_threshold:
if search_hit.score >= score_threshold:
ots_column_map = {}
for col in search_hit.row[1]:
ots_column_map[col[0]] = col[1]

View File

@ -39,6 +39,9 @@ class TencentConfig(BaseModel):
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
bm25 = BM25Encoder.default("zh")
class TencentVector(BaseVector):
field_id: str = "id"
field_vector: str = "vector"
@ -53,7 +56,6 @@ class TencentVector(BaseVector):
self._dimension = 1024
self._init_database()
self._load_collection()
self._bm25 = BM25Encoder.default("zh")
def _load_collection(self):
"""
@ -186,7 +188,7 @@ class TencentVector(BaseVector):
metadata=metadata,
)
if self._enable_hybrid_search:
doc.__dict__["sparse_vector"] = self._bm25.encode_texts(texts[i])
doc.__dict__["sparse_vector"] = bm25.encode_texts(texts[i])
docs.append(doc)
self._client.upsert(
database_name=self._client_config.database,
@ -264,7 +266,7 @@ class TencentVector(BaseVector):
match=[
KeywordSearch(
field_name="sparse_vector",
data=self._bm25.encode_queries(query),
data=bm25.encode_queries(query),
),
],
rerank=WeightedRerank(
@ -291,7 +293,7 @@ class TencentVector(BaseVector):
score = 1 - result.get("score", 0.0)
else:
score = result.get("score", 0.0)
if score > score_threshold:
if score >= score_threshold:
meta["score"] = score
doc = Document(page_content=result.get(self.field_text), metadata=meta)
docs.append(doc)

View File

@ -20,6 +20,7 @@ from qdrant_client.http.models import (
)
from qdrant_client.local.qdrant_local import QdrantLocal
from requests.auth import HTTPDigestAuth
from sqlalchemy import select
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@ -351,7 +352,7 @@ class TidbOnQdrantVector(BaseVector):
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
# duplicate check score threshold
score_threshold = kwargs.get("score_threshold") or 0.0
if result.score > score_threshold:
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
@ -416,16 +417,12 @@ class TidbOnQdrantVector(BaseVector):
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
tidb_auth_binding = (
db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
)
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
if not tidb_auth_binding:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
db.session.query(TidbAuthBinding)
.where(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none()
)
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
if tidb_auth_binding:
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"

View File

@ -110,7 +110,7 @@ class UpstashVector(BaseVector):
score = record.score
if metadata is not None and text is not None:
metadata["score"] = score
if score > score_threshold:
if score >= score_threshold:
docs.append(Document(page_content=text, metadata=metadata))
return docs

View File

@ -3,6 +3,8 @@ import time
from abc import ABC, abstractmethod
from typing import Any, Optional
from sqlalchemy import select
from configs import dify_config
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
@ -45,11 +47,10 @@ class Vector:
vector_type = self._dataset.index_struct_dict["type"]
else:
if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
whitelist = (
db.session.query(Whitelist)
.where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
.one_or_none()
stmt = select(Whitelist).where(
Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db"
)
whitelist = db.session.scalars(stmt).one_or_none()
if whitelist:
vector_type = VectorType.TIDB_ON_QDRANT

View File

@ -192,7 +192,7 @@ class VikingDBVector(BaseVector):
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
if metadata is not None:
metadata = json.loads(metadata)
if result.score > score_threshold:
if result.score >= score_threshold:
metadata["score"] = result.score
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)

View File

@ -220,7 +220,7 @@ class WeaviateVector(BaseVector):
for doc, score in docs_and_scores:
score_threshold = float(kwargs.get("score_threshold") or 0.0)
# check score threshold
if score > score_threshold:
if score >= score_threshold:
if doc.metadata is not None:
doc.metadata["score"] = score
docs.append(doc)

View File

@ -1,7 +1,7 @@
from collections.abc import Sequence
from typing import Any, Optional
from sqlalchemy import func
from sqlalchemy import func, select
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
@ -41,9 +41,8 @@ class DatasetDocumentStore:
@property
def docs(self) -> dict[str, Document]:
document_segments = (
db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id).all()
)
stmt = select(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id)
document_segments = db.session.scalars(stmt).all()
output = {}
for document_segment in document_segments:
@ -228,10 +227,9 @@ class DatasetDocumentStore:
return data
def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]:
document_segment = (
db.session.query(DocumentSegment)
.where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
.first()
stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id
)
document_segment = db.session.scalar(stmt)
return document_segment

View File

@ -2,7 +2,7 @@
import re
from pathlib import Path
from typing import Optional, cast
from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.extractor.helpers import detect_file_encodings
@ -76,7 +76,7 @@ class MarkdownExtractor(BaseExtractor):
markdown_tups.append((current_header, current_text))
markdown_tups = [
(re.sub(r"#", "", cast(str, key)).strip() if key else None, re.sub(r"<.*?>", "", value))
(re.sub(r"#", "", key).strip() if key else None, re.sub(r"<.*?>", "", value))
for key, value in markdown_tups
]

View File

@ -4,6 +4,7 @@ import operator
from typing import Any, Optional, cast
import requests
from sqlalchemy import select
from configs import dify_config
from core.rag.extractor.extractor_base import BaseExtractor
@ -367,22 +368,17 @@ class NotionExtractor(BaseExtractor):
@classmethod
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.where(
db.and_(
DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
)
)
.first()
stmt = select(DataSourceOauthBinding).where(
DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
)
data_source_binding = db.session.scalar(stmt)
if not data_source_binding:
raise Exception(
f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}"
)
return cast(str, data_source_binding.access_token)
return data_source_binding.access_token

View File

@ -2,7 +2,7 @@
import contextlib
from collections.abc import Iterator
from typing import Optional, cast
from typing import Optional
from core.rag.extractor.blob.blob import Blob
from core.rag.extractor.extractor_base import BaseExtractor
@ -27,7 +27,7 @@ class PdfExtractor(BaseExtractor):
plaintext_file_exists = False
if self._file_cache_key:
with contextlib.suppress(FileNotFoundError):
text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
text = storage.load(self._file_cache_key).decode("utf-8")
plaintext_file_exists = True
return [Document(page_content=text)]
documents = list(self.load())

View File

@ -123,7 +123,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
for result in results:
metadata = result.metadata
metadata["score"] = result.score
if result.score > score_threshold:
if result.score >= score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs

View File

@ -130,13 +130,16 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
if delete_child_chunks:
db.session.query(ChildChunk).where(
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
).delete()
).delete(synchronize_session=False)
db.session.commit()
else:
vector.delete()
if delete_child_chunks:
db.session.query(ChildChunk).where(ChildChunk.dataset_id == dataset.id).delete()
# Use existing compound index: (tenant_id, dataset_id, ...)
db.session.query(ChildChunk).where(
ChildChunk.tenant_id == dataset.tenant_id, ChildChunk.dataset_id == dataset.id
).delete(synchronize_session=False)
db.session.commit()
def retrieve(
@ -162,7 +165,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
for result in results:
metadata = result.metadata
metadata["score"] = result.score
if result.score > score_threshold:
if result.score >= score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs

View File

@ -158,7 +158,7 @@ class QAIndexProcessor(BaseIndexProcessor):
for result in results:
metadata = result.metadata
metadata["score"] = result.score
if result.score > score_threshold:
if result.score >= score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs

View File

@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
from typing import Any, Optional, Union, cast
from flask import Flask, current_app
from sqlalchemy import Float, and_, or_, text
from sqlalchemy import Float, and_, or_, select, text
from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy.orm import Session
@ -65,7 +65,7 @@ default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"top_k": 4,
"score_threshold_enabled": False,
}
@ -135,7 +135,8 @@ class DatasetRetrieval:
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)
# pass if dataset is not available
if not dataset:
@ -240,15 +241,12 @@ class DatasetRetrieval:
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(DatasetDocument)
.where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.first()
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
document = db.session.scalar(dataset_document_stmt)
if dataset and document:
source = RetrievalSourceMetadata(
dataset_id=dataset.id,
@ -327,7 +325,8 @@ class DatasetRetrieval:
if dataset_id:
# get retrieval model config
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)
if dataset:
results = []
if dataset.provider == "external":
@ -514,22 +513,18 @@ class DatasetRetrieval:
dify_documents = [document for document in documents if document.provider == "dify"]
for document in dify_documents:
if document.metadata is not None:
dataset_document = (
db.session.query(DatasetDocument)
.where(DatasetDocument.id == document.metadata["document_id"])
.first()
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == document.metadata["document_id"]
)
dataset_document = db.session.scalar(dataset_document_stmt)
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = (
db.session.query(ChildChunk)
.where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
.first()
child_chunk_stmt = select(ChildChunk).where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
)
child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk:
segment = (
db.session.query(DocumentSegment)
@ -600,7 +595,8 @@ class DatasetRetrieval:
):
with flask_app.app_context():
with Session(db.engine) as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)
if not dataset:
return []
@ -647,7 +643,7 @@ class DatasetRetrieval:
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 2,
top_k=retrieval_model.get("top_k") or 4,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
@ -685,7 +681,8 @@ class DatasetRetrieval:
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(dataset_stmt)
# pass if dataset is not available
if not dataset:
@ -743,7 +740,7 @@ class DatasetRetrieval:
tool = DatasetMultiRetrieverTool.from_dataset(
dataset_ids=[dataset.id for dataset in available_datasets],
tenant_id=tenant_id,
top_k=retrieve_config.top_k or 2,
top_k=retrieve_config.top_k or 4,
score_threshold=retrieve_config.score_threshold,
hit_callbacks=[hit_callback],
return_resource=return_resource,
@ -958,7 +955,8 @@ class DatasetRetrieval:
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
) -> Optional[list[dict[str, Any]]]:
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
metadata_fields = db.session.scalars(metadata_stmt).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
# get metadata model config
if metadata_model_config is None:

View File

@ -1,3 +1,5 @@
from sqlalchemy import select
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.custom_tool.provider import ApiToolProviderController
@ -54,17 +56,13 @@ class ToolLabelManager:
return controller.tool_labels
else:
raise ValueError("Unsupported tool type")
labels = (
db.session.query(ToolLabelBinding.label_name)
.where(
ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type.value,
)
.all()
stmt = select(ToolLabelBinding.label_name).where(
ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type.value,
)
labels = db.session.scalars(stmt).all()
return [label.label_name for label in labels]
return list(labels)
@classmethod
def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:

View File

@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
import sqlalchemy as sa
from pydantic import TypeAdapter
from sqlalchemy import select
from sqlalchemy.orm import Session
from yarl import URL
@ -198,14 +199,11 @@ class ToolManager:
# get specific credentials
if is_valid_uuid(credential_id):
try:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
builtin_provider_stmt = select(BuiltinToolProvider).where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
builtin_provider = db.session.scalar(builtin_provider_stmt)
except Exception as e:
builtin_provider = None
logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
@ -319,11 +317,10 @@ class ToolManager:
),
)
elif provider_type == ToolProviderType.WORKFLOW:
workflow_provider = (
db.session.query(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
workflow_provider_stmt = select(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
)
workflow_provider = db.session.scalar(workflow_provider_stmt)
if workflow_provider is None:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
@ -333,16 +330,13 @@ class ToolManager:
if controller_tools is None or len(controller_tools) == 0:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
return cast(
WorkflowTool,
controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
elif provider_type == ToolProviderType.APP:
raise NotImplementedError("app provider not implemented")
@ -652,8 +646,8 @@ class ToolManager:
for provider in builtin_providers:
# handle include, exclude
if is_filtered(
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider,
name_func=lambda x: x.identity.name,
):

View File

@ -3,6 +3,7 @@ from typing import Any
from flask import Flask, current_app
from pydantic import BaseModel, Field
from sqlalchemy import select
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.model_manager import ModelManager
@ -85,17 +86,14 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
)
.all()
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
)
segments = db.session.scalars(document_segment_stmt).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
@ -112,15 +110,12 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
resource_number = 1
for segment in sorted_segments:
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(Document)
.where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
.first()
document_stmt = select(Document).where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
document = db.session.scalar(document_stmt)
if dataset and document:
source = RetrievalSourceMetadata(
position=resource_number,
@ -162,9 +157,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
hit_callbacks: list[DatasetIndexToolCallbackHandler],
):
with flask_app.app_context():
dataset = (
db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
)
stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id)
dataset = db.session.scalar(stmt)
if not dataset:
return []
@ -181,7 +175,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
retrieval_method="keyword_search",
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 2,
top_k=retrieval_model.get("top_k") or 4,
)
if documents:
all_documents.extend(documents)
@ -192,7 +186,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k") or 2,
top_k=retrieval_model.get("top_k") or 4,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,

View File

@ -13,7 +13,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
name: str = "dataset"
description: str = "use this to retrieve a dataset. "
tenant_id: str
top_k: int = 2
top_k: int = 4
score_threshold: Optional[float] = None
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
return_resource: bool

View File

@ -1,6 +1,7 @@
from typing import Any, Optional, cast
from pydantic import BaseModel, Field
from sqlalchemy import select
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService
@ -56,9 +57,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
)
def _run(self, query: str) -> str:
dataset = (
db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
)
dataset_stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id)
dataset = db.session.scalar(dataset_stmt)
if not dataset:
return ""
@ -188,15 +188,12 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(DatasetDocument) # type: ignore
.where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.first()
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
document = db.session.scalar(dataset_document_stmt) # type: ignore
if dataset and document:
source = RetrievalSourceMetadata(
dataset_id=dataset.id,

View File

@ -3,7 +3,7 @@ from collections.abc import Generator
from datetime import date, datetime
from decimal import Decimal
from mimetypes import guess_extension
from typing import Optional, cast
from typing import Optional
from uuid import UUID
import numpy as np
@ -159,8 +159,7 @@ class ToolFileMessageTransformer:
elif message.type == ToolInvokeMessage.MessageType.JSON:
if isinstance(message.message, ToolInvokeMessage.JsonMessage):
json_msg = cast(ToolInvokeMessage.JsonMessage, message.message)
json_msg.json_object = safe_json_value(json_msg.json_object)
message.message.json_object = safe_json_value(message.message.json_object)
yield message
else:
yield message

View File

@ -129,17 +129,14 @@ class ModelInvocationUtils:
db.session.commit()
try:
response: LLMResult = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
)
except InvokeRateLimitError as e:
raise InvokeModelError(f"Invoke rate limit error: {e}")

View File

@ -1,7 +1,9 @@
import json
import logging
from collections.abc import Generator
from typing import Any, Optional, cast
from typing import Any, Optional
from sqlalchemy import select
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool
@ -133,7 +135,8 @@ class WorkflowTool(Tool):
.first()
)
else:
workflow = db.session.query(Workflow).where(Workflow.app_id == app_id, Workflow.version == version).first()
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
workflow = db.session.scalar(stmt)
if not workflow:
raise ValueError("workflow not found or not published")
@ -144,7 +147,8 @@ class WorkflowTool(Tool):
"""
get the app by app id
"""
app = db.session.query(App).where(App.id == app_id).first()
stmt = select(App).where(App.id == app_id)
app = db.session.scalar(stmt)
if not app:
raise ValueError("app not found")
@ -201,14 +205,14 @@ class WorkflowTool(Tool):
item = self._update_file_mapping(item)
file = build_from_mapping(
mapping=item,
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
tenant_id=str(self.runtime.tenant_id),
)
files.append(file)
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
value = self._update_file_mapping(value)
file = build_from_mapping(
mapping=value,
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
tenant_id=str(self.runtime.tenant_id),
)
files.append(file)

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Annotated, TypeAlias, cast
from typing import Annotated, TypeAlias
from uuid import uuid4
from pydantic import Discriminator, Field, Tag
@ -86,7 +86,7 @@ class SecretVariable(StringVariable):
@property
def log(self) -> str:
return cast(str, encrypter.obfuscated_token(self.value))
return encrypter.obfuscated_token(self.value)
class NoneVariable(NoneSegment, Variable):

View File

@ -0,0 +1,339 @@
"""
QueueBasedGraphEngine - Main orchestrator for queue-based workflow execution.
This engine uses a modular architecture with separated packages following
Domain-Driven Design principles for improved maintainability and testability.
"""
import contextvars
import logging
import queue
from collections.abc import Generator, Mapping
from typing import final
from flask import Flask, current_app
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph_events import (
GraphEngineEvent,
GraphNodeEventBase,
GraphRunAbortedEvent,
GraphRunFailedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
)
from models.enums import UserFrom
from .command_processing import AbortCommandHandler, CommandProcessor
from .domain import ExecutionContext, GraphExecution
from .entities.commands import AbortCommand
from .error_handling import ErrorHandler
from .event_management import EventHandler, EventManager
from .graph_traversal import EdgeProcessor, SkipPropagator
from .layers.base import Layer
from .orchestration import Dispatcher, ExecutionCoordinator
from .protocols.command_channel import CommandChannel
from .response_coordinator import ResponseStreamCoordinator
from .state_management import UnifiedStateManager
from .worker_management import SimpleWorkerPool
logger = logging.getLogger(__name__)
@final
class GraphEngine:
"""
Queue-based graph execution engine.
Uses a modular architecture that delegates responsibilities to specialized
subsystems, following Domain-Driven Design and SOLID principles.
"""
def __init__(
self,
tenant_id: str,
app_id: str,
workflow_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
call_depth: int,
graph: Graph,
graph_config: Mapping[str, object],
graph_runtime_state: GraphRuntimeState,
max_execution_steps: int,
max_execution_time: int,
command_channel: CommandChannel,
min_workers: int | None = None,
max_workers: int | None = None,
scale_up_threshold: int | None = None,
scale_down_idle_time: float | None = None,
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
# === Domain Models ===
# Execution context encapsulates workflow execution metadata
self._execution_context = ExecutionContext(
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
call_depth=call_depth,
max_execution_steps=max_execution_steps,
max_execution_time=max_execution_time,
)
# Graph execution tracks the overall execution state
self._graph_execution = GraphExecution(workflow_id=workflow_id)
# === Core Dependencies ===
# Graph structure and configuration
self._graph = graph
self._graph_config = graph_config
self._graph_runtime_state = graph_runtime_state
self._command_channel = command_channel
# === Worker Management Parameters ===
# Parameters for dynamic worker pool scaling
self._min_workers = min_workers
self._max_workers = max_workers
self._scale_up_threshold = scale_up_threshold
self._scale_down_idle_time = scale_down_idle_time
# === Execution Queues ===
# Queue for nodes ready to execute
self._ready_queue: queue.Queue[str] = queue.Queue()
# Queue for events generated during execution
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
# === State Management ===
# Unified state manager handles all node state transitions and queue operations
self._state_manager = UnifiedStateManager(self._graph, self._ready_queue)
# === Response Coordination ===
# Coordinates response streaming from response nodes
self._response_coordinator = ResponseStreamCoordinator(
variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph
)
# === Event Management ===
# Event manager handles both collection and emission of events
self._event_manager = EventManager()
# === Error Handling ===
# Centralized error handler for graph execution errors
self._error_handler = ErrorHandler(self._graph, self._graph_execution)
# === Graph Traversal Components ===
# Propagates skip status through the graph when conditions aren't met
self._skip_propagator = SkipPropagator(
graph=self._graph,
state_manager=self._state_manager,
)
# Processes edges to determine next nodes after execution
# Also handles conditional branching and route selection
self._edge_processor = EdgeProcessor(
graph=self._graph,
state_manager=self._state_manager,
response_coordinator=self._response_coordinator,
skip_propagator=self._skip_propagator,
)
# === Event Handler Registry ===
# Central registry for handling all node execution events
self._event_handler_registry = EventHandler(
graph=self._graph,
graph_runtime_state=self._graph_runtime_state,
graph_execution=self._graph_execution,
response_coordinator=self._response_coordinator,
event_collector=self._event_manager,
edge_processor=self._edge_processor,
state_manager=self._state_manager,
error_handler=self._error_handler,
)
# === Command Processing ===
# Processes external commands (e.g., abort requests)
self._command_processor = CommandProcessor(
command_channel=self._command_channel,
graph_execution=self._graph_execution,
)
# Register abort command handler
abort_handler = AbortCommandHandler()
self._command_processor.register_handler(
AbortCommand,
abort_handler,
)
# === Worker Pool Setup ===
# Capture Flask app context for worker threads
flask_app: Flask | None = None
try:
app = current_app._get_current_object() # type: ignore
if isinstance(app, Flask):
flask_app = app
except RuntimeError:
pass
# Capture context variables for worker threads
context_vars = contextvars.copy_context()
# Create worker pool for parallel node execution
self._worker_pool = SimpleWorkerPool(
ready_queue=self._ready_queue,
event_queue=self._event_queue,
graph=self._graph,
flask_app=flask_app,
context_vars=context_vars,
min_workers=self._min_workers,
max_workers=self._max_workers,
scale_up_threshold=self._scale_up_threshold,
scale_down_idle_time=self._scale_down_idle_time,
)
# === Orchestration ===
# Coordinates the overall execution lifecycle
self._execution_coordinator = ExecutionCoordinator(
graph_execution=self._graph_execution,
state_manager=self._state_manager,
event_handler=self._event_handler_registry,
event_collector=self._event_manager,
command_processor=self._command_processor,
worker_pool=self._worker_pool,
)
# Dispatches events and manages execution flow
self._dispatcher = Dispatcher(
event_queue=self._event_queue,
event_handler=self._event_handler_registry,
event_collector=self._event_manager,
execution_coordinator=self._execution_coordinator,
max_execution_time=self._execution_context.max_execution_time,
event_emitter=self._event_manager,
)
# === Extensibility ===
# Layers allow plugins to extend engine functionality
self._layers: list[Layer] = []
# === Validation ===
# Ensure all nodes share the same GraphRuntimeState instance
self._validate_graph_state_consistency()
def _validate_graph_state_consistency(self) -> None:
"""Validate that all nodes share the same GraphRuntimeState."""
expected_state_id = id(self._graph_runtime_state)
for node in self._graph.nodes.values():
if id(node.graph_runtime_state) != expected_state_id:
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
def layer(self, layer: Layer) -> "GraphEngine":
"""Add a layer for extending functionality."""
self._layers.append(layer)
return self
def run(self) -> Generator[GraphEngineEvent, None, None]:
"""
Execute the graph using the modular architecture.
Returns:
Generator yielding GraphEngineEvent instances
"""
try:
# Initialize layers
self._initialize_layers()
# Start execution
self._graph_execution.start()
start_event = GraphRunStartedEvent()
yield start_event
# Start subsystems
self._start_execution()
# Yield events as they occur
yield from self._event_manager.emit_events()
# Handle completion
if self._graph_execution.aborted:
abort_reason = "Workflow execution aborted by user command"
if self._graph_execution.error:
abort_reason = str(self._graph_execution.error)
yield GraphRunAbortedEvent(
reason=abort_reason,
outputs=self._graph_runtime_state.outputs,
)
elif self._graph_execution.has_error:
if self._graph_execution.error:
raise self._graph_execution.error
else:
yield GraphRunSucceededEvent(
outputs=self._graph_runtime_state.outputs,
)
except Exception as e:
yield GraphRunFailedEvent(error=str(e))
raise
finally:
self._stop_execution()
def _initialize_layers(self) -> None:
"""Initialize layers with context."""
self._event_manager.set_layers(self._layers)
for layer in self._layers:
try:
layer.initialize(self._graph_runtime_state, self._command_channel)
except Exception as e:
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)
try:
layer.on_graph_start()
except Exception as e:
logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e)
def _start_execution(self) -> None:
"""Start execution subsystems."""
# Start worker pool (it calculates initial workers internally)
self._worker_pool.start()
# Register response nodes
for node in self._graph.nodes.values():
if node.execution_type == NodeExecutionType.RESPONSE:
self._response_coordinator.register(node.id)
# Enqueue root node
root_node = self._graph.root_node
self._state_manager.enqueue_node(root_node.id)
self._state_manager.start_execution(root_node.id)
# Start dispatcher
self._dispatcher.start()
def _stop_execution(self) -> None:
"""Stop execution subsystems."""
self._dispatcher.stop()
self._worker_pool.stop()
# Don't mark complete here as the dispatcher already does it
# Notify layers
logger = logging.getLogger(__name__)
for layer in self._layers:
try:
layer.on_graph_end(self._graph_execution.error)
except Exception as e:
logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e)
# Public property accessors for attributes that need external access
@property
def graph_runtime_state(self) -> GraphRuntimeState:
"""Get the graph runtime state."""
return self._graph_runtime_state

View File

@ -157,7 +157,7 @@ class AgentNode(Node):
messages=message_stream,
tool_info={
"icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
"agent_strategy": self._node_data.agent_strategy_name,
},
parameters_for_log=parameters_for_log,
user_id=self.user_id,
@ -401,8 +401,7 @@ class AgentNode(Node):
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}"
== cast(AgentNodeData, self._node_data).agent_strategy_provider_name
if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name
)
icon = current_plugin.declaration.icon
except StopIteration:

View File

@ -301,12 +301,12 @@ def _extract_text_from_yaml(file_content: bytes) -> str:
encoding = "utf-8"
yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore"))
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e:
# If decoding fails, try with utf-8 as last resort
try:
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore"))
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, yaml.YAMLError):
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e

View File

@ -6,7 +6,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from sqlalchemy import Float, and_, func, or_, text
from sqlalchemy import Float, and_, func, or_, select, text
from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy.orm import sessionmaker
@ -75,7 +75,7 @@ default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"top_k": 4,
"score_threshold_enabled": False,
}
@ -358,15 +358,12 @@ class KnowledgeRetrievalNode(Node):
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
document = (
db.session.query(Document)
.where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
.first()
stmt = select(Document).where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
document = db.session.scalar(stmt)
if dataset and document:
source = {
"metadata": {
@ -505,7 +502,8 @@ class KnowledgeRetrievalNode(Node):
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> list[dict[str, Any]]:
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
metadata_fields = db.session.scalars(stmt).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
if node_data.metadata_model_config is None:
raise ValueError("metadata_model_config is required")

View File

@ -138,7 +138,7 @@ class ParameterExtractorNode(Node):
"""
Run the node.
"""
node_data = cast(ParameterExtractorNodeData, self._node_data)
node_data = self._node_data
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else ""

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@ -103,7 +103,7 @@ class QuestionClassifierNode(Node):
return "1"
def _run(self):
node_data = cast(QuestionClassifierNodeData, self._node_data)
node_data = self._node_data
variable_pool = self.graph_runtime_state.variable_pool
# extract variables

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -61,7 +61,7 @@ class ToolNode(Node):
"""
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
node_data = cast(ToolNodeData, self._node_data)
node_data = self._node_data
# fetch tool icon
tool_info = {

View File

@ -0,0 +1,493 @@
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.enums import (
ErrorStrategy,
NodeType,
SystemVariableKey,
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from factories import file_factory
from models import ToolFile
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .entities import ToolNodeData
from .exc import (
ToolFileError,
ToolNodeError,
ToolParameterError,
)
if TYPE_CHECKING:
from core.workflow.entities import VariablePool
class ToolNode(Node):
"""
Tool Node
"""
node_type = NodeType.TOOL
_node_data: ToolNodeData
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = ToolNodeData.model_validate(data)
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator:
"""
Run the tool node
"""
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
node_data = self._node_data
# fetch tool icon
tool_info = {
"provider_type": node_data.provider_type.value,
"provider_id": node_data.provider_id,
"plugin_unique_identifier": node_data.plugin_unique_identifier,
}
# get tool runtime
try:
from core.tools.tool_manager import ToolManager
# This is an issue that caused problems before.
# Logically, we shouldn't use the node_data.version field for judgment
# But for backward compatibility with historical data
# this version field judgment is still preserved here.
variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version != "1":
variable_pool = self.graph_runtime_state.variable_pool
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self._node_id, self._node_data, self.invoke_from, variable_pool
)
except ToolNodeError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to get tool runtime: {str(e)}",
error_type=type(e).__name__,
)
)
return
# get parameters
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
parameters = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self._node_data,
)
parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=self._node_data,
for_log=True,
)
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
try:
message_stream = ToolEngine.generic_invoke(
tool=tool_runtime,
tool_parameters=parameters,
user_id=self.user_id,
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
app_id=self.app_id,
conversation_id=conversation_id.text if conversation_id else None,
)
except ToolNodeError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to invoke tool: {str(e)}",
error_type=type(e).__name__,
)
)
return
try:
# convert tool messages
yield from self._transform_message(
messages=message_stream,
tool_info=tool_info,
parameters_for_log=parameters_for_log,
user_id=self.user_id,
tenant_id=self.tenant_id,
node_id=self._node_id,
)
except ToolInvokeError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to invoke tool {node_data.provider_name}: {str(e)}",
error_type=type(e).__name__,
)
)
except PluginInvokeError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error="An error occurred in the plugin, "
f"please contact the author of {node_data.provider_name} for help, "
f"error type: {e.get_error_type()}, "
f"error details: {e.get_error_message()}",
error_type=type(e).__name__,
)
)
except PluginDaemonClientSideError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to invoke tool, error: {e.description}",
error_type=type(e).__name__,
)
)
def _generate_parameters(
self,
*,
tool_parameters: Sequence[ToolParameter],
variable_pool: "VariablePool",
node_data: ToolNodeData,
for_log: bool = False,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
Args:
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
variable_pool (VariablePool): The variable pool containing the variables.
node_data (ToolNodeData): The data associated with the tool node.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
"""
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
result: dict[str, Any] = {}
for parameter_name in node_data.tool_parameters:
parameter = tool_parameters_dictionary.get(parameter_name)
if not parameter:
result[parameter_name] = None
continue
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
if variable is None:
if parameter.required:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
continue
parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(tool_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
result[parameter_name] = parameter_value
return result
def _fetch_files(self, variable_pool: "VariablePool") -> list[File]:
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _transform_message(
self,
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
user_id: str,
tenant_id: str,
node_id: str,
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
# transform message and handle file storage
from core.plugin.impl.plugin import PluginInstaller
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
)
text = ""
files: list[File] = []
json: list[dict] = []
variables: dict[str, Any] = {}
for message in message_stream:
if message.type in {
ToolInvokeMessage.MessageType.IMAGE_LINK,
ToolInvokeMessage.MessageType.BINARY_LINK,
ToolInvokeMessage.MessageType.IMAGE,
}:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
url = message.message.text
if message.meta:
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
else:
transfer_method = FileTransferMethod.TOOL_FILE
tool_file_id = str(url).split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
mapping = {
"tool_file_id": tool_file_id,
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
"transfer_method": transfer_method,
"url": url,
}
file = file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
assert message.meta
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
with Session(db.engine) as session:
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
tool_file = session.scalar(stmt)
if tool_file is None:
raise ToolFileError(f"tool file {tool_file_id} not exists")
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append(
file_factory.build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=message.message.text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
# JSON message handling for tool node
if message.message.json_object is not None:
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=stream_text,
is_final=False,
)
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield StreamChunkEvent(
selector=[node_id, variable_name],
chunk=variable_value,
is_final=False,
)
else:
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, dict)
# Validate that meta contains a 'file' key
if "file" not in message.meta:
raise ToolNodeError("File message is missing 'file' key in meta")
# Validate that the file is an instance of File
if not isinstance(message.meta["file"], File):
raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
if message.message.metadata:
icon = tool_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstaller()
plugins = manager.list_plugins(tenant_id)
try:
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
)
icon = current_plugin.declaration.icon
except StopIteration:
pass
icon_dark = None
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
user_id,
tenant_id,
)
if provider.name == dict_metadata["provider"]
)
icon = builtin_tool.icon
icon_dark = builtin_tool.icon_dark
except StopIteration:
pass
dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: list[dict[str, Any]] = []
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json:
json_output.extend(json)
else:
json_output.append({"data": []})
# Send final chunk events for all streamed outputs
# Final chunk for text stream
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk="",
is_final=True,
)
# Final chunks for any streamed variables
for var_name in variables:
yield StreamChunkEvent(
selector=[self._node_id, var_name],
chunk="",
is_final=True,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
},
inputs=parameters_for_log,
)
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
:param graph_config: graph config
:param node_id: node id
:param node_data: node data
:return:
"""
# Create typed NodeData from dict
typed_node_data = ToolNodeData.model_validate(node_data)
result = {}
for parameter_name in typed_node_data.tool_parameters:
input = typed_node_data.tool_parameters[parameter_name]
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
result[parameter_name] = input.value
elif input.type == "constant":
pass
result = {node_id + "." + key: value for key, value in result.items()}
return result
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
@property
def retry(self) -> bool:
return self._node_data.retry_config.retry_enabled

View File

@ -2,7 +2,7 @@ import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, Optional
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
@ -318,7 +318,6 @@ class WorkflowEntry:
# init graph
graph = Graph.init(graph_config=graph_dict, node_factory=node_factory)
node_cls = cast(type[Node], node_cls)
# init workflow run state
node_config = {
"id": node_id,

View File

@ -0,0 +1,445 @@
import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
from core.workflow.nodes import NodeType
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from factories import file_factory
from models.enums import UserFrom
from models.workflow import Workflow
logger = logging.getLogger(__name__)
class WorkflowEntry:
def __init__(
self,
tenant_id: str,
app_id: str,
workflow_id: str,
graph_config: Mapping[str, Any],
graph: Graph,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
call_depth: int,
graph_runtime_state: GraphRuntimeState,
command_channel: Optional[CommandChannel] = None,
) -> None:
"""
Init workflow entry
:param tenant_id: tenant id
:param app_id: app id
:param workflow_id: workflow id
:param workflow_type: workflow type
:param graph_config: workflow graph config
:param graph: workflow graph
:param user_id: user id
:param user_from: user from
:param invoke_from: invoke from
:param call_depth: call depth
:param variable_pool: variable pool
:param graph_runtime_state: pre-created graph runtime state
:param command_channel: command channel for external control (optional, defaults to InMemoryChannel)
:param thread_pool_id: thread pool id
"""
# check call depth
workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
if call_depth > workflow_call_max_depth:
raise ValueError(f"Max workflow call depth {workflow_call_max_depth} reached.")
# Use provided command channel or default to InMemoryChannel
if command_channel is None:
command_channel = InMemoryChannel()
self.command_channel = command_channel
self.graph_engine = GraphEngine(
tenant_id=tenant_id,
app_id=app_id,
workflow_id=workflow_id,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
call_depth=call_depth,
graph=graph,
graph_config=graph_config,
graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
command_channel=command_channel,
)
# Add debug logging layer when in debug mode
if dify_config.DEBUG:
logger.info("Debug mode enabled - adding DebugLoggingLayer to GraphEngine")
debug_layer = DebugLoggingLayer(
level="DEBUG",
include_inputs=True,
include_outputs=True,
include_process_data=False, # Process data can be very verbose
logger_name=f"GraphEngine.Debug.{workflow_id[:8]}", # Use workflow ID prefix for unique logger
)
self.graph_engine.layer(debug_layer)
# Add execution limits layer
limits_layer = ExecutionLimitsLayer(
max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
)
self.graph_engine.layer(limits_layer)
def run(self) -> Generator[GraphEngineEvent, None, None]:
graph_engine = self.graph_engine
try:
# run workflow
generator = graph_engine.run()
yield from generator
except GenerateTaskStoppedError:
pass
except Exception as e:
logger.exception("Unknown Error when workflow entry running")
yield GraphRunFailedEvent(error=str(e))
return
@classmethod
def single_step_run(
cls,
*,
workflow: Workflow,
node_id: str,
user_id: str,
user_inputs: Mapping[str, Any],
variable_pool: VariablePool,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
"""
Single step run workflow node
:param workflow: Workflow instance
:param node_id: node id
:param user_id: user id
:param user_inputs: user inputs
:return:
"""
node_config = workflow.get_node_config_by_id(node_id)
node_config_data = node_config.get("data", {})
# Get node class
node_type = NodeType(node_config_data.get("type"))
node_version = node_config_data.get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
# init graph init params and runtime state
graph_init_params = GraphInitParams(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
graph_config=workflow.graph_dict,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init node factory
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# init graph
graph = Graph.init(graph_config=workflow.graph_dict, node_factory=node_factory)
# init workflow run state
node = node_cls(
id=str(uuid.uuid4()),
config=node_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
node.init_node_data(node_config_data)
try:
# variable selector to variable mapping
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=workflow.graph_dict, config=node_config
)
except NotImplementedError:
variable_mapping = {}
# Loading missing variable from draft var here, and set it into
# variable_pool.
load_into_variable_pool(
variable_loader=variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
user_inputs=user_inputs,
)
cls.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
)
try:
# run node
generator = node.run()
except Exception as e:
logger.exception(
"error while running node, workflow_id=%s, node_id=%s, node_type=%s, node_version=%s",
workflow.id,
node.id,
node.node_type,
node.version(),
)
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
return node, generator
@staticmethod
def _create_single_node_graph(
node_id: str,
node_data: dict[str, Any],
node_width: int = 114,
node_height: int = 514,
) -> dict[str, Any]:
"""
Create a minimal graph structure for testing a single node in isolation.
:param node_id: ID of the target node
:param node_data: configuration data for the target node
:param node_width: width for UI layout (default: 200)
:param node_height: height for UI layout (default: 100)
:return: graph dictionary with start node and target node
"""
node_config = {
"id": node_id,
"width": node_width,
"height": node_height,
"type": "custom",
"data": node_data,
}
start_node_config = {
"id": "start",
"width": node_width,
"height": node_height,
"type": "custom",
"data": {
"type": NodeType.START.value,
"title": "Start",
"desc": "Start",
},
}
return {
"nodes": [start_node_config, node_config],
"edges": [
{
"source": "start",
"target": node_id,
"sourceHandle": "source",
"targetHandle": "target",
}
],
}
@classmethod
def run_free_node(
cls, node_data: dict, node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
"""
Run free node
NOTE: only parameter_extractor/question_classifier are supported
:param node_data: node data
:param node_id: node id
:param tenant_id: tenant id
:param user_id: user id
:param user_inputs: user inputs
:return:
"""
# Create a minimal graph for single node execution
graph_dict = cls._create_single_node_graph(node_id, node_data)
node_type = NodeType(node_data.get("type", ""))
if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}:
raise ValueError(f"Node type {node_type} not supported")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type]["1"]
if not node_cls:
raise ValueError(f"Node class not found for node type {node_type}")
# init variable pool
variable_pool = VariablePool(
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=[],
)
# init graph init params and runtime state
graph_init_params = GraphInitParams(
tenant_id=tenant_id,
app_id="",
workflow_id="",
graph_config=graph_dict,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init node factory
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
# init graph
graph = Graph.init(graph_config=graph_dict, node_factory=node_factory)
node_cls = cast(type[Node], node_cls)
# init workflow run state
node_config = {
"id": node_id,
"data": node_data,
}
node: Node = node_cls(
id=str(uuid.uuid4()),
config=node_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
node.init_node_data(node_data)
try:
# variable selector to variable mapping
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_dict, config=node_config
)
except NotImplementedError:
variable_mapping = {}
cls.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=tenant_id,
)
# run node
generator = node.run()
return node, generator
except Exception as e:
logger.exception(
"error while running node, node_id=%s, node_type=%s, node_version=%s",
node.id,
node.node_type,
node.version(),
)
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
@staticmethod
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
# NOTE(QuantumGhost): Avoid using this function in new code.
# Keep values structured as long as possible and only convert to dict
# immediately before serialization (e.g., JSON serialization) to maintain
# data integrity and type information.
result = WorkflowEntry._handle_special_values(value)
return result if isinstance(result, Mapping) or result is None else dict(result)
@staticmethod
def _handle_special_values(value: Any) -> Any:
if value is None:
return value
if isinstance(value, dict):
res = {}
for k, v in value.items():
res[k] = WorkflowEntry._handle_special_values(v)
return res
if isinstance(value, list):
res_list = []
for item in value:
res_list.append(WorkflowEntry._handle_special_values(item))
return res_list
if isinstance(value, File):
return value.to_dict()
return value
@classmethod
def mapping_user_inputs_to_variable_pool(
cls,
*,
variable_mapping: Mapping[str, Sequence[str]],
user_inputs: Mapping[str, Any],
variable_pool: VariablePool,
tenant_id: str,
) -> None:
# NOTE(QuantumGhost): This logic should remain synchronized with
# the implementation of `load_into_variable_pool`, specifically the logic about
# variable existence checking.
# WARNING(QuantumGhost): The semantics of this method are not clearly defined,
# and multiple parts of the codebase depend on its current behavior.
# Modify with caution.
for node_variable, variable_selector in variable_mapping.items():
# fetch node id and variable key from node_variable
node_variable_list = node_variable.split(".")
if len(node_variable_list) < 1:
raise ValueError(f"Invalid node variable {node_variable}")
node_variable_key = ".".join(node_variable_list[1:])
if (node_variable_key not in user_inputs and node_variable not in user_inputs) and not variable_pool.get(
variable_selector
):
raise ValueError(f"Variable key {node_variable} not found in user inputs.")
# environment variable already exist in variable pool, not from user inputs
if variable_pool.get(variable_selector):
continue
# fetch variable node id from variable selector
variable_node_id = variable_selector[0]
variable_key_list = variable_selector[1:]
variable_key_list = list(variable_key_list)
# get input value
input_value = user_inputs.get(node_variable)
if not input_value:
input_value = user_inputs.get(node_variable_key)
if isinstance(input_value, dict) and "type" in input_value and "transfer_method" in input_value:
input_value = file_factory.build_from_mapping(mapping=input_value, tenant_id=tenant_id)
if (
isinstance(input_value, list)
and all(isinstance(item, dict) for item in input_value)
and all("type" in item and "transfer_method" in item for item in input_value)
):
input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id)
# append variable and value to variable pool
if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID:
variable_pool.add([variable_node_id] + variable_key_list, input_value)

View File

@ -12,6 +12,7 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, Ch
from core.entities.provider_entities import QuotaUnit, SystemConfiguration
from events.message_event import message_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client, redis_fallback
from libs import datetime_utils
from models.model import Message
from models.provider import Provider, ProviderType
@ -19,6 +20,32 @@ from models.provider_ids import ModelProviderID
logger = logging.getLogger(__name__)
# Redis cache key prefix for provider last used timestamps
_PROVIDER_LAST_USED_CACHE_PREFIX = "provider:last_used"
# Default TTL for cache entries (10 minutes)
_CACHE_TTL_SECONDS = 600
LAST_USED_UPDATE_WINDOW_SECONDS = 60 * 5
def _get_provider_cache_key(tenant_id: str, provider_name: str) -> str:
"""Generate Redis cache key for provider last used timestamp."""
return f"{_PROVIDER_LAST_USED_CACHE_PREFIX}:{tenant_id}:{provider_name}"
@redis_fallback(default_return=None)
def _get_last_update_timestamp(cache_key: str) -> Optional[datetime]:
"""Get last update timestamp from Redis cache."""
timestamp_str = redis_client.get(cache_key)
if timestamp_str:
return datetime.fromtimestamp(float(timestamp_str.decode("utf-8")))
return None
@redis_fallback()
def _set_last_update_timestamp(cache_key: str, timestamp: datetime) -> None:
"""Set last update timestamp in Redis cache with TTL."""
redis_client.setex(cache_key, _CACHE_TTL_SECONDS, str(timestamp.timestamp()))
class _ProviderUpdateFilters(BaseModel):
"""Filters for identifying Provider records to update."""
@ -139,7 +166,7 @@ def handle(sender: Message, **kwargs):
provider_name,
)
except Exception as e:
except Exception:
# Log failure with timing and context
duration = time_module.perf_counter() - start_time
@ -215,8 +242,23 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]
# Prepare values dict for SQLAlchemy update
update_values = {}
# updateing to `last_used` is removed due to performance reason.
# ref: https://github.com/langgenius/dify/issues/24526
# NOTE: For frequently used providers under high load, this implementation may experience
# race conditions or update contention despite the time-window optimization:
# 1. Multiple concurrent requests might check the same cache key simultaneously
# 2. Redis cache operations are not atomic with database updates
# 3. Heavy providers could still face database lock contention during peak usage
# The current implementation is acceptable for most scenarios, but future optimization
# considerations could include: batched updates, or async processing.
if values.last_used is not None:
cache_key = _get_provider_cache_key(filters.tenant_id, filters.provider_name)
now = datetime_utils.naive_utc_now()
last_update = _get_last_update_timestamp(cache_key)
if last_update is None or (now - last_update).total_seconds() > LAST_USED_UPDATE_WINDOW_SECONDS:
update_values["last_used"] = values.last_used
_set_last_update_timestamp(cache_key, now)
if values.quota_used is not None:
update_values["quota_used"] = values.quota_used
# Skip the current update operation if no updates are required.

View File

@ -3,7 +3,7 @@ import os
import urllib.parse
import uuid
from collections.abc import Callable, Mapping, Sequence
from typing import Any, cast
from typing import Any
import httpx
from sqlalchemy import select
@ -258,7 +258,6 @@ def _get_remote_file_info(url: str):
mime_type = ""
resp = ssrf_proxy.head(url, follow_redirects=True)
resp = cast(httpx.Response, resp)
if resp.status_code == httpx.codes.OK:
if content_disposition := resp.headers.get("Content-Disposition"):
filename = str(content_disposition.split("filename=")[-1].strip('"'))

View File

@ -17,7 +17,7 @@ class EnvironmentVariableField(fields.Raw):
return {
"id": value.id,
"name": value.name,
"value": encrypter.obfuscated_token(value.value),
"value": encrypter.full_mask_token(),
"value_type": value.value_type.value,
"description": value.description,
}

View File

@ -0,0 +1,32 @@
"""chore: add workflow app log run id index
Revision ID: b95962a3885c
Revises: 0e154742a5fa
Create Date: 2025-08-29 15:34:09.838623
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'b95962a3885c'
down_revision = '8d289573e1da'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_app_logs', schema=None) as batch_op:
batch_op.create_index('workflow_app_log_workflow_run_id_idx', ['workflow_run_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_app_logs', schema=None) as batch_op:
batch_op.drop_index('workflow_app_log_workflow_run_id_idx')
# ### end Alembic commands ###

View File

@ -319,7 +319,7 @@ class MCPToolProvider(Base):
@property
def decrypted_server_url(self) -> str:
return cast(str, encrypter.decrypt_token(self.tenant_id, self.server_url))
return encrypter.decrypt_token(self.tenant_id, self.server_url)
@property
def masked_server_url(self) -> str:

View File

@ -835,6 +835,7 @@ class WorkflowAppLog(Base):
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"),
sa.Index("workflow_app_log_app_idx", "tenant_id", "app_id"),
sa.Index("workflow_app_log_workflow_run_id_idx", "workflow_run_id"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))

View File

@ -146,7 +146,7 @@ class AccountService:
account.last_active_at = naive_utc_now()
db.session.commit()
return cast(Account, account)
return account
@staticmethod
def get_account_jwt_token(account: Account) -> str:
@ -191,7 +191,7 @@ class AccountService:
db.session.commit()
return cast(Account, account)
return account
@staticmethod
def update_account_password(account, password, new_password):
@ -1127,7 +1127,7 @@ class TenantService:
def get_custom_config(tenant_id: str) -> dict:
tenant = db.get_or_404(Tenant, tenant_id)
return cast(dict, tenant.custom_config_dict)
return tenant.custom_config_dict
@staticmethod
def is_owner(account: Account, tenant: Tenant) -> bool:

View File

@ -1,5 +1,5 @@
import uuid
from typing import cast
from typing import Optional
import pandas as pd
from flask_login import current_user
@ -40,7 +40,7 @@ class AppAnnotationService:
if not message:
raise NotFound("Message Not Exists.")
annotation = message.annotation
annotation: Optional[MessageAnnotation] = message.annotation
# save the message annotation
if annotation:
annotation.content = args["answer"]
@ -70,7 +70,7 @@ class AppAnnotationService:
app_id,
annotation_setting.collection_binding_id,
)
return cast(MessageAnnotation, annotation)
return annotation
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:

View File

@ -1149,7 +1149,7 @@ class DocumentService:
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"top_k": 4,
"score_threshold_enabled": False,
}
@ -1612,7 +1612,7 @@ class DocumentService:
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
reranking_enable=False,
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
top_k=2,
top_k=4,
score_threshold_enabled=False,
)
# save dataset

View File

@ -18,7 +18,7 @@ default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"top_k": 4,
"score_threshold_enabled": False,
}
@ -66,7 +66,7 @@ class HitTestingService:
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
dataset_id=dataset.id,
query=query,
top_k=retrieval_model.get("top_k", 2),
top_k=retrieval_model.get("top_k", 4),
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,

View File

@ -112,7 +112,9 @@ class MessageService:
base_query = base_query.where(Message.conversation_id == conversation.id)
# Check if include_ids is not None and not empty to avoid WHERE false condition
if include_ids is not None and len(include_ids) > 0:
if include_ids is not None:
if len(include_ids) == 0:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
base_query = base_query.where(Message.id.in_(include_ids))
if last_id:

View File

@ -27,73 +27,73 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
documents = []
start_at = time.perf_counter()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset is None:
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
db.session.close()
return
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
count = len(document_ids)
if features.billing.subscription.plan == "sandbox" and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
if 0 < vector_space.limit <= vector_space.size:
raise ValueError(
"Your total number of documents plus the number of uploads have over the limit of "
"your subscription."
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset is None:
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
db.session.close()
return
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
count = len(document_ids)
if features.billing.subscription.plan == "sandbox" and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
if 0 < vector_space.limit <= vector_space.size:
raise ValueError(
"Your total number of documents plus the number of uploads have over the limit of "
"your subscription."
)
except Exception as e:
for document_id in document_ids:
document = (
db.session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
except Exception as e:
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
return
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
# clean old data
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()
return
finally:
db.session.close()
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
# clean old data
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()

View File

@ -1,7 +1,6 @@
import time
import uuid
from os import getenv
from typing import cast
import pytest
@ -11,7 +10,6 @@ from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
@ -242,8 +240,6 @@ def test_execute_code_output_validator_depth():
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
}
node._node_data = cast(CodeNodeData, node._node_data)
# validate
node._transform_result(result, node._node_data.outputs)
@ -338,8 +334,6 @@ def test_execute_code_output_object_list():
]
}
node._node_data = cast(CodeNodeData, node._node_data)
# validate
node._transform_result(result, node._node_data.outputs)

View File

@ -0,0 +1,390 @@
import time
import uuid
from os import getenv
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.node_factory import DifyNodeFactory
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000"))
def init_code_node(code_config: dict):
graph_config = {
"edges": [
{
"id": "start-source-code-target",
"source": "start",
"target": "code",
},
],
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, code_config],
}
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
# construct variable pool
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="aaa", files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
variable_pool.add(["code", "args1"], 1)
variable_pool.add(["code", "args2"], 2)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# Create node factory
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
node = CodeNode(
id=str(uuid.uuid4()),
config=code_config,
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
# Initialize node data
if "data" in code_config:
node.init_node_data(code_config["data"])
return node
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
def test_execute_code(setup_code_executor_mock):
code = """
def main(args1: int, args2: int) -> dict:
return {
"result": args1 + args2,
}
"""
# trim first 4 spaces at the beginning of each line
code = "\n".join([line[4:] for line in code.split("\n")])
code_config = {
"id": "code",
"data": {
"outputs": {
"result": {
"type": "number",
},
},
"title": "123",
"variables": [
{
"variable": "args1",
"value_selector": ["1", "args1"],
},
{"variable": "args2", "value_selector": ["1", "args2"]},
],
"answer": "123",
"code_language": "python3",
"code": code,
},
}
node = init_code_node(code_config)
node.graph_runtime_state.variable_pool.add(["1", "args1"], 1)
node.graph_runtime_state.variable_pool.add(["1", "args2"], 2)
# execute node
result = node._run()
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs is not None
assert result.outputs["result"] == 3
assert result.error == ""
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
def test_execute_code_output_validator(setup_code_executor_mock):
code = """
def main(args1: int, args2: int) -> dict:
return {
"result": args1 + args2,
}
"""
# trim first 4 spaces at the beginning of each line
code = "\n".join([line[4:] for line in code.split("\n")])
code_config = {
"id": "code",
"data": {
"outputs": {
"result": {
"type": "string",
},
},
"title": "123",
"variables": [
{
"variable": "args1",
"value_selector": ["1", "args1"],
},
{"variable": "args2", "value_selector": ["1", "args2"]},
],
"answer": "123",
"code_language": "python3",
"code": code,
},
}
node = init_code_node(code_config)
node.graph_runtime_state.variable_pool.add(["1", "args1"], 1)
node.graph_runtime_state.variable_pool.add(["1", "args2"], 2)
# execute node
result = node._run()
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert result.error == "Output variable `result` must be a string"
def test_execute_code_output_validator_depth():
code = """
def main(args1: int, args2: int) -> dict:
return {
"result": {
"result": args1 + args2,
}
}
"""
# trim first 4 spaces at the beginning of each line
code = "\n".join([line[4:] for line in code.split("\n")])
code_config = {
"id": "code",
"data": {
"outputs": {
"string_validator": {
"type": "string",
},
"number_validator": {
"type": "number",
},
"number_array_validator": {
"type": "array[number]",
},
"string_array_validator": {
"type": "array[string]",
},
"object_validator": {
"type": "object",
"children": {
"result": {
"type": "number",
},
"depth": {
"type": "object",
"children": {
"depth": {
"type": "object",
"children": {
"depth": {
"type": "number",
}
},
}
},
},
},
},
},
"title": "123",
"variables": [
{
"variable": "args1",
"value_selector": ["1", "args1"],
},
{"variable": "args2", "value_selector": ["1", "args2"]},
],
"answer": "123",
"code_language": "python3",
"code": code,
},
}
node = init_code_node(code_config)
# construct result
result = {
"number_validator": 1,
"string_validator": "1",
"number_array_validator": [1, 2, 3, 3.333],
"string_array_validator": ["1", "2", "3"],
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
}
# validate
node._transform_result(result, node._node_data.outputs)
# construct result
result = {
"number_validator": "1",
"string_validator": 1,
"number_array_validator": ["1", "2", "3", "3.333"],
"string_array_validator": [1, 2, 3],
"object_validator": {"result": "1", "depth": {"depth": {"depth": "1"}}},
}
# validate
with pytest.raises(ValueError):
node._transform_result(result, node._node_data.outputs)
# construct result
result = {
"number_validator": 1,
"string_validator": (CODE_MAX_STRING_LENGTH + 1) * "1",
"number_array_validator": [1, 2, 3, 3.333],
"string_array_validator": ["1", "2", "3"],
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
}
# validate
with pytest.raises(ValueError):
node._transform_result(result, node._node_data.outputs)
# construct result
result = {
"number_validator": 1,
"string_validator": "1",
"number_array_validator": [1, 2, 3, 3.333] * 2000,
"string_array_validator": ["1", "2", "3"],
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
}
# validate
with pytest.raises(ValueError):
node._transform_result(result, node._node_data.outputs)
def test_execute_code_output_object_list():
code = """
def main(args1: int, args2: int) -> dict:
return {
"result": {
"result": args1 + args2,
}
}
"""
# trim first 4 spaces at the beginning of each line
code = "\n".join([line[4:] for line in code.split("\n")])
code_config = {
"id": "code",
"data": {
"outputs": {
"object_list": {
"type": "array[object]",
},
},
"title": "123",
"variables": [
{
"variable": "args1",
"value_selector": ["1", "args1"],
},
{"variable": "args2", "value_selector": ["1", "args2"]},
],
"answer": "123",
"code_language": "python3",
"code": code,
},
}
node = init_code_node(code_config)
# construct result
result = {
"object_list": [
{
"result": 1,
},
{
"result": 2,
},
{
"result": [1, 2, 3],
},
]
}
# validate
node._transform_result(result, node._node_data.outputs)
# construct result
result = {
"object_list": [
{
"result": 1,
},
{
"result": 2,
},
{
"result": [1, 2, 3],
},
1,
]
}
# validate
with pytest.raises(ValueError):
node._transform_result(result, node._node_data.outputs)
def test_execute_code_scientific_notation():
code = """
def main() -> dict:
return {
"result": -8.0E-5
}
"""
code = "\n".join([line[4:] for line in code.split("\n")])
code_config = {
"id": "code",
"data": {
"outputs": {
"result": {
"type": "number",
},
},
"title": "123",
"variables": [],
"answer": "123",
"code_language": "python3",
"code": code,
},
}
node = init_code_node(code_config)
# execute node
result = node._run()
assert isinstance(result, NodeRunResult)
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED

View File

@ -0,0 +1,788 @@
from unittest.mock import Mock, patch
import pytest
from faker import Faker
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from services.tools.tools_transform_service import ToolTransformService
class TestToolTransformService:
"""Integration tests for ToolTransformService using testcontainers."""
@pytest.fixture
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.tools.tools_transform_service.dify_config") as mock_dify_config,
):
# Setup default mock returns
mock_dify_config.CONSOLE_API_URL = "https://console.example.com"
yield {
"dify_config": mock_dify_config,
}
def _create_test_tool_provider(
self, db_session_with_containers, mock_external_service_dependencies, provider_type="api"
):
"""
Helper method to create a test tool provider for testing.
Args:
db_session_with_containers: Database session from testcontainers infrastructure
mock_external_service_dependencies: Mock dependencies
provider_type: Type of provider to create
Returns:
Tool provider instance
"""
fake = Faker()
if provider_type == "api":
provider = ApiToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
icon='{"background": "#FF6B6B", "content": "🔧"}',
icon_dark='{"background": "#252525", "content": "🔧"}',
tenant_id="test_tenant_id",
user_id="test_user_id",
credentials={"auth_type": "api_key_header", "api_key": "test_key"},
provider_type="api",
)
elif provider_type == "builtin":
provider = BuiltinToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
icon="🔧",
icon_dark="🔧",
tenant_id="test_tenant_id",
provider="test_provider",
credential_type="api_key",
credentials={"api_key": "test_key"},
)
elif provider_type == "workflow":
provider = WorkflowToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
icon='{"background": "#FF6B6B", "content": "🔧"}',
icon_dark='{"background": "#252525", "content": "🔧"}',
tenant_id="test_tenant_id",
user_id="test_user_id",
workflow_id="test_workflow_id",
)
elif provider_type == "mcp":
provider = MCPToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
provider_icon='{"background": "#FF6B6B", "content": "🔧"}',
tenant_id="test_tenant_id",
user_id="test_user_id",
server_url="https://mcp.example.com",
server_identifier="test_server",
tools='[{"name": "test_tool", "description": "Test tool"}]',
authed=True,
)
else:
raise ValueError(f"Unknown provider type: {provider_type}")
from extensions.ext_database import db
db.session.add(provider)
db.session.commit()
return provider
def test_get_plugin_icon_url_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful plugin icon URL generation.
This test verifies:
- Proper URL construction for plugin icons
- Correct tenant_id and filename handling
- URL format compliance
"""
# Arrange: Setup test data
fake = Faker()
tenant_id = fake.uuid4()
filename = "test_icon.png"
# Act: Execute the method under test
result = ToolTransformService.get_plugin_icon_url(tenant_id, filename)
# Assert: Verify the expected outcomes
assert result is not None
assert isinstance(result, str)
assert "console/api/workspaces/current/plugin/icon" in result
assert tenant_id in result
assert filename in result
assert result.startswith("https://console.example.com")
# Verify URL structure
expected_url = f"https://console.example.com/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={filename}"
assert result == expected_url
def test_get_plugin_icon_url_with_empty_console_url(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test plugin icon URL generation when CONSOLE_API_URL is empty.
This test verifies:
- Fallback to relative URL when CONSOLE_API_URL is None
- Proper URL construction with relative path
"""
# Arrange: Setup mock with empty console URL
mock_external_service_dependencies["dify_config"].CONSOLE_API_URL = None
fake = Faker()
tenant_id = fake.uuid4()
filename = "test_icon.png"
# Act: Execute the method under test
result = ToolTransformService.get_plugin_icon_url(tenant_id, filename)
# Assert: Verify the expected outcomes
assert result is not None
assert isinstance(result, str)
assert result.startswith("/console/api/workspaces/current/plugin/icon")
assert tenant_id in result
assert filename in result
# Verify URL structure
expected_url = f"/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={filename}"
assert result == expected_url
def test_get_tool_provider_icon_url_builtin_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful tool provider icon URL generation for builtin providers.
This test verifies:
- Proper URL construction for builtin tool providers
- Correct provider type handling
- URL format compliance
"""
# Arrange: Setup test data
fake = Faker()
provider_type = ToolProviderType.BUILT_IN.value
provider_name = fake.company()
icon = "🔧"
# Act: Execute the method under test
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
# Assert: Verify the expected outcomes
assert result is not None
assert isinstance(result, str)
assert "console/api/workspaces/current/tool-provider/builtin" in result
# Note: provider_name may contain spaces that get URL encoded
assert provider_name.replace(" ", "%20") in result or provider_name in result
assert result.endswith("/icon")
assert result.startswith("https://console.example.com")
# Verify URL structure (accounting for URL encoding)
# The actual result will have URL-encoded spaces (%20), so we need to compare accordingly
expected_url = (
f"https://console.example.com/console/api/workspaces/current/tool-provider/builtin/{provider_name}/icon"
)
# Convert expected URL to match the actual URL encoding
expected_encoded = expected_url.replace(" ", "%20")
assert result == expected_encoded
def test_get_tool_provider_icon_url_api_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful tool provider icon URL generation for API providers.
This test verifies:
- Proper icon handling for API tool providers
- JSON string parsing for icon data
- Fallback icon when parsing fails
"""
# Arrange: Setup test data
fake = Faker()
provider_type = ToolProviderType.API.value
provider_name = fake.company()
icon = '{"background": "#FF6B6B", "content": "🔧"}'
# Act: Execute the method under test
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
# Assert: Verify the expected outcomes
assert result is not None
assert isinstance(result, dict)
assert result["background"] == "#FF6B6B"
assert result["content"] == "🔧"
def test_get_tool_provider_icon_url_api_invalid_json(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tool provider icon URL generation for API providers with invalid JSON.
This test verifies:
- Proper fallback when JSON parsing fails
- Default icon structure when exception occurs
"""
# Arrange: Setup test data with invalid JSON
fake = Faker()
provider_type = ToolProviderType.API.value
provider_name = fake.company()
icon = '{"invalid": json}'
# Act: Execute the method under test
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
# Assert: Verify the expected outcomes
assert result is not None
assert isinstance(result, dict)
assert result["background"] == "#252525"
# Note: emoji characters may be represented as Unicode escape sequences
assert result["content"] == "😁" or result["content"] == "\ud83d\ude01"
def test_get_tool_provider_icon_url_workflow_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful tool provider icon URL generation for workflow providers.
This test verifies:
- Proper icon handling for workflow tool providers
- Direct icon return for workflow type
"""
# Arrange: Setup test data
fake = Faker()
provider_type = ToolProviderType.WORKFLOW.value
provider_name = fake.company()
icon = {"background": "#FF6B6B", "content": "🔧"}
# Act: Execute the method under test
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
# Assert: Verify the expected outcomes
assert result is not None
assert isinstance(result, dict)
assert result["background"] == "#FF6B6B"
assert result["content"] == "🔧"
def test_get_tool_provider_icon_url_mcp_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful tool provider icon URL generation for MCP providers.
This test verifies:
- Direct icon return for MCP type
- No URL transformation for MCP providers
"""
# Arrange: Setup test data
fake = Faker()
provider_type = ToolProviderType.MCP.value
provider_name = fake.company()
icon = {"background": "#FF6B6B", "content": "🔧"}
# Act: Execute the method under test
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
# Assert: Verify the expected outcomes
assert result is not None
assert isinstance(result, dict)
assert result["background"] == "#FF6B6B"
assert result["content"] == "🔧"
def test_get_tool_provider_icon_url_unknown_type(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test tool provider icon URL generation for unknown provider types.
This test verifies:
- Empty string return for unknown provider types
- Proper handling of unsupported types
"""
# Arrange: Setup test data with unknown type
fake = Faker()
provider_type = "unknown_type"
provider_name = fake.company()
icon = "🔧"
# Act: Execute the method under test
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
# Assert: Verify the expected outcomes
assert result == ""
def test_repack_provider_dict_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful provider repacking with dictionary input.
This test verifies:
- Proper icon URL generation for dictionary providers
- Correct provider type handling
- Icon transformation for different provider types
"""
# Arrange: Setup test data
fake = Faker()
tenant_id = fake.uuid4()
provider = {"type": ToolProviderType.BUILT_IN.value, "name": fake.company(), "icon": "🔧"}
# Act: Execute the method under test
ToolTransformService.repack_provider(tenant_id, provider)
# Assert: Verify the expected outcomes
assert "icon" in provider
assert isinstance(provider["icon"], str)
assert "console/api/workspaces/current/tool-provider/builtin" in provider["icon"]
# Note: provider name may contain spaces that get URL encoded
assert provider["name"].replace(" ", "%20") in provider["icon"] or provider["name"] in provider["icon"]
def test_repack_provider_entity_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful provider repacking with ToolProviderApiEntity input.
This test verifies:
- Proper icon URL generation for entity providers
- Plugin icon handling when plugin_id is present
- Regular icon handling when plugin_id is not present
"""
# Arrange: Setup test data
fake = Faker()
tenant_id = fake.uuid4()
# Create provider entity with plugin_id
provider = ToolProviderApiEntity(
id=fake.uuid4(),
author=fake.name(),
name=fake.company(),
description=I18nObject(en_US=fake.text(max_nb_chars=100)),
icon="test_icon.png",
icon_dark="test_icon_dark.png",
label=I18nObject(en_US=fake.company()),
type=ToolProviderType.API,
masked_credentials={},
is_team_authorization=True,
plugin_id="test_plugin_id",
tools=[],
labels=[],
)
# Act: Execute the method under test
ToolTransformService.repack_provider(tenant_id, provider)
# Assert: Verify the expected outcomes
assert provider.icon is not None
assert isinstance(provider.icon, str)
assert "console/api/workspaces/current/plugin/icon" in provider.icon
assert tenant_id in provider.icon
assert "test_icon.png" in provider.icon
# Verify dark icon handling
assert provider.icon_dark is not None
assert isinstance(provider.icon_dark, str)
assert "console/api/workspaces/current/plugin/icon" in provider.icon_dark
assert tenant_id in provider.icon_dark
assert "test_icon_dark.png" in provider.icon_dark
def test_repack_provider_entity_no_plugin_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful provider repacking with ToolProviderApiEntity input without plugin_id.
This test verifies:
- Proper icon URL generation for non-plugin providers
- Regular tool provider icon handling
- Dark icon handling when present
"""
# Arrange: Setup test data
fake = Faker()
tenant_id = fake.uuid4()
# Create provider entity without plugin_id
provider = ToolProviderApiEntity(
id=fake.uuid4(),
author=fake.name(),
name=fake.company(),
description=I18nObject(en_US=fake.text(max_nb_chars=100)),
icon='{"background": "#FF6B6B", "content": "🔧"}',
icon_dark='{"background": "#252525", "content": "🔧"}',
label=I18nObject(en_US=fake.company()),
type=ToolProviderType.API,
masked_credentials={},
is_team_authorization=True,
plugin_id=None,
tools=[],
labels=[],
)
# Act: Execute the method under test
ToolTransformService.repack_provider(tenant_id, provider)
# Assert: Verify the expected outcomes
assert provider.icon is not None
assert isinstance(provider.icon, dict)
assert provider.icon["background"] == "#FF6B6B"
assert provider.icon["content"] == "🔧"
# Verify dark icon handling
assert provider.icon_dark is not None
assert isinstance(provider.icon_dark, dict)
assert provider.icon_dark["background"] == "#252525"
assert provider.icon_dark["content"] == "🔧"
def test_repack_provider_entity_no_dark_icon(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test provider repacking with ToolProviderApiEntity input without dark icon.
This test verifies:
- Proper handling when icon_dark is None or empty
- No errors when dark icon is not present
"""
# Arrange: Setup test data
fake = Faker()
tenant_id = fake.uuid4()
# Create provider entity without dark icon
provider = ToolProviderApiEntity(
id=fake.uuid4(),
author=fake.name(),
name=fake.company(),
description=I18nObject(en_US=fake.text(max_nb_chars=100)),
icon='{"background": "#FF6B6B", "content": "🔧"}',
icon_dark=None,
label=I18nObject(en_US=fake.company()),
type=ToolProviderType.API,
masked_credentials={},
is_team_authorization=True,
plugin_id=None,
tools=[],
labels=[],
)
# Act: Execute the method under test
ToolTransformService.repack_provider(tenant_id, provider)
# Assert: Verify the expected outcomes
assert provider.icon is not None
assert isinstance(provider.icon, dict)
assert provider.icon["background"] == "#FF6B6B"
assert provider.icon["content"] == "🔧"
# Verify dark icon remains None
assert provider.icon_dark is None
def test_builtin_provider_to_user_provider_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful conversion of builtin provider to user provider.
This test verifies:
- Proper entity creation with all required fields
- Credentials schema handling
- Team authorization setup
- Plugin ID handling
"""
# Arrange: Setup test data
fake = Faker()
# Create mock provider controller
mock_controller = Mock()
mock_controller.entity.identity.name = fake.company()
mock_controller.entity.identity.author = fake.name()
mock_controller.entity.identity.description = I18nObject(en_US=fake.text(max_nb_chars=100))
mock_controller.entity.identity.icon = "🔧"
mock_controller.entity.identity.icon_dark = "🔧"
mock_controller.entity.identity.label = I18nObject(en_US=fake.company())
mock_controller.plugin_id = None
mock_controller.plugin_unique_identifier = None
mock_controller.tool_labels = ["label1", "label2"]
mock_controller.need_credentials = True
# Mock credentials schema
mock_credential = Mock()
mock_credential.to_basic_provider_config.return_value.name = "api_key"
mock_controller.get_credentials_schema_by_type.return_value = [mock_credential]
# Create mock database provider
mock_db_provider = Mock()
mock_db_provider.credential_type = "api-key"
mock_db_provider.tenant_id = fake.uuid4()
mock_db_provider.credentials = {"api_key": "encrypted_key"}
# Mock encryption
with patch("services.tools.tools_transform_service.create_provider_encrypter") as mock_encrypter:
mock_encrypter_instance = Mock()
mock_encrypter_instance.decrypt.return_value = {"api_key": "decrypted_key"}
mock_encrypter_instance.mask_tool_credentials.return_value = {"api_key": ""}
mock_encrypter.return_value = (mock_encrypter_instance, None)
# Act: Execute the method under test
result = ToolTransformService.builtin_provider_to_user_provider(
mock_controller, mock_db_provider, decrypt_credentials=True
)
# Assert: Verify the expected outcomes
assert result is not None
assert result.id == mock_controller.entity.identity.name
assert result.author == mock_controller.entity.identity.author
assert result.name == mock_controller.entity.identity.name
assert result.description == mock_controller.entity.identity.description
assert result.icon == mock_controller.entity.identity.icon
assert result.icon_dark == mock_controller.entity.identity.icon_dark
assert result.label == mock_controller.entity.identity.label
assert result.type == ToolProviderType.BUILT_IN
assert result.is_team_authorization is True
assert result.plugin_id is None
assert result.tools == []
assert result.labels == ["label1", "label2"]
assert result.masked_credentials == {"api_key": ""}
assert result.original_credentials == {"api_key": "decrypted_key"}
def test_builtin_provider_to_user_provider_plugin_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful conversion of builtin provider to user provider with plugin.
This test verifies:
- Plugin ID and unique identifier handling
- Proper entity creation for plugin providers
"""
# Arrange: Setup test data
fake = Faker()
# Create mock provider controller with plugin
mock_controller = Mock()
mock_controller.entity.identity.name = fake.company()
mock_controller.entity.identity.author = fake.name()
mock_controller.entity.identity.description = I18nObject(en_US=fake.text(max_nb_chars=100))
mock_controller.entity.identity.icon = "🔧"
mock_controller.entity.identity.icon_dark = "🔧"
mock_controller.entity.identity.label = I18nObject(en_US=fake.company())
mock_controller.plugin_id = "test_plugin_id"
mock_controller.plugin_unique_identifier = "test_unique_id"
mock_controller.tool_labels = ["label1"]
mock_controller.need_credentials = False
# Mock credentials schema
mock_credential = Mock()
mock_credential.to_basic_provider_config.return_value.name = "api_key"
mock_controller.get_credentials_schema_by_type.return_value = [mock_credential]
# Act: Execute the method under test
result = ToolTransformService.builtin_provider_to_user_provider(
mock_controller, None, decrypt_credentials=False
)
# Assert: Verify the expected outcomes
assert result is not None
# Note: The method checks isinstance(provider_controller, PluginToolProviderController)
# Since we're using a Mock, this check will fail, so plugin_id will remain None
# In a real test with actual PluginToolProviderController, this would work
assert result.is_team_authorization is True
assert result.allow_delete is False
def test_builtin_provider_to_user_provider_no_credentials(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test conversion of builtin provider to user provider without credentials.
This test verifies:
- Proper handling when no credentials are needed
- Team authorization setup for no-credentials providers
"""
# Arrange: Setup test data
fake = Faker()
# Create mock provider controller
mock_controller = Mock()
mock_controller.entity.identity.name = fake.company()
mock_controller.entity.identity.author = fake.name()
mock_controller.entity.identity.description = I18nObject(en_US=fake.text(max_nb_chars=100))
mock_controller.entity.identity.icon = "🔧"
mock_controller.entity.identity.icon_dark = "🔧"
mock_controller.entity.identity.label = I18nObject(en_US=fake.company())
mock_controller.plugin_id = None
mock_controller.plugin_unique_identifier = None
mock_controller.tool_labels = []
mock_controller.need_credentials = False
# Mock credentials schema
mock_credential = Mock()
mock_credential.to_basic_provider_config.return_value.name = "api_key"
mock_controller.get_credentials_schema_by_type.return_value = [mock_credential]
# Act: Execute the method under test
result = ToolTransformService.builtin_provider_to_user_provider(
mock_controller, None, decrypt_credentials=False
)
# Assert: Verify the expected outcomes
assert result is not None
assert result.is_team_authorization is True
assert result.allow_delete is False
assert result.masked_credentials == {}
def test_api_provider_to_controller_success(self, db_session_with_containers, mock_external_service_dependencies):
"""
Test successful conversion of API provider to controller.
This test verifies:
- Proper controller creation from database provider
- Auth type handling for different credential types
- Backward compatibility for auth types
"""
# Arrange: Setup test data
fake = Faker()
# Create API tool provider with api_key_header auth
provider = ApiToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
icon='{"background": "#FF6B6B", "content": "🔧"}',
tenant_id=fake.uuid4(),
user_id=fake.uuid4(),
credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}',
schema="{}",
schema_type_str="openapi",
tools_str="[]",
)
from extensions.ext_database import db
db.session.add(provider)
db.session.commit()
# Act: Execute the method under test
result = ToolTransformService.api_provider_to_controller(provider)
# Assert: Verify the expected outcomes
assert result is not None
assert hasattr(result, "from_db")
# Additional assertions would depend on the actual controller implementation
def test_api_provider_to_controller_api_key_query(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test conversion of API provider to controller with api_key_query auth type.
This test verifies:
- Proper auth type handling for query parameter authentication
"""
# Arrange: Setup test data
fake = Faker()
# Create API tool provider with api_key_query auth
provider = ApiToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
icon='{"background": "#FF6B6B", "content": "🔧"}',
tenant_id=fake.uuid4(),
user_id=fake.uuid4(),
credentials_str='{"auth_type": "api_key_query", "api_key": "test_key"}',
schema="{}",
schema_type_str="openapi",
tools_str="[]",
)
from extensions.ext_database import db
db.session.add(provider)
db.session.commit()
# Act: Execute the method under test
result = ToolTransformService.api_provider_to_controller(provider)
# Assert: Verify the expected outcomes
assert result is not None
assert hasattr(result, "from_db")
def test_api_provider_to_controller_backward_compatibility(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test conversion of API provider to controller with backward compatibility auth types.
This test verifies:
- Proper handling of legacy auth type values
- Backward compatibility for api_key and api_key_header
"""
# Arrange: Setup test data
fake = Faker()
# Create API tool provider with legacy auth type
provider = ApiToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
icon='{"background": "#FF6B6B", "content": "🔧"}',
tenant_id=fake.uuid4(),
user_id=fake.uuid4(),
credentials_str='{"auth_type": "api_key", "api_key": "test_key"}',
schema="{}",
schema_type_str="openapi",
tools_str="[]",
)
from extensions.ext_database import db
db.session.add(provider)
db.session.commit()
# Act: Execute the method under test
result = ToolTransformService.api_provider_to_controller(provider)
# Assert: Verify the expected outcomes
assert result is not None
assert hasattr(result, "from_db")
def test_workflow_provider_to_controller_success(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test successful conversion of workflow provider to controller.
This test verifies:
- Proper controller creation from workflow provider
- Workflow-specific controller handling
"""
# Arrange: Setup test data
fake = Faker()
# Create workflow tool provider
provider = WorkflowToolProvider(
name=fake.company(),
description=fake.text(max_nb_chars=100),
icon='{"background": "#FF6B6B", "content": "🔧"}',
tenant_id=fake.uuid4(),
user_id=fake.uuid4(),
app_id=fake.uuid4(),
label="Test Workflow",
version="1.0.0",
parameter_configuration="[]",
)
from extensions.ext_database import db
db.session.add(provider)
db.session.commit()
# Mock the WorkflowToolProviderController.from_db method to avoid app dependency
with patch("services.tools.tools_transform_service.WorkflowToolProviderController.from_db") as mock_from_db:
mock_controller = Mock()
mock_from_db.return_value = mock_controller
# Act: Execute the method under test
result = ToolTransformService.workflow_provider_to_controller(provider)
# Assert: Verify the expected outcomes
assert result is not None
assert result == mock_controller
mock_from_db.assert_called_once_with(provider)

View File

@ -621,7 +621,7 @@ export default translation
&& !trimmed.startsWith('//'))
break
}
else {
else {
break
}

Some files were not shown because too many files have changed in this diff Show More