mirror of https://github.com/langgenius/dify.git
Merge remote-tracking branch 'origin/main' into feat/queue-based-graph-engine
This commit is contained in:
commit
0b0dc63f29
|
|
@ -89,7 +89,9 @@ jobs:
|
|||
- name: Web style check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
working-directory: ./web
|
||||
run: pnpm run lint
|
||||
run: |
|
||||
pnpm run lint
|
||||
pnpm run eslint
|
||||
|
||||
docker-compose-template:
|
||||
name: Docker Compose Template
|
||||
|
|
|
|||
|
|
@ -44,22 +44,19 @@ def oauth_server_access_token_required(view):
|
|||
if not oauth_provider_app or not isinstance(oauth_provider_app, OAuthProviderApp):
|
||||
raise BadRequest("Invalid oauth_provider_app")
|
||||
|
||||
if not request.headers.get("Authorization"):
|
||||
raise BadRequest("Authorization is required")
|
||||
|
||||
authorization_header = request.headers.get("Authorization")
|
||||
if not authorization_header:
|
||||
raise BadRequest("Authorization header is required")
|
||||
|
||||
parts = authorization_header.split(" ")
|
||||
parts = authorization_header.strip().split(" ")
|
||||
if len(parts) != 2:
|
||||
raise BadRequest("Invalid Authorization header format")
|
||||
|
||||
token_type = parts[0]
|
||||
if token_type != "Bearer":
|
||||
token_type = parts[0].strip()
|
||||
if token_type.lower() != "bearer":
|
||||
raise BadRequest("token_type is invalid")
|
||||
|
||||
access_token = parts[1]
|
||||
access_token = parts[1].strip()
|
||||
if not access_token:
|
||||
raise BadRequest("access_token is required")
|
||||
|
||||
|
|
@ -125,7 +122,10 @@ class OAuthServerUserTokenApi(Resource):
|
|||
parser.add_argument("refresh_token", type=str, required=False, location="json")
|
||||
parsed_args = parser.parse_args()
|
||||
|
||||
grant_type = OAuthGrantType(parsed_args["grant_type"])
|
||||
try:
|
||||
grant_type = OAuthGrantType(parsed_args["grant_type"])
|
||||
except ValueError:
|
||||
raise BadRequest("invalid grant_type")
|
||||
|
||||
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
|
||||
if not parsed_args["code"]:
|
||||
|
|
@ -163,8 +163,6 @@ class OAuthServerUserTokenApi(Resource):
|
|||
"refresh_token": refresh_token,
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise BadRequest("invalid grant_type")
|
||||
|
||||
|
||||
class OAuthServerUserAccountApi(Resource):
|
||||
|
|
|
|||
|
|
@ -354,9 +354,6 @@ class DatasetInitApi(Resource):
|
|||
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
knowledge_config = KnowledgeConfig(**args)
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
from base64 import b64encode
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from hashlib import sha1
|
||||
from hmac import new as hmac_new
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
from flask import abort, request
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -10,9 +14,9 @@ from extensions.ext_database import db
|
|||
from models.model import EndUser
|
||||
|
||||
|
||||
def billing_inner_api_only(view):
|
||||
def billing_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
|
|
@ -26,9 +30,9 @@ def billing_inner_api_only(view):
|
|||
return decorated
|
||||
|
||||
|
||||
def enterprise_inner_api_only(view):
|
||||
def enterprise_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.INNER_API:
|
||||
abort(404)
|
||||
|
||||
|
|
@ -78,9 +82,9 @@ def enterprise_inner_api_user_auth(view):
|
|||
return decorated
|
||||
|
||||
|
||||
def plugin_inner_api_only(view):
|
||||
def plugin_inner_api_only(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
if not dify_config.PLUGIN_DAEMON_KEY:
|
||||
abort(404)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from enum import StrEnum, auto
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -23,14 +23,14 @@ from models.model import ApiToken, App, EndUser
|
|||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
class WhereisUserArg(Enum):
|
||||
class WhereisUserArg(StrEnum):
|
||||
"""
|
||||
Enum for whereis_user_arg.
|
||||
"""
|
||||
|
||||
QUERY = "query"
|
||||
JSON = "json"
|
||||
FORM = "form"
|
||||
QUERY = auto()
|
||||
JSON = auto()
|
||||
FORM = auto()
|
||||
|
||||
|
||||
class FetchUserArg(BaseModel):
|
||||
|
|
|
|||
|
|
@ -334,7 +334,8 @@ class BaseAgentRunner(AppRunner):
|
|||
"""
|
||||
Save agent thought
|
||||
"""
|
||||
agent_thought = db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id).first()
|
||||
stmt = select(MessageAgentThought).where(MessageAgentThought.id == agent_thought_id)
|
||||
agent_thought = db.session.scalar(stmt)
|
||||
if not agent_thought:
|
||||
raise ValueError("agent thought not found")
|
||||
|
||||
|
|
@ -492,7 +493,8 @@ class BaseAgentRunner(AppRunner):
|
|||
return result
|
||||
|
||||
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
|
||||
files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
|
||||
stmt = select(MessageFile).where(MessageFile.message_id == message.id)
|
||||
files = db.session.scalars(stmt).all()
|
||||
if not files:
|
||||
return UserPromptMessage(content=message.query)
|
||||
if message.app_model_config:
|
||||
|
|
|
|||
|
|
@ -450,6 +450,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||
|
||||
worker_thread.start()
|
||||
|
||||
# release database connection, because the following new thread operations may take a long time
|
||||
db.session.refresh(workflow)
|
||||
db.session.refresh(message)
|
||||
db.session.refresh(user)
|
||||
db.session.close()
|
||||
|
||||
# return response or stream generator
|
||||
response = self._handle_advanced_chat_response(
|
||||
application_generate_entity=application_generate_entity,
|
||||
|
|
|
|||
|
|
@ -73,7 +73,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(AdvancedChatAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
|
||||
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
|
|
@ -147,7 +149,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||
environment_variables=self._workflow.environment_variables,
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
conversation_variables=cast(list[VariableUnion], conversation_variables),
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
# init graph
|
||||
|
|
|
|||
|
|
@ -68,7 +68,6 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
|
|||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Conversation, EndUser, Message, MessageFile
|
||||
|
|
@ -886,10 +885,6 @@ class AdvancedChatAppGenerateTaskPipeline:
|
|||
self._task_state.metadata.usage = usage
|
||||
else:
|
||||
self._task_state.metadata.usage = LLMUsage.empty_usage()
|
||||
message_was_created.send(
|
||||
message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
)
|
||||
|
||||
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import logging
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
|
||||
from core.agent.entities import AgentEntity
|
||||
|
|
@ -44,8 +46,8 @@ class AgentChatAppRunner(AppRunner):
|
|||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(AgentChatAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
|
||||
app_stmt = select(App).where(App.id == app_config.app_id)
|
||||
app_record = db.session.scalar(app_stmt)
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
|
|
@ -182,11 +184,12 @@ class AgentChatAppRunner(AppRunner):
|
|||
|
||||
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
|
||||
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
|
||||
conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first()
|
||||
conversation_stmt = select(Conversation).where(Conversation.id == conversation.id)
|
||||
conversation_result = db.session.scalar(conversation_stmt)
|
||||
if conversation_result is None:
|
||||
raise ValueError("Conversation not found")
|
||||
message_result = db.session.query(Message).where(Message.id == message.id).first()
|
||||
msg_stmt = select(Message).where(Message.id == message.id)
|
||||
message_result = db.session.scalar(msg_stmt)
|
||||
if message_result is None:
|
||||
raise ValueError("Message not found")
|
||||
db.session.close()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import queue
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from enum import IntEnum, auto
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy.orm import DeclarativeMeta
|
||||
|
|
@ -19,9 +19,9 @@ from core.app.entities.queue_entities import (
|
|||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class PublishFrom(Enum):
|
||||
APPLICATION_MANAGER = 1
|
||||
TASK_PIPELINE = 2
|
||||
class PublishFrom(IntEnum):
|
||||
APPLICATION_MANAGER = auto()
|
||||
TASK_PIPELINE = auto()
|
||||
|
||||
|
||||
class AppQueueManager:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import logging
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.apps.chat.app_config_manager import ChatAppConfig
|
||||
|
|
@ -42,8 +44,8 @@ class ChatAppRunner(AppRunner):
|
|||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(ChatAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
|
||||
stmt = select(App).where(App.id == app_config.app_id)
|
||||
app_record = db.session.scalar(stmt)
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import Any, Literal, Union, overload
|
|||
|
||||
from flask import Flask, copy_current_request_context, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
|
||||
|
|
@ -248,17 +249,14 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
|||
:param invoke_from: invoke from source
|
||||
:param stream: is stream
|
||||
"""
|
||||
message = (
|
||||
db.session.query(Message)
|
||||
.where(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
)
|
||||
.first()
|
||||
stmt = select(Message).where(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
|
||||
Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
|
||||
Message.from_account_id == (user.id if isinstance(user, Account) else None),
|
||||
)
|
||||
message = db.session.scalar(stmt)
|
||||
|
||||
if not message:
|
||||
raise MessageNotExistsError()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import logging
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.apps.completion.app_config_manager import CompletionAppConfig
|
||||
|
|
@ -35,8 +37,8 @@ class CompletionAppRunner(AppRunner):
|
|||
"""
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(CompletionAppConfig, app_config)
|
||||
|
||||
app_record = db.session.query(App).where(App.id == app_config.app_id).first()
|
||||
stmt = select(App).where(App.id == app_config.app_id)
|
||||
app_record = db.session.scalar(stmt)
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,9 @@ import logging
|
|||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
|
|
@ -83,11 +86,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
|
||||
def _get_app_model_config(self, app_model: App, conversation: Optional[Conversation] = None) -> AppModelConfig:
|
||||
if conversation:
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig)
|
||||
.where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
|
||||
.first()
|
||||
stmt = select(AppModelConfig).where(
|
||||
AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
|
||||
)
|
||||
app_model_config = db.session.scalar(stmt)
|
||||
|
||||
if not app_model_config:
|
||||
raise AppModelConfigBrokenError()
|
||||
|
|
@ -253,7 +255,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
:param conversation_id: conversation id
|
||||
:return: conversation
|
||||
"""
|
||||
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
conversation = session.scalar(select(Conversation).where(Conversation.id == conversation_id))
|
||||
|
||||
if not conversation:
|
||||
raise ConversationNotExistsError("Conversation not exists")
|
||||
|
|
@ -266,7 +269,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||
:param message_id: message id
|
||||
:return: message
|
||||
"""
|
||||
message = db.session.query(Message).where(Message.id == message_id).first()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message = session.scalar(select(Message).where(Message.id == message_id))
|
||||
|
||||
if message is None:
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -25,9 +27,8 @@ class AnnotationReplyFeature:
|
|||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
annotation_setting = (
|
||||
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first()
|
||||
)
|
||||
stmt = select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id)
|
||||
annotation_setting = db.session.scalar(stmt)
|
||||
|
||||
if not annotation_setting:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -472,9 +472,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||
:param event: agent thought event
|
||||
:return:
|
||||
"""
|
||||
agent_thought: Optional[MessageAgentThought] = (
|
||||
db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
agent_thought: Optional[MessageAgentThought] = (
|
||||
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
)
|
||||
|
||||
if agent_thought:
|
||||
return AgentThoughtStreamResponse(
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from threading import Thread
|
|||
from typing import Optional, Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
|
|
@ -84,7 +86,8 @@ class MessageCycleManager:
|
|||
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
|
||||
with flask_app.app_context():
|
||||
# get conversation and message
|
||||
conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
|
||||
stmt = select(Conversation).where(Conversation.id == conversation_id)
|
||||
conversation = db.session.scalar(stmt)
|
||||
|
||||
if not conversation:
|
||||
return
|
||||
|
|
@ -143,7 +146,8 @@ class MessageCycleManager:
|
|||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
message_file = db.session.query(MessageFile).where(MessageFile.id == event.message_file_id).first()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
|
||||
|
||||
if message_file and message_file.url is not None:
|
||||
# get tool file id
|
||||
|
|
@ -183,7 +187,8 @@ class MessageCycleManager:
|
|||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
message_file = db.session.query(MessageFile).where(MessageFile.id == message_id).first()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_file = session.scalar(select(MessageFile).where(MessageFile.id == message_id))
|
||||
event_type = StreamEvent.MESSAGE_FILE if message_file else StreamEvent.MESSAGE
|
||||
|
||||
return MessageStreamResponse(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
|
||||
|
|
@ -49,7 +51,8 @@ class DatasetIndexToolCallbackHandler:
|
|||
for document in documents:
|
||||
if document.metadata is not None:
|
||||
document_id = document.metadata["document_id"]
|
||||
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
|
||||
dataset_document_stmt = select(DatasetDocument).where(DatasetDocument.id == document_id)
|
||||
dataset_document = db.session.scalar(dataset_document_stmt)
|
||||
if not dataset_document:
|
||||
_logger.warning(
|
||||
"Expected DatasetDocument record to exist, but none was found, document_id=%s",
|
||||
|
|
@ -57,15 +60,12 @@ class DatasetIndexToolCallbackHandler:
|
|||
)
|
||||
continue
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
.first()
|
||||
child_chunk_stmt = select(ChildChunk).where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
child_chunk = db.session.scalar(child_chunk_stmt)
|
||||
if child_chunk:
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
|
||||
from core.external_data_tool.base import ExternalDataTool
|
||||
from core.helper import encrypter
|
||||
|
|
@ -28,13 +30,11 @@ class ApiExternalDataTool(ExternalDataTool):
|
|||
api_based_extension_id = config.get("api_based_extension_id")
|
||||
if not api_based_extension_id:
|
||||
raise ValueError("api_based_extension_id is required")
|
||||
|
||||
# get api_based_extension
|
||||
api_based_extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.first()
|
||||
stmt = select(APIBasedExtension).where(
|
||||
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
|
||||
)
|
||||
api_based_extension = db.session.scalar(stmt)
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError("api_based_extension_id is invalid")
|
||||
|
|
@ -52,13 +52,11 @@ class ApiExternalDataTool(ExternalDataTool):
|
|||
raise ValueError(f"config is required, config: {self.config}")
|
||||
api_based_extension_id = self.config.get("api_based_extension_id")
|
||||
assert api_based_extension_id is not None, "api_based_extension_id is required"
|
||||
|
||||
# get api_based_extension
|
||||
api_based_extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.first()
|
||||
stmt = select(APIBasedExtension).where(
|
||||
APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id
|
||||
)
|
||||
api_based_extension = db.session.scalar(stmt)
|
||||
|
||||
if not api_based_extension:
|
||||
raise ValueError(
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import base64
|
|||
from libs import rsa
|
||||
|
||||
|
||||
def obfuscated_token(token: str):
|
||||
def obfuscated_token(token: str) -> str:
|
||||
if not token:
|
||||
return token
|
||||
if len(token) <= 8:
|
||||
|
|
@ -11,6 +11,10 @@ def obfuscated_token(token: str):
|
|||
return token[:6] + "*" * 12 + token[-2:]
|
||||
|
||||
|
||||
def full_mask_token(token_length=20):
|
||||
return "*" * token_length
|
||||
|
||||
|
||||
def encrypt_token(tenant_id: str, token: str):
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Sequence
|
||||
|
||||
import requests
|
||||
import httpx
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -23,7 +23,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
|
|||
return []
|
||||
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = requests.post(url, json={"plugin_ids": plugin_ids})
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids})
|
||||
response.raise_for_status()
|
||||
|
||||
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
|
||||
|
|
@ -36,7 +36,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
|||
return []
|
||||
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = requests.post(url, json={"plugin_ids": plugin_ids})
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids})
|
||||
response.raise_for_status()
|
||||
result: list[MarketplacePluginDeclaration] = []
|
||||
for plugin in response.json()["data"]["plugins"]:
|
||||
|
|
@ -50,5 +50,5 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
|||
|
||||
def record_install_plugin_event(plugin_unique_identifier: str):
|
||||
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
|
||||
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
|
||||
response = httpx.post(url, json={"unique_identifier": plugin_unique_identifier})
|
||||
response.raise_for_status()
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import uuid
|
|||
from typing import Any, Optional, cast
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -56,13 +57,11 @@ class IndexingRunner:
|
|||
|
||||
if not dataset:
|
||||
raise ValueError("no dataset found")
|
||||
|
||||
# get the process rule
|
||||
processing_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
.first()
|
||||
stmt = select(DatasetProcessRule).where(
|
||||
DatasetProcessRule.id == dataset_document.dataset_process_rule_id
|
||||
)
|
||||
processing_rule = db.session.scalar(stmt)
|
||||
if not processing_rule:
|
||||
raise ValueError("no process rule found")
|
||||
index_type = dataset_document.doc_form
|
||||
|
|
@ -123,11 +122,8 @@ class IndexingRunner:
|
|||
db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
|
||||
db.session.commit()
|
||||
# get the process rule
|
||||
processing_rule = (
|
||||
db.session.query(DatasetProcessRule)
|
||||
.where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
stmt = select(DatasetProcessRule).where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
|
||||
processing_rule = db.session.scalar(stmt)
|
||||
if not processing_rule:
|
||||
raise ValueError("no process rule found")
|
||||
|
||||
|
|
@ -208,7 +204,6 @@ class IndexingRunner:
|
|||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
documents.append(document)
|
||||
|
||||
# build index
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
|
@ -310,7 +305,8 @@ class IndexingRunner:
|
|||
# delete image files and related db records
|
||||
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
|
||||
for upload_file_id in image_upload_file_ids:
|
||||
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||
stmt = select(UploadFile).where(UploadFile.id == upload_file_id)
|
||||
image_file = db.session.scalar(stmt)
|
||||
if image_file is None:
|
||||
continue
|
||||
try:
|
||||
|
|
@ -339,10 +335,8 @@ class IndexingRunner:
|
|||
if dataset_document.data_source_type == "upload_file":
|
||||
if not data_source_info or "upload_file_id" not in data_source_info:
|
||||
raise ValueError("no upload file found")
|
||||
|
||||
file_detail = (
|
||||
db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
|
||||
)
|
||||
stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"])
|
||||
file_detail = db.session.scalars(stmt).one_or_none()
|
||||
|
||||
if file_detail:
|
||||
extract_setting = ExtractSetting(
|
||||
|
|
|
|||
|
|
@ -110,9 +110,9 @@ class TokenBufferMemory:
|
|||
else:
|
||||
message_limit = 500
|
||||
|
||||
stmt = stmt.limit(message_limit)
|
||||
msg_limit_stmt = stmt.limit(message_limit)
|
||||
|
||||
messages = db.session.scalars(stmt).all()
|
||||
messages = db.session.scalars(msg_limit_stmt).all()
|
||||
|
||||
# instead of all messages from the conversation, we only need to extract messages
|
||||
# that belong to the thread of last message
|
||||
|
|
|
|||
|
|
@ -158,8 +158,6 @@ class ModelInstance:
|
|||
"""
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return cast(
|
||||
Union[LLMResult, Generator],
|
||||
self._round_robin_invoke(
|
||||
|
|
@ -188,8 +186,6 @@ class ModelInstance:
|
|||
"""
|
||||
if not isinstance(self.model_type_instance, LargeLanguageModel):
|
||||
raise Exception("Model type instance is not LargeLanguageModel")
|
||||
|
||||
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
|
||||
return cast(
|
||||
int,
|
||||
self._round_robin_invoke(
|
||||
|
|
@ -214,8 +210,6 @@ class ModelInstance:
|
|||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
|
||||
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||
return cast(
|
||||
TextEmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
|
|
@ -237,8 +231,6 @@ class ModelInstance:
|
|||
"""
|
||||
if not isinstance(self.model_type_instance, TextEmbeddingModel):
|
||||
raise Exception("Model type instance is not TextEmbeddingModel")
|
||||
|
||||
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
|
||||
return cast(
|
||||
list[int],
|
||||
self._round_robin_invoke(
|
||||
|
|
@ -269,8 +261,6 @@ class ModelInstance:
|
|||
"""
|
||||
if not isinstance(self.model_type_instance, RerankModel):
|
||||
raise Exception("Model type instance is not RerankModel")
|
||||
|
||||
self.model_type_instance = cast(RerankModel, self.model_type_instance)
|
||||
return cast(
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
|
|
@ -295,8 +285,6 @@ class ModelInstance:
|
|||
"""
|
||||
if not isinstance(self.model_type_instance, ModerationModel):
|
||||
raise Exception("Model type instance is not ModerationModel")
|
||||
|
||||
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
|
||||
return cast(
|
||||
bool,
|
||||
self._round_robin_invoke(
|
||||
|
|
@ -318,8 +306,6 @@ class ModelInstance:
|
|||
"""
|
||||
if not isinstance(self.model_type_instance, Speech2TextModel):
|
||||
raise Exception("Model type instance is not Speech2TextModel")
|
||||
|
||||
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
|
||||
return cast(
|
||||
str,
|
||||
self._round_robin_invoke(
|
||||
|
|
@ -343,8 +329,6 @@ class ModelInstance:
|
|||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return cast(
|
||||
Iterable[bytes],
|
||||
self._round_robin_invoke(
|
||||
|
|
@ -404,8 +388,6 @@ class ModelInstance:
|
|||
"""
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
|
||||
self.model_type_instance = cast(TTSModel, self.model_type_instance)
|
||||
return self.model_type_instance.get_tts_model_voices(
|
||||
model=self.model, credentials=self.credentials, language=language
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor
|
||||
from core.helper.encrypter import decrypt_token
|
||||
|
|
@ -87,10 +88,9 @@ class ApiModeration(Moderation):
|
|||
|
||||
@staticmethod
|
||||
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
|
||||
extension = (
|
||||
db.session.query(APIBasedExtension)
|
||||
.where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
|
||||
.first()
|
||||
stmt = select(APIBasedExtension).where(
|
||||
APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id
|
||||
)
|
||||
extension = db.session.scalar(stmt)
|
||||
|
||||
return extension
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import Optional
|
|||
from urllib.parse import urljoin
|
||||
|
||||
from opentelemetry.trace import Link, Status, StatusCode
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.ops.aliyun_trace.data_exporter.traceclient import (
|
||||
|
|
@ -260,15 +261,15 @@ class AliyunDataTrace(BaseTraceInstance):
|
|||
app_id = trace_info.metadata.get("app_id")
|
||||
if not app_id:
|
||||
raise ValueError("No app_id found in trace_info metadata")
|
||||
|
||||
app = session.query(App).where(App.id == app_id).first()
|
||||
app_stmt = select(App).where(App.id == app_id)
|
||||
app = session.scalar(app_stmt)
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).where(Account.id == app.created_by).first()
|
||||
account_stmt = select(Account).where(Account.id == app.created_by)
|
||||
service_account = session.scalar(account_stmt)
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
current_tenant = (
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.ops.entities.config_entity import BaseTracingConfig
|
||||
|
|
@ -44,14 +45,15 @@ class BaseTraceInstance(ABC):
|
|||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get the app to find its creator
|
||||
app = session.query(App).where(App.id == app_id).first()
|
||||
app_stmt = select(App).where(App.id == app_id)
|
||||
app = session.scalar(app_stmt)
|
||||
if not app:
|
||||
raise ValueError(f"App with id {app_id} not found")
|
||||
|
||||
if not app.created_by:
|
||||
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
|
||||
|
||||
service_account = session.query(Account).where(Account.id == app.created_by).first()
|
||||
account_stmt = select(Account).where(Account.id == app.created_by)
|
||||
service_account = session.scalar(account_stmt)
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
|
||||
|
|
|
|||
|
|
@ -228,9 +228,9 @@ class OpsTraceManager:
|
|||
|
||||
if not trace_config_data:
|
||||
return None
|
||||
|
||||
# decrypt_token
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
stmt = select(App).where(App.id == app_id)
|
||||
app = db.session.scalar(stmt)
|
||||
if not app:
|
||||
raise ValueError("App not found")
|
||||
|
||||
|
|
@ -297,20 +297,19 @@ class OpsTraceManager:
|
|||
@classmethod
|
||||
def get_app_config_through_message_id(cls, message_id: str):
|
||||
app_model_config = None
|
||||
message_data = db.session.query(Message).where(Message.id == message_id).first()
|
||||
message_stmt = select(Message).where(Message.id == message_id)
|
||||
message_data = db.session.scalar(message_stmt)
|
||||
if not message_data:
|
||||
return None
|
||||
conversation_id = message_data.conversation_id
|
||||
conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
|
||||
conversation_stmt = select(Conversation).where(Conversation.id == conversation_id)
|
||||
conversation_data = db.session.scalar(conversation_stmt)
|
||||
if not conversation_data:
|
||||
return None
|
||||
|
||||
if conversation_data.app_model_config_id:
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig)
|
||||
.where(AppModelConfig.id == conversation_data.app_model_config_id)
|
||||
.first()
|
||||
)
|
||||
config_stmt = select(AppModelConfig).where(AppModelConfig.id == conversation_data.app_model_config_id)
|
||||
app_model_config = db.session.scalar(config_stmt)
|
||||
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
|
||||
app_model_config = conversation_data.override_model_configs
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from collections.abc import Generator, Mapping
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.service_api.wraps import create_or_update_end_user_for_user_id
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
|
|
@ -191,10 +193,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||
"""
|
||||
get the user by user id
|
||||
"""
|
||||
|
||||
user = db.session.query(EndUser).where(EndUser.id == user_id).first()
|
||||
stmt = select(EndUser).where(EndUser.id == user_id)
|
||||
user = db.session.scalar(stmt)
|
||||
if not user:
|
||||
user = db.session.query(Account).where(Account.id == user_id).first()
|
||||
stmt = select(Account).where(Account.id == user_id)
|
||||
user = db.session.scalar(stmt)
|
||||
|
||||
if not user:
|
||||
raise ValueError("user not found")
|
||||
|
|
|
|||
|
|
@ -87,7 +87,6 @@ class PromptMessageUtil:
|
|||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
if content.type == PromptMessageContentType.TEXT:
|
||||
content = cast(TextPromptMessageContent, content)
|
||||
text += content.data
|
||||
else:
|
||||
content = cast(ImagePromptMessageContent, content)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import contextlib
|
|||
import json
|
||||
from collections import defaultdict
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
|
@ -154,8 +154,8 @@ class ProviderManager:
|
|||
for provider_entity in provider_entities:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET),
|
||||
exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET),
|
||||
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
|
||||
data=provider_entity,
|
||||
name_func=lambda x: x.provider,
|
||||
):
|
||||
|
|
@ -276,15 +276,11 @@ class ProviderManager:
|
|||
:param model_type: model type
|
||||
:return:
|
||||
"""
|
||||
# Get the corresponding TenantDefaultModel record
|
||||
default_model = (
|
||||
db.session.query(TenantDefaultModel)
|
||||
.where(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
.first()
|
||||
stmt = select(TenantDefaultModel).where(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
default_model = db.session.scalar(stmt)
|
||||
|
||||
# If it does not exist, get the first available provider model from get_configurations
|
||||
# and update the TenantDefaultModel record
|
||||
|
|
@ -367,16 +363,11 @@ class ProviderManager:
|
|||
model_names = [model.model for model in available_models]
|
||||
if model not in model_names:
|
||||
raise ValueError(f"Model {model} does not exist.")
|
||||
|
||||
# Get the list of available models from get_configurations and check if it is LLM
|
||||
default_model = (
|
||||
db.session.query(TenantDefaultModel)
|
||||
.where(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
.first()
|
||||
stmt = select(TenantDefaultModel).where(
|
||||
TenantDefaultModel.tenant_id == tenant_id,
|
||||
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
|
||||
)
|
||||
default_model = db.session.scalar(stmt)
|
||||
|
||||
# create or update TenantDefaultModel record
|
||||
if default_model:
|
||||
|
|
@ -598,16 +589,13 @@ class ProviderManager:
|
|||
provider_name_to_provider_records_dict[provider_name].append(new_provider_record)
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
existed_provider_record = (
|
||||
db.session.query(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == ModelProviderID(provider_name).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value,
|
||||
)
|
||||
.first()
|
||||
stmt = select(Provider).where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == ModelProviderID(provider_name).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value,
|
||||
)
|
||||
existed_provider_record = db.session.scalar(stmt)
|
||||
if not existed_provider_record:
|
||||
continue
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import Any, Optional
|
|||
|
||||
import orjson
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
|
||||
|
|
@ -211,11 +212,10 @@ class Jieba(BaseKeyword):
|
|||
return sorted_chunk_indices[:k]
|
||||
|
||||
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
|
||||
document_segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
|
||||
.first()
|
||||
stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id
|
||||
)
|
||||
document_segment = db.session.scalar(stmt)
|
||||
if document_segment:
|
||||
document_segment.keywords = keywords
|
||||
db.session.add(document_segment)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|||
from typing import Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, load_only
|
||||
|
||||
from configs import dify_config
|
||||
|
|
@ -24,7 +25,7 @@ default_retrieval_model = {
|
|||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
|
@ -127,7 +128,8 @@ class RetrievalService:
|
|||
external_retrieval_model: Optional[dict] = None,
|
||||
metadata_filtering_conditions: Optional[dict] = None,
|
||||
):
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
if not dataset:
|
||||
return []
|
||||
metadata_condition = (
|
||||
|
|
@ -316,10 +318,8 @@ class RetrievalService:
|
|||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
# Handle parent-child documents
|
||||
child_index_node_id = document.metadata.get("doc_id")
|
||||
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first()
|
||||
)
|
||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id)
|
||||
child_chunk = db.session.scalar(child_chunk_stmt)
|
||||
|
||||
if not child_chunk:
|
||||
continue
|
||||
|
|
@ -378,17 +378,13 @@ class RetrievalService:
|
|||
index_node_id = document.metadata.get("doc_id")
|
||||
if not index_node_id:
|
||||
continue
|
||||
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id == index_node_id,
|
||||
)
|
||||
.first()
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == dataset_document.dataset_id,
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.index_node_id == index_node_id,
|
||||
)
|
||||
segment = db.session.scalar(document_segment_stmt)
|
||||
|
||||
if not segment:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -256,7 +256,7 @@ class AnalyticdbVectorOpenAPI:
|
|||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
if match.score >= score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
|
|
@ -293,7 +293,7 @@ class AnalyticdbVectorOpenAPI:
|
|||
response = self._client.query_collection_data(request)
|
||||
documents = []
|
||||
for match in response.body.matches.match:
|
||||
if match.score > score_threshold:
|
||||
if match.score >= score_threshold:
|
||||
metadata = json.loads(match.metadata.get("metadata_"))
|
||||
metadata["score"] = match.score
|
||||
doc = Document(
|
||||
|
|
|
|||
|
|
@ -229,7 +229,7 @@ class AnalyticdbVectorBySql:
|
|||
documents = []
|
||||
for record in cur:
|
||||
id, vector, score, page_content, metadata = record
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=page_content,
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ class BaiduVector(BaseVector):
|
|||
if meta is not None:
|
||||
meta = json.loads(meta)
|
||||
score = row.get("score", 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
meta["score"] = score
|
||||
doc = Document(page_content=row_data.get(self.field_text), metadata=meta)
|
||||
docs.append(doc)
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ class ChromaVector(BaseVector):
|
|||
distance = distances[index]
|
||||
metadata = dict(metadatas[index])
|
||||
score = 1 - distance
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
metadata["score"] = score
|
||||
doc = Document(
|
||||
page_content=documents[index],
|
||||
|
|
|
|||
|
|
@ -304,7 +304,7 @@ class CouchbaseVector(BaseVector):
|
|||
return docs
|
||||
|
||||
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
|
||||
top_k = kwargs.get("top_k", 2)
|
||||
top_k = kwargs.get("top_k", 4)
|
||||
try:
|
||||
CBrequest = search.SearchRequest.create(search.QueryStringQuery("text:" + query))
|
||||
search_iter = self._scope.search(
|
||||
|
|
|
|||
|
|
@ -216,7 +216,7 @@ class ElasticSearchVector(BaseVector):
|
|||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ class HuaweiCloudVector(BaseVector):
|
|||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
|
|
|||
|
|
@ -275,7 +275,7 @@ class LindormVectorStore(BaseVector):
|
|||
docs = []
|
||||
for doc, score in docs_and_scores:
|
||||
score_threshold = kwargs.get("score_threshold", 0.0) or 0.0
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
|
|
|||
|
|
@ -194,7 +194,7 @@ class OpenGauss(BaseVector):
|
|||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
|
|
|
|||
|
|
@ -211,7 +211,7 @@ class OpenSearchVector(BaseVector):
|
|||
|
||||
metadata["score"] = hit["_score"]
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if hit["_score"] > score_threshold:
|
||||
if hit["_score"] >= score_threshold:
|
||||
doc = Document(page_content=hit["_source"].get(Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
|
|
|
|||
|
|
@ -261,7 +261,7 @@ class OracleVector(BaseVector):
|
|||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
conn.close()
|
||||
return docs
|
||||
|
|
|
|||
|
|
@ -202,7 +202,7 @@ class PGVectoRS(BaseVector):
|
|||
score = 1 - dis
|
||||
metadata["score"] = score
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
doc = Document(page_content=record.text, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
|
|
|||
|
|
@ -195,7 +195,7 @@ class PGVector(BaseVector):
|
|||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ class VastbaseVector(BaseVector):
|
|||
metadata, text, distance = record
|
||||
score = 1 - distance
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import os
|
|||
import uuid
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from itertools import islice
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import qdrant_client
|
||||
from flask import current_app
|
||||
|
|
@ -18,6 +18,7 @@ from qdrant_client.http.models import (
|
|||
TokenizerType,
|
||||
)
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
|
|
@ -369,7 +370,7 @@ class QdrantVector(BaseVector):
|
|||
continue
|
||||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
|
||||
|
|
@ -426,7 +427,6 @@ class QdrantVector(BaseVector):
|
|||
|
||||
def _reload_if_needed(self):
|
||||
if isinstance(self._client, QdrantLocal):
|
||||
self._client = cast(QdrantLocal, self._client)
|
||||
self._client._load()
|
||||
|
||||
@classmethod
|
||||
|
|
@ -446,11 +446,8 @@ class QdrantVector(BaseVector):
|
|||
class QdrantVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> QdrantVector:
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = (
|
||||
db.session.query(DatasetCollectionBinding)
|
||||
.where(DatasetCollectionBinding.id == dataset.collection_binding_id)
|
||||
.one_or_none()
|
||||
)
|
||||
stmt = select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == dataset.collection_binding_id)
|
||||
dataset_collection_binding = db.session.scalars(stmt).one_or_none()
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -233,7 +233,7 @@ class RelytVector(BaseVector):
|
|||
docs = []
|
||||
for document, score in results:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
if 1 - score > score_threshold:
|
||||
if 1 - score >= score_threshold:
|
||||
docs.append(document)
|
||||
return docs
|
||||
|
||||
|
|
|
|||
|
|
@ -300,7 +300,7 @@ class TableStoreVector(BaseVector):
|
|||
)
|
||||
documents = []
|
||||
for search_hit in search_response.search_hits:
|
||||
if search_hit.score > score_threshold:
|
||||
if search_hit.score >= score_threshold:
|
||||
ots_column_map = {}
|
||||
for col in search_hit.row[1]:
|
||||
ots_column_map[col[0]] = col[1]
|
||||
|
|
|
|||
|
|
@ -39,6 +39,9 @@ class TencentConfig(BaseModel):
|
|||
return {"url": self.url, "username": self.username, "key": self.api_key, "timeout": self.timeout}
|
||||
|
||||
|
||||
bm25 = BM25Encoder.default("zh")
|
||||
|
||||
|
||||
class TencentVector(BaseVector):
|
||||
field_id: str = "id"
|
||||
field_vector: str = "vector"
|
||||
|
|
@ -53,7 +56,6 @@ class TencentVector(BaseVector):
|
|||
self._dimension = 1024
|
||||
self._init_database()
|
||||
self._load_collection()
|
||||
self._bm25 = BM25Encoder.default("zh")
|
||||
|
||||
def _load_collection(self):
|
||||
"""
|
||||
|
|
@ -186,7 +188,7 @@ class TencentVector(BaseVector):
|
|||
metadata=metadata,
|
||||
)
|
||||
if self._enable_hybrid_search:
|
||||
doc.__dict__["sparse_vector"] = self._bm25.encode_texts(texts[i])
|
||||
doc.__dict__["sparse_vector"] = bm25.encode_texts(texts[i])
|
||||
docs.append(doc)
|
||||
self._client.upsert(
|
||||
database_name=self._client_config.database,
|
||||
|
|
@ -264,7 +266,7 @@ class TencentVector(BaseVector):
|
|||
match=[
|
||||
KeywordSearch(
|
||||
field_name="sparse_vector",
|
||||
data=self._bm25.encode_queries(query),
|
||||
data=bm25.encode_queries(query),
|
||||
),
|
||||
],
|
||||
rerank=WeightedRerank(
|
||||
|
|
@ -291,7 +293,7 @@ class TencentVector(BaseVector):
|
|||
score = 1 - result.get("score", 0.0)
|
||||
else:
|
||||
score = result.get("score", 0.0)
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
meta["score"] = score
|
||||
doc = Document(page_content=result.get(self.field_text), metadata=meta)
|
||||
docs.append(doc)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from qdrant_client.http.models import (
|
|||
)
|
||||
from qdrant_client.local.qdrant_local import QdrantLocal
|
||||
from requests.auth import HTTPDigestAuth
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.datasource.vdb.field import Field
|
||||
|
|
@ -351,7 +352,7 @@ class TidbOnQdrantVector(BaseVector):
|
|||
metadata = result.payload.get(Field.METADATA_KEY.value) or {}
|
||||
# duplicate check score threshold
|
||||
score_threshold = kwargs.get("score_threshold") or 0.0
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(
|
||||
page_content=result.payload.get(Field.CONTENT_KEY.value, ""),
|
||||
|
|
@ -416,16 +417,12 @@ class TidbOnQdrantVector(BaseVector):
|
|||
|
||||
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
|
||||
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
|
||||
tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
|
||||
)
|
||||
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
|
||||
if not tidb_auth_binding:
|
||||
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
|
||||
tidb_auth_binding = (
|
||||
db.session.query(TidbAuthBinding)
|
||||
.where(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
.one_or_none()
|
||||
)
|
||||
stmt = select(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id)
|
||||
tidb_auth_binding = db.session.scalars(stmt).one_or_none()
|
||||
if tidb_auth_binding:
|
||||
TIDB_ON_QDRANT_API_KEY = f"{tidb_auth_binding.account}:{tidb_auth_binding.password}"
|
||||
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ class UpstashVector(BaseVector):
|
|||
score = record.score
|
||||
if metadata is not None and text is not None:
|
||||
metadata["score"] = score
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
docs.append(Document(page_content=text, metadata=metadata))
|
||||
return docs
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ import time
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
|
@ -45,11 +47,10 @@ class Vector:
|
|||
vector_type = self._dataset.index_struct_dict["type"]
|
||||
else:
|
||||
if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
|
||||
whitelist = (
|
||||
db.session.query(Whitelist)
|
||||
.where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
|
||||
.one_or_none()
|
||||
stmt = select(Whitelist).where(
|
||||
Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db"
|
||||
)
|
||||
whitelist = db.session.scalars(stmt).one_or_none()
|
||||
if whitelist:
|
||||
vector_type = VectorType.TIDB_ON_QDRANT
|
||||
|
||||
|
|
|
|||
|
|
@ -192,7 +192,7 @@ class VikingDBVector(BaseVector):
|
|||
metadata = result.fields.get(vdb_Field.METADATA_KEY.value)
|
||||
if metadata is not None:
|
||||
metadata = json.loads(metadata)
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
metadata["score"] = result.score
|
||||
doc = Document(page_content=result.fields.get(vdb_Field.CONTENT_KEY.value), metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
|
|
|||
|
|
@ -220,7 +220,7 @@ class WeaviateVector(BaseVector):
|
|||
for doc, score in docs_and_scores:
|
||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||
# check score threshold
|
||||
if score > score_threshold:
|
||||
if score >= score_threshold:
|
||||
if doc.metadata is not None:
|
||||
doc.metadata["score"] = score
|
||||
docs.append(doc)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
|
@ -41,9 +41,8 @@ class DatasetDocumentStore:
|
|||
|
||||
@property
|
||||
def docs(self) -> dict[str, Document]:
|
||||
document_segments = (
|
||||
db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id).all()
|
||||
)
|
||||
stmt = select(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id)
|
||||
document_segments = db.session.scalars(stmt).all()
|
||||
|
||||
output = {}
|
||||
for document_segment in document_segments:
|
||||
|
|
@ -228,10 +227,9 @@ class DatasetDocumentStore:
|
|||
return data
|
||||
|
||||
def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]:
|
||||
document_segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
|
||||
.first()
|
||||
stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id
|
||||
)
|
||||
document_segment = db.session.scalar(stmt)
|
||||
|
||||
return document_segment
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional, cast
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.extractor.helpers import detect_file_encodings
|
||||
|
|
@ -76,7 +76,7 @@ class MarkdownExtractor(BaseExtractor):
|
|||
markdown_tups.append((current_header, current_text))
|
||||
|
||||
markdown_tups = [
|
||||
(re.sub(r"#", "", cast(str, key)).strip() if key else None, re.sub(r"<.*?>", "", value))
|
||||
(re.sub(r"#", "", key).strip() if key else None, re.sub(r"<.*?>", "", value))
|
||||
for key, value in markdown_tups
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import operator
|
|||
from typing import Any, Optional, cast
|
||||
|
||||
import requests
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
|
|
@ -367,22 +368,17 @@ class NotionExtractor(BaseExtractor):
|
|||
|
||||
@classmethod
|
||||
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.where(
|
||||
db.and_(
|
||||
DataSourceOauthBinding.tenant_id == tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
|
||||
)
|
||||
)
|
||||
.first()
|
||||
stmt = select(DataSourceOauthBinding).where(
|
||||
DataSourceOauthBinding.tenant_id == tenant_id,
|
||||
DataSourceOauthBinding.provider == "notion",
|
||||
DataSourceOauthBinding.disabled == False,
|
||||
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
|
||||
)
|
||||
data_source_binding = db.session.scalar(stmt)
|
||||
|
||||
if not data_source_binding:
|
||||
raise Exception(
|
||||
f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}"
|
||||
)
|
||||
|
||||
return cast(str, data_source_binding.access_token)
|
||||
return data_source_binding.access_token
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
import contextlib
|
||||
from collections.abc import Iterator
|
||||
from typing import Optional, cast
|
||||
from typing import Optional
|
||||
|
||||
from core.rag.extractor.blob.blob import Blob
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
|
|
@ -27,7 +27,7 @@ class PdfExtractor(BaseExtractor):
|
|||
plaintext_file_exists = False
|
||||
if self._file_cache_key:
|
||||
with contextlib.suppress(FileNotFoundError):
|
||||
text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
|
||||
text = storage.load(self._file_cache_key).decode("utf-8")
|
||||
plaintext_file_exists = True
|
||||
return [Document(page_content=text)]
|
||||
documents = list(self.load())
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
|||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata["score"] = result.score
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
|
|
|||
|
|
@ -130,13 +130,16 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||
if delete_child_chunks:
|
||||
db.session.query(ChildChunk).where(
|
||||
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
|
||||
).delete()
|
||||
).delete(synchronize_session=False)
|
||||
db.session.commit()
|
||||
else:
|
||||
vector.delete()
|
||||
|
||||
if delete_child_chunks:
|
||||
db.session.query(ChildChunk).where(ChildChunk.dataset_id == dataset.id).delete()
|
||||
# Use existing compound index: (tenant_id, dataset_id, ...)
|
||||
db.session.query(ChildChunk).where(
|
||||
ChildChunk.tenant_id == dataset.tenant_id, ChildChunk.dataset_id == dataset.id
|
||||
).delete(synchronize_session=False)
|
||||
db.session.commit()
|
||||
|
||||
def retrieve(
|
||||
|
|
@ -162,7 +165,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
|||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata["score"] = result.score
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
|
|
|||
|
|
@ -158,7 +158,7 @@ class QAIndexProcessor(BaseIndexProcessor):
|
|||
for result in results:
|
||||
metadata = result.metadata
|
||||
metadata["score"] = result.score
|
||||
if result.score > score_threshold:
|
||||
if result.score >= score_threshold:
|
||||
doc = Document(page_content=result.page_content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
|
|||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import Float, and_, or_, text
|
||||
from sqlalchemy import Float, and_, or_, select, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
|
@ -65,7 +65,7 @@ default_retrieval_model: dict[str, Any] = {
|
|||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
|
@ -135,7 +135,8 @@ class DatasetRetrieval:
|
|||
available_datasets = []
|
||||
for dataset_id in dataset_ids:
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
|
|
@ -240,15 +241,12 @@ class DatasetRetrieval:
|
|||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
document = (
|
||||
db.session.query(DatasetDocument)
|
||||
.where(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.first()
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
document = db.session.scalar(dataset_document_stmt)
|
||||
if dataset and document:
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=dataset.id,
|
||||
|
|
@ -327,7 +325,8 @@ class DatasetRetrieval:
|
|||
|
||||
if dataset_id:
|
||||
# get retrieval model config
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
if dataset:
|
||||
results = []
|
||||
if dataset.provider == "external":
|
||||
|
|
@ -514,22 +513,18 @@ class DatasetRetrieval:
|
|||
dify_documents = [document for document in documents if document.provider == "dify"]
|
||||
for document in dify_documents:
|
||||
if document.metadata is not None:
|
||||
dataset_document = (
|
||||
db.session.query(DatasetDocument)
|
||||
.where(DatasetDocument.id == document.metadata["document_id"])
|
||||
.first()
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == document.metadata["document_id"]
|
||||
)
|
||||
dataset_document = db.session.scalar(dataset_document_stmt)
|
||||
if dataset_document:
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunk = (
|
||||
db.session.query(ChildChunk)
|
||||
.where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
.first()
|
||||
child_chunk_stmt = select(ChildChunk).where(
|
||||
ChildChunk.index_node_id == document.metadata["doc_id"],
|
||||
ChildChunk.dataset_id == dataset_document.dataset_id,
|
||||
ChildChunk.document_id == dataset_document.id,
|
||||
)
|
||||
child_chunk = db.session.scalar(child_chunk_stmt)
|
||||
if child_chunk:
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
|
|
@ -600,7 +595,8 @@ class DatasetRetrieval:
|
|||
):
|
||||
with flask_app.app_context():
|
||||
with Session(db.engine) as session:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset_stmt = select(Dataset).where(Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
|
|
@ -647,7 +643,7 @@ class DatasetRetrieval:
|
|||
retrieval_method=retrieval_model["search_method"],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k") or 2,
|
||||
top_k=retrieval_model.get("top_k") or 4,
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
|
|
@ -685,7 +681,8 @@ class DatasetRetrieval:
|
|||
available_datasets = []
|
||||
for dataset_id in dataset_ids:
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
|
||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
|
|
@ -743,7 +740,7 @@ class DatasetRetrieval:
|
|||
tool = DatasetMultiRetrieverTool.from_dataset(
|
||||
dataset_ids=[dataset.id for dataset in available_datasets],
|
||||
tenant_id=tenant_id,
|
||||
top_k=retrieve_config.top_k or 2,
|
||||
top_k=retrieve_config.top_k or 4,
|
||||
score_threshold=retrieve_config.score_threshold,
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=return_resource,
|
||||
|
|
@ -958,7 +955,8 @@ class DatasetRetrieval:
|
|||
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
|
||||
) -> Optional[list[dict[str, Any]]]:
|
||||
# get all metadata field
|
||||
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
||||
metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
|
||||
metadata_fields = db.session.scalars(metadata_stmt).all()
|
||||
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
|
||||
# get metadata model config
|
||||
if metadata_model_config is None:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from sqlalchemy import select
|
||||
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
|
|
@ -54,17 +56,13 @@ class ToolLabelManager:
|
|||
return controller.tool_labels
|
||||
else:
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
labels = (
|
||||
db.session.query(ToolLabelBinding.label_name)
|
||||
.where(
|
||||
ToolLabelBinding.tool_id == provider_id,
|
||||
ToolLabelBinding.tool_type == controller.provider_type.value,
|
||||
)
|
||||
.all()
|
||||
stmt = select(ToolLabelBinding.label_name).where(
|
||||
ToolLabelBinding.tool_id == provider_id,
|
||||
ToolLabelBinding.tool_type == controller.provider_type.value,
|
||||
)
|
||||
labels = db.session.scalars(stmt).all()
|
||||
|
||||
return [label.label_name for label in labels]
|
||||
return list(labels)
|
||||
|
||||
@classmethod
|
||||
def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
|||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from yarl import URL
|
||||
|
||||
|
|
@ -198,14 +199,11 @@ class ToolManager:
|
|||
# get specific credentials
|
||||
if is_valid_uuid(credential_id):
|
||||
try:
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.id == credential_id,
|
||||
)
|
||||
.first()
|
||||
builtin_provider_stmt = select(BuiltinToolProvider).where(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.id == credential_id,
|
||||
)
|
||||
builtin_provider = db.session.scalar(builtin_provider_stmt)
|
||||
except Exception as e:
|
||||
builtin_provider = None
|
||||
logger.info("Error getting builtin provider %s:%s", credential_id, e, exc_info=True)
|
||||
|
|
@ -319,11 +317,10 @@ class ToolManager:
|
|||
),
|
||||
)
|
||||
elif provider_type == ToolProviderType.WORKFLOW:
|
||||
workflow_provider = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
|
||||
.first()
|
||||
workflow_provider_stmt = select(WorkflowToolProvider).where(
|
||||
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id
|
||||
)
|
||||
workflow_provider = db.session.scalar(workflow_provider_stmt)
|
||||
|
||||
if workflow_provider is None:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
|
@ -333,16 +330,13 @@ class ToolManager:
|
|||
if controller_tools is None or len(controller_tools) == 0:
|
||||
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
|
||||
|
||||
return cast(
|
||||
WorkflowTool,
|
||||
controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
),
|
||||
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials={},
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
)
|
||||
elif provider_type == ToolProviderType.APP:
|
||||
raise NotImplementedError("app provider not implemented")
|
||||
|
|
@ -652,8 +646,8 @@ class ToolManager:
|
|||
for provider in builtin_providers:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
|
||||
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||
data=provider,
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import Any
|
|||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.model_manager import ModelManager
|
||||
|
|
@ -85,17 +86,14 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
DocumentSegment.dataset_id.in_(self.dataset_ids),
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||
)
|
||||
.all()
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.dataset_id.in_(self.dataset_ids),
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.index_node_id.in_(index_node_ids),
|
||||
)
|
||||
segments = db.session.scalars(document_segment_stmt).all()
|
||||
|
||||
if segments:
|
||||
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
|
||||
|
|
@ -112,15 +110,12 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
.first()
|
||||
document_stmt = select(Document).where(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
document = db.session.scalar(document_stmt)
|
||||
if dataset and document:
|
||||
source = RetrievalSourceMetadata(
|
||||
position=resource_number,
|
||||
|
|
@ -162,9 +157,8 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||
hit_callbacks: list[DatasetIndexToolCallbackHandler],
|
||||
):
|
||||
with flask_app.app_context():
|
||||
dataset = (
|
||||
db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
|
||||
)
|
||||
stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id)
|
||||
dataset = db.session.scalar(stmt)
|
||||
|
||||
if not dataset:
|
||||
return []
|
||||
|
|
@ -181,7 +175,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||
retrieval_method="keyword_search",
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k") or 2,
|
||||
top_k=retrieval_model.get("top_k") or 4,
|
||||
)
|
||||
if documents:
|
||||
all_documents.extend(documents)
|
||||
|
|
@ -192,7 +186,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
|||
retrieval_method=retrieval_model["search_method"],
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k") or 2,
|
||||
top_k=retrieval_model.get("top_k") or 4,
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
|
|||
name: str = "dataset"
|
||||
description: str = "use this to retrieve a dataset. "
|
||||
tenant_id: str
|
||||
top_k: int = 2
|
||||
top_k: int = 4
|
||||
score_threshold: Optional[float] = None
|
||||
hit_callbacks: list[DatasetIndexToolCallbackHandler] = []
|
||||
return_resource: bool
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
|
|
@ -56,9 +57,8 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
)
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
dataset = (
|
||||
db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
|
||||
)
|
||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id)
|
||||
dataset = db.session.scalar(dataset_stmt)
|
||||
|
||||
if not dataset:
|
||||
return ""
|
||||
|
|
@ -188,15 +188,12 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
|||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
document = (
|
||||
db.session.query(DatasetDocument) # type: ignore
|
||||
.where(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.first()
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
document = db.session.scalar(dataset_document_stmt) # type: ignore
|
||||
if dataset and document:
|
||||
source = RetrievalSourceMetadata(
|
||||
dataset_id=dataset.id,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from collections.abc import Generator
|
|||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
from mimetypes import guess_extension
|
||||
from typing import Optional, cast
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -159,8 +159,7 @@ class ToolFileMessageTransformer:
|
|||
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
if isinstance(message.message, ToolInvokeMessage.JsonMessage):
|
||||
json_msg = cast(ToolInvokeMessage.JsonMessage, message.message)
|
||||
json_msg.json_object = safe_json_value(json_msg.json_object)
|
||||
message.message.json_object = safe_json_value(message.message.json_object)
|
||||
yield message
|
||||
else:
|
||||
yield message
|
||||
|
|
|
|||
|
|
@ -129,17 +129,14 @@ class ModelInvocationUtils:
|
|||
db.session.commit()
|
||||
|
||||
try:
|
||||
response: LLMResult = cast(
|
||||
LLMResult,
|
||||
model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=[],
|
||||
stop=[],
|
||||
stream=False,
|
||||
user=user_id,
|
||||
callbacks=[],
|
||||
),
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=[],
|
||||
stop=[],
|
||||
stream=False,
|
||||
user=user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
except InvokeRateLimitError as e:
|
||||
raise InvokeModelError(f"Invoke rate limit error: {e}")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||
from core.tools.__base.tool import Tool
|
||||
|
|
@ -133,7 +135,8 @@ class WorkflowTool(Tool):
|
|||
.first()
|
||||
)
|
||||
else:
|
||||
workflow = db.session.query(Workflow).where(Workflow.app_id == app_id, Workflow.version == version).first()
|
||||
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
|
||||
workflow = db.session.scalar(stmt)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError("workflow not found or not published")
|
||||
|
|
@ -144,7 +147,8 @@ class WorkflowTool(Tool):
|
|||
"""
|
||||
get the app by app id
|
||||
"""
|
||||
app = db.session.query(App).where(App.id == app_id).first()
|
||||
stmt = select(App).where(App.id == app_id)
|
||||
app = db.session.scalar(stmt)
|
||||
if not app:
|
||||
raise ValueError("app not found")
|
||||
|
||||
|
|
@ -201,14 +205,14 @@ class WorkflowTool(Tool):
|
|||
item = self._update_file_mapping(item)
|
||||
file = build_from_mapping(
|
||||
mapping=item,
|
||||
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
|
||||
tenant_id=str(self.runtime.tenant_id),
|
||||
)
|
||||
files.append(file)
|
||||
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
value = self._update_file_mapping(value)
|
||||
file = build_from_mapping(
|
||||
mapping=value,
|
||||
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
|
||||
tenant_id=str(self.runtime.tenant_id),
|
||||
)
|
||||
files.append(file)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Annotated, TypeAlias, cast
|
||||
from typing import Annotated, TypeAlias
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import Discriminator, Field, Tag
|
||||
|
|
@ -86,7 +86,7 @@ class SecretVariable(StringVariable):
|
|||
|
||||
@property
|
||||
def log(self) -> str:
|
||||
return cast(str, encrypter.obfuscated_token(self.value))
|
||||
return encrypter.obfuscated_token(self.value)
|
||||
|
||||
|
||||
class NoneVariable(NoneSegment, Variable):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,339 @@
|
|||
"""
|
||||
QueueBasedGraphEngine - Main orchestrator for queue-based workflow execution.
|
||||
|
||||
This engine uses a modular architecture with separated packages following
|
||||
Domain-Driven Design principles for improved maintainability and testability.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import queue
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import final
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphRuntimeState
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphNodeEventBase,
|
||||
GraphRunAbortedEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
)
|
||||
from models.enums import UserFrom
|
||||
|
||||
from .command_processing import AbortCommandHandler, CommandProcessor
|
||||
from .domain import ExecutionContext, GraphExecution
|
||||
from .entities.commands import AbortCommand
|
||||
from .error_handling import ErrorHandler
|
||||
from .event_management import EventHandler, EventManager
|
||||
from .graph_traversal import EdgeProcessor, SkipPropagator
|
||||
from .layers.base import Layer
|
||||
from .orchestration import Dispatcher, ExecutionCoordinator
|
||||
from .protocols.command_channel import CommandChannel
|
||||
from .response_coordinator import ResponseStreamCoordinator
|
||||
from .state_management import UnifiedStateManager
|
||||
from .worker_management import SimpleWorkerPool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@final
|
||||
class GraphEngine:
|
||||
"""
|
||||
Queue-based graph execution engine.
|
||||
|
||||
Uses a modular architecture that delegates responsibilities to specialized
|
||||
subsystems, following Domain-Driven Design and SOLID principles.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
call_depth: int,
|
||||
graph: Graph,
|
||||
graph_config: Mapping[str, object],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
max_execution_steps: int,
|
||||
max_execution_time: int,
|
||||
command_channel: CommandChannel,
|
||||
min_workers: int | None = None,
|
||||
max_workers: int | None = None,
|
||||
scale_up_threshold: int | None = None,
|
||||
scale_down_idle_time: float | None = None,
|
||||
) -> None:
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
|
||||
# === Domain Models ===
|
||||
# Execution context encapsulates workflow execution metadata
|
||||
self._execution_context = ExecutionContext(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
max_execution_steps=max_execution_steps,
|
||||
max_execution_time=max_execution_time,
|
||||
)
|
||||
|
||||
# Graph execution tracks the overall execution state
|
||||
self._graph_execution = GraphExecution(workflow_id=workflow_id)
|
||||
|
||||
# === Core Dependencies ===
|
||||
# Graph structure and configuration
|
||||
self._graph = graph
|
||||
self._graph_config = graph_config
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._command_channel = command_channel
|
||||
|
||||
# === Worker Management Parameters ===
|
||||
# Parameters for dynamic worker pool scaling
|
||||
self._min_workers = min_workers
|
||||
self._max_workers = max_workers
|
||||
self._scale_up_threshold = scale_up_threshold
|
||||
self._scale_down_idle_time = scale_down_idle_time
|
||||
|
||||
# === Execution Queues ===
|
||||
# Queue for nodes ready to execute
|
||||
self._ready_queue: queue.Queue[str] = queue.Queue()
|
||||
# Queue for events generated during execution
|
||||
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
||||
|
||||
# === State Management ===
|
||||
# Unified state manager handles all node state transitions and queue operations
|
||||
self._state_manager = UnifiedStateManager(self._graph, self._ready_queue)
|
||||
|
||||
# === Response Coordination ===
|
||||
# Coordinates response streaming from response nodes
|
||||
self._response_coordinator = ResponseStreamCoordinator(
|
||||
variable_pool=self._graph_runtime_state.variable_pool, graph=self._graph
|
||||
)
|
||||
|
||||
# === Event Management ===
|
||||
# Event manager handles both collection and emission of events
|
||||
self._event_manager = EventManager()
|
||||
|
||||
# === Error Handling ===
|
||||
# Centralized error handler for graph execution errors
|
||||
self._error_handler = ErrorHandler(self._graph, self._graph_execution)
|
||||
|
||||
# === Graph Traversal Components ===
|
||||
# Propagates skip status through the graph when conditions aren't met
|
||||
self._skip_propagator = SkipPropagator(
|
||||
graph=self._graph,
|
||||
state_manager=self._state_manager,
|
||||
)
|
||||
|
||||
# Processes edges to determine next nodes after execution
|
||||
# Also handles conditional branching and route selection
|
||||
self._edge_processor = EdgeProcessor(
|
||||
graph=self._graph,
|
||||
state_manager=self._state_manager,
|
||||
response_coordinator=self._response_coordinator,
|
||||
skip_propagator=self._skip_propagator,
|
||||
)
|
||||
|
||||
# === Event Handler Registry ===
|
||||
# Central registry for handling all node execution events
|
||||
self._event_handler_registry = EventHandler(
|
||||
graph=self._graph,
|
||||
graph_runtime_state=self._graph_runtime_state,
|
||||
graph_execution=self._graph_execution,
|
||||
response_coordinator=self._response_coordinator,
|
||||
event_collector=self._event_manager,
|
||||
edge_processor=self._edge_processor,
|
||||
state_manager=self._state_manager,
|
||||
error_handler=self._error_handler,
|
||||
)
|
||||
|
||||
# === Command Processing ===
|
||||
# Processes external commands (e.g., abort requests)
|
||||
self._command_processor = CommandProcessor(
|
||||
command_channel=self._command_channel,
|
||||
graph_execution=self._graph_execution,
|
||||
)
|
||||
|
||||
# Register abort command handler
|
||||
abort_handler = AbortCommandHandler()
|
||||
self._command_processor.register_handler(
|
||||
AbortCommand,
|
||||
abort_handler,
|
||||
)
|
||||
|
||||
# === Worker Pool Setup ===
|
||||
# Capture Flask app context for worker threads
|
||||
flask_app: Flask | None = None
|
||||
try:
|
||||
app = current_app._get_current_object() # type: ignore
|
||||
if isinstance(app, Flask):
|
||||
flask_app = app
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
# Capture context variables for worker threads
|
||||
context_vars = contextvars.copy_context()
|
||||
|
||||
# Create worker pool for parallel node execution
|
||||
self._worker_pool = SimpleWorkerPool(
|
||||
ready_queue=self._ready_queue,
|
||||
event_queue=self._event_queue,
|
||||
graph=self._graph,
|
||||
flask_app=flask_app,
|
||||
context_vars=context_vars,
|
||||
min_workers=self._min_workers,
|
||||
max_workers=self._max_workers,
|
||||
scale_up_threshold=self._scale_up_threshold,
|
||||
scale_down_idle_time=self._scale_down_idle_time,
|
||||
)
|
||||
|
||||
# === Orchestration ===
|
||||
# Coordinates the overall execution lifecycle
|
||||
self._execution_coordinator = ExecutionCoordinator(
|
||||
graph_execution=self._graph_execution,
|
||||
state_manager=self._state_manager,
|
||||
event_handler=self._event_handler_registry,
|
||||
event_collector=self._event_manager,
|
||||
command_processor=self._command_processor,
|
||||
worker_pool=self._worker_pool,
|
||||
)
|
||||
|
||||
# Dispatches events and manages execution flow
|
||||
self._dispatcher = Dispatcher(
|
||||
event_queue=self._event_queue,
|
||||
event_handler=self._event_handler_registry,
|
||||
event_collector=self._event_manager,
|
||||
execution_coordinator=self._execution_coordinator,
|
||||
max_execution_time=self._execution_context.max_execution_time,
|
||||
event_emitter=self._event_manager,
|
||||
)
|
||||
|
||||
# === Extensibility ===
|
||||
# Layers allow plugins to extend engine functionality
|
||||
self._layers: list[Layer] = []
|
||||
|
||||
# === Validation ===
|
||||
# Ensure all nodes share the same GraphRuntimeState instance
|
||||
self._validate_graph_state_consistency()
|
||||
|
||||
def _validate_graph_state_consistency(self) -> None:
|
||||
"""Validate that all nodes share the same GraphRuntimeState."""
|
||||
expected_state_id = id(self._graph_runtime_state)
|
||||
for node in self._graph.nodes.values():
|
||||
if id(node.graph_runtime_state) != expected_state_id:
|
||||
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
|
||||
|
||||
def layer(self, layer: Layer) -> "GraphEngine":
|
||||
"""Add a layer for extending functionality."""
|
||||
self._layers.append(layer)
|
||||
return self
|
||||
|
||||
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Execute the graph using the modular architecture.
|
||||
|
||||
Returns:
|
||||
Generator yielding GraphEngineEvent instances
|
||||
"""
|
||||
try:
|
||||
# Initialize layers
|
||||
self._initialize_layers()
|
||||
|
||||
# Start execution
|
||||
self._graph_execution.start()
|
||||
start_event = GraphRunStartedEvent()
|
||||
yield start_event
|
||||
|
||||
# Start subsystems
|
||||
self._start_execution()
|
||||
|
||||
# Yield events as they occur
|
||||
yield from self._event_manager.emit_events()
|
||||
|
||||
# Handle completion
|
||||
if self._graph_execution.aborted:
|
||||
abort_reason = "Workflow execution aborted by user command"
|
||||
if self._graph_execution.error:
|
||||
abort_reason = str(self._graph_execution.error)
|
||||
yield GraphRunAbortedEvent(
|
||||
reason=abort_reason,
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
elif self._graph_execution.has_error:
|
||||
if self._graph_execution.error:
|
||||
raise self._graph_execution.error
|
||||
else:
|
||||
yield GraphRunSucceededEvent(
|
||||
outputs=self._graph_runtime_state.outputs,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
yield GraphRunFailedEvent(error=str(e))
|
||||
raise
|
||||
|
||||
finally:
|
||||
self._stop_execution()
|
||||
|
||||
def _initialize_layers(self) -> None:
|
||||
"""Initialize layers with context."""
|
||||
self._event_manager.set_layers(self._layers)
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.initialize(self._graph_runtime_state, self._command_channel)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)
|
||||
|
||||
try:
|
||||
layer.on_graph_start()
|
||||
except Exception as e:
|
||||
logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e)
|
||||
|
||||
def _start_execution(self) -> None:
|
||||
"""Start execution subsystems."""
|
||||
# Start worker pool (it calculates initial workers internally)
|
||||
self._worker_pool.start()
|
||||
|
||||
# Register response nodes
|
||||
for node in self._graph.nodes.values():
|
||||
if node.execution_type == NodeExecutionType.RESPONSE:
|
||||
self._response_coordinator.register(node.id)
|
||||
|
||||
# Enqueue root node
|
||||
root_node = self._graph.root_node
|
||||
self._state_manager.enqueue_node(root_node.id)
|
||||
self._state_manager.start_execution(root_node.id)
|
||||
|
||||
# Start dispatcher
|
||||
self._dispatcher.start()
|
||||
|
||||
def _stop_execution(self) -> None:
|
||||
"""Stop execution subsystems."""
|
||||
self._dispatcher.stop()
|
||||
self._worker_pool.stop()
|
||||
# Don't mark complete here as the dispatcher already does it
|
||||
|
||||
# Notify layers
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_graph_end(self._graph_execution.error)
|
||||
except Exception as e:
|
||||
logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e)
|
||||
|
||||
# Public property accessors for attributes that need external access
|
||||
@property
|
||||
def graph_runtime_state(self) -> GraphRuntimeState:
|
||||
"""Get the graph runtime state."""
|
||||
return self._graph_runtime_state
|
||||
|
|
@ -157,7 +157,7 @@ class AgentNode(Node):
|
|||
messages=message_stream,
|
||||
tool_info={
|
||||
"icon": self.agent_strategy_icon,
|
||||
"agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
|
||||
"agent_strategy": self._node_data.agent_strategy_name,
|
||||
},
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
|
|
@ -401,8 +401,7 @@ class AgentNode(Node):
|
|||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}"
|
||||
== cast(AgentNodeData, self._node_data).agent_strategy_provider_name
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
|
|
|
|||
|
|
@ -301,12 +301,12 @@ def _extract_text_from_yaml(file_content: bytes) -> str:
|
|||
encoding = "utf-8"
|
||||
|
||||
yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore"))
|
||||
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
|
||||
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
|
||||
except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e:
|
||||
# If decoding fails, try with utf-8 as last resort
|
||||
try:
|
||||
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore"))
|
||||
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
|
||||
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
|
||||
except (UnicodeDecodeError, yaml.YAMLError):
|
||||
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from collections import defaultdict
|
|||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from sqlalchemy import Float, and_, func, or_, text
|
||||
from sqlalchemy import Float, and_, func, or_, select, text
|
||||
from sqlalchemy import cast as sqlalchemy_cast
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
|
|
@ -75,7 +75,7 @@ default_retrieval_model = {
|
|||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
|
@ -358,15 +358,12 @@ class KnowledgeRetrievalNode(Node):
|
|||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.where(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
.first()
|
||||
stmt = select(Document).where(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
document = db.session.scalar(stmt)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"metadata": {
|
||||
|
|
@ -505,7 +502,8 @@ class KnowledgeRetrievalNode(Node):
|
|||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> list[dict[str, Any]]:
|
||||
# get all metadata field
|
||||
metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
|
||||
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
|
||||
metadata_fields = db.session.scalars(stmt).all()
|
||||
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
|
||||
if node_data.metadata_model_config is None:
|
||||
raise ValueError("metadata_model_config is required")
|
||||
|
|
|
|||
|
|
@ -138,7 +138,7 @@ class ParameterExtractorNode(Node):
|
|||
"""
|
||||
Run the node.
|
||||
"""
|
||||
node_data = cast(ParameterExtractorNodeData, self._node_data)
|
||||
node_data = self._node_data
|
||||
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
|
||||
query = variable.text if variable else ""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
|
@ -103,7 +103,7 @@ class QuestionClassifierNode(Node):
|
|||
return "1"
|
||||
|
||||
def _run(self):
|
||||
node_data = cast(QuestionClassifierNodeData, self._node_data)
|
||||
node_data = self._node_data
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# extract variables
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -61,7 +61,7 @@ class ToolNode(Node):
|
|||
"""
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
|
||||
|
||||
node_data = cast(ToolNodeData, self._node_data)
|
||||
node_data = self._node_data
|
||||
|
||||
# fetch tool icon
|
||||
tool_info = {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,493 @@
|
|||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.errors import ToolInvokeError
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.enums import (
|
||||
ErrorStrategy,
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from models import ToolFile
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .entities import ToolNodeData
|
||||
from .exc import (
|
||||
ToolFileError,
|
||||
ToolNodeError,
|
||||
ToolParameterError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import VariablePool
|
||||
|
||||
|
||||
class ToolNode(Node):
|
||||
"""
|
||||
Tool Node
|
||||
"""
|
||||
|
||||
node_type = NodeType.TOOL
|
||||
|
||||
_node_data: ToolNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = ToolNodeData.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the tool node
|
||||
"""
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError
|
||||
|
||||
node_data = self._node_data
|
||||
|
||||
# fetch tool icon
|
||||
tool_info = {
|
||||
"provider_type": node_data.provider_type.value,
|
||||
"provider_id": node_data.provider_id,
|
||||
"plugin_unique_identifier": node_data.plugin_unique_identifier,
|
||||
}
|
||||
|
||||
# get tool runtime
|
||||
try:
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use the node_data.version field for judgment
|
||||
# But for backward compatibility with historical data
|
||||
# this version field judgment is still preserved here.
|
||||
variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version != "1":
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
self.tenant_id, self.app_id, self._node_id, self._node_data, self.invoke_from, variable_pool
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to get tool runtime: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# get parameters
|
||||
tool_parameters = tool_runtime.get_merged_runtime_parameters() or []
|
||||
parameters = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self._node_data,
|
||||
)
|
||||
parameters_for_log = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self._node_data,
|
||||
for_log=True,
|
||||
)
|
||||
# get conversation id
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
try:
|
||||
message_stream = ToolEngine.generic_invoke(
|
||||
tool=tool_runtime,
|
||||
tool_parameters=parameters,
|
||||
user_id=self.user_id,
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
app_id=self.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
)
|
||||
except ToolNodeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to invoke tool: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# convert tool messages
|
||||
yield from self._transform_message(
|
||||
messages=message_stream,
|
||||
tool_info=tool_info,
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_id=self._node_id,
|
||||
)
|
||||
except ToolInvokeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to invoke tool {node_data.provider_name}: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
except PluginInvokeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error="An error occurred in the plugin, "
|
||||
f"please contact the author of {node_data.provider_name} for help, "
|
||||
f"error type: {e.get_error_type()}, "
|
||||
f"error details: {e.get_error_message()}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to invoke tool, error: {e.description}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_parameters(
|
||||
self,
|
||||
*,
|
||||
tool_parameters: Sequence[ToolParameter],
|
||||
variable_pool: "VariablePool",
|
||||
node_data: ToolNodeData,
|
||||
for_log: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
Args:
|
||||
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (ToolNodeData): The data associated with the tool node.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
parameter = tool_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
if variable is None:
|
||||
if parameter.required:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
continue
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(tool_input.value))
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
else:
|
||||
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
|
||||
result[parameter_name] = parameter_value
|
||||
|
||||
return result
|
||||
|
||||
def _fetch_files(self, variable_pool: "VariablePool") -> list[File]:
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
node_id: str,
|
||||
) -> Generator:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json: list[dict] = []
|
||||
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
ToolInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"Tool file {tool_file_id} does not exist")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileError(f"tool file {tool_file_id} not exists")
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
# JSON message handling for tool node
|
||||
if message.message.json_object is not None:
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
assert isinstance(message.meta, dict)
|
||||
# Validate that meta contains a 'file' key
|
||||
if "file" not in message.meta:
|
||||
raise ToolNodeError("File message is missing 'file' key in meta")
|
||||
|
||||
# Validate that the file is an instance of File
|
||||
if not isinstance(message.meta["file"], File):
|
||||
raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
|
||||
files.append(message.meta["file"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
||||
if message.message.metadata:
|
||||
icon = tool_info.get("icon", "")
|
||||
dict_metadata = dict(message.message.metadata)
|
||||
if dict_metadata.get("provider"):
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
icon_dark = None
|
||||
try:
|
||||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
icon = builtin_tool.icon
|
||||
icon_dark = builtin_tool.icon_dark
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
dict_metadata["icon"] = icon
|
||||
dict_metadata["icon_dark"] = icon_dark
|
||||
message.message.metadata = dict_metadata
|
||||
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: list[dict[str, Any]] = []
|
||||
|
||||
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
|
||||
if json:
|
||||
json_output.extend(json)
|
||||
else:
|
||||
json_output.append({"data": []})
|
||||
|
||||
# Send final chunk events for all streamed outputs
|
||||
# Final chunk for text stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Final chunks for any streamed variables
|
||||
for var_name in variables:
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, var_name],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
# Create typed NodeData from dict
|
||||
typed_node_data = ToolNodeData.model_validate(node_data)
|
||||
|
||||
result = {}
|
||||
for parameter_name in typed_node_data.tool_parameters:
|
||||
input = typed_node_data.tool_parameters[parameter_name]
|
||||
if input.type == "mixed":
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
result[parameter_name] = input.value
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
return result
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@property
|
||||
def retry(self) -> bool:
|
||||
return self._node_data.retry_config.retry_enabled
|
||||
|
|
@ -2,7 +2,7 @@ import logging
|
|||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
|
|
@ -318,7 +318,6 @@ class WorkflowEntry:
|
|||
# init graph
|
||||
graph = Graph.init(graph_config=graph_dict, node_factory=node_factory)
|
||||
|
||||
node_cls = cast(type[Node], node_cls)
|
||||
# init workflow run state
|
||||
node_config = {
|
||||
"id": node_id,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,445 @@
|
|||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer
|
||||
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
|
||||
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
|
||||
from factories import file_factory
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowEntry:
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
graph_config: Mapping[str, Any],
|
||||
graph: Graph,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
call_depth: int,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
command_channel: Optional[CommandChannel] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Init workflow entry
|
||||
:param tenant_id: tenant id
|
||||
:param app_id: app id
|
||||
:param workflow_id: workflow id
|
||||
:param workflow_type: workflow type
|
||||
:param graph_config: workflow graph config
|
||||
:param graph: workflow graph
|
||||
:param user_id: user id
|
||||
:param user_from: user from
|
||||
:param invoke_from: invoke from
|
||||
:param call_depth: call depth
|
||||
:param variable_pool: variable pool
|
||||
:param graph_runtime_state: pre-created graph runtime state
|
||||
:param command_channel: command channel for external control (optional, defaults to InMemoryChannel)
|
||||
:param thread_pool_id: thread pool id
|
||||
"""
|
||||
# check call depth
|
||||
workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
|
||||
if call_depth > workflow_call_max_depth:
|
||||
raise ValueError(f"Max workflow call depth {workflow_call_max_depth} reached.")
|
||||
|
||||
# Use provided command channel or default to InMemoryChannel
|
||||
if command_channel is None:
|
||||
command_channel = InMemoryChannel()
|
||||
|
||||
self.command_channel = command_channel
|
||||
self.graph_engine = GraphEngine(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_id=workflow_id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
graph=graph,
|
||||
graph_config=graph_config,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
# Add debug logging layer when in debug mode
|
||||
if dify_config.DEBUG:
|
||||
logger.info("Debug mode enabled - adding DebugLoggingLayer to GraphEngine")
|
||||
debug_layer = DebugLoggingLayer(
|
||||
level="DEBUG",
|
||||
include_inputs=True,
|
||||
include_outputs=True,
|
||||
include_process_data=False, # Process data can be very verbose
|
||||
logger_name=f"GraphEngine.Debug.{workflow_id[:8]}", # Use workflow ID prefix for unique logger
|
||||
)
|
||||
self.graph_engine.layer(debug_layer)
|
||||
|
||||
# Add execution limits layer
|
||||
limits_layer = ExecutionLimitsLayer(
|
||||
max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
||||
)
|
||||
self.graph_engine.layer(limits_layer)
|
||||
|
||||
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
graph_engine = self.graph_engine
|
||||
|
||||
try:
|
||||
# run workflow
|
||||
generator = graph_engine.run()
|
||||
yield from generator
|
||||
except GenerateTaskStoppedError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when workflow entry running")
|
||||
yield GraphRunFailedEvent(error=str(e))
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def single_step_run(
|
||||
cls,
|
||||
*,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
user_inputs: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
|
||||
"""
|
||||
Single step run workflow node
|
||||
:param workflow: Workflow instance
|
||||
:param node_id: node id
|
||||
:param user_id: user id
|
||||
:param user_inputs: user inputs
|
||||
:return:
|
||||
"""
|
||||
node_config = workflow.get_node_config_by_id(node_id)
|
||||
node_config_data = node_config.get("data", {})
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType(node_config_data.get("type"))
|
||||
node_version = node_config_data.get("version", "1")
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
|
||||
|
||||
# init graph init params and runtime state
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=workflow.graph_dict, node_factory=node_factory)
|
||||
|
||||
# init workflow run state
|
||||
node = node_cls(
|
||||
id=str(uuid.uuid4()),
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(node_config_data)
|
||||
|
||||
try:
|
||||
# variable selector to variable mapping
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
# Loading missing variable from draft var here, and set it into
|
||||
# variable_pool.
|
||||
load_into_variable_pool(
|
||||
variable_loader=variable_loader,
|
||||
variable_pool=variable_pool,
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
)
|
||||
|
||||
cls.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
)
|
||||
|
||||
try:
|
||||
# run node
|
||||
generator = node.run()
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"error while running node, workflow_id=%s, node_id=%s, node_type=%s, node_version=%s",
|
||||
workflow.id,
|
||||
node.id,
|
||||
node.node_type,
|
||||
node.version(),
|
||||
)
|
||||
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
|
||||
return node, generator
|
||||
|
||||
@staticmethod
|
||||
def _create_single_node_graph(
|
||||
node_id: str,
|
||||
node_data: dict[str, Any],
|
||||
node_width: int = 114,
|
||||
node_height: int = 514,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Create a minimal graph structure for testing a single node in isolation.
|
||||
|
||||
:param node_id: ID of the target node
|
||||
:param node_data: configuration data for the target node
|
||||
:param node_width: width for UI layout (default: 200)
|
||||
:param node_height: height for UI layout (default: 100)
|
||||
:return: graph dictionary with start node and target node
|
||||
"""
|
||||
node_config = {
|
||||
"id": node_id,
|
||||
"width": node_width,
|
||||
"height": node_height,
|
||||
"type": "custom",
|
||||
"data": node_data,
|
||||
}
|
||||
start_node_config = {
|
||||
"id": "start",
|
||||
"width": node_width,
|
||||
"height": node_height,
|
||||
"type": "custom",
|
||||
"data": {
|
||||
"type": NodeType.START.value,
|
||||
"title": "Start",
|
||||
"desc": "Start",
|
||||
},
|
||||
}
|
||||
return {
|
||||
"nodes": [start_node_config, node_config],
|
||||
"edges": [
|
||||
{
|
||||
"source": "start",
|
||||
"target": node_id,
|
||||
"sourceHandle": "source",
|
||||
"targetHandle": "target",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def run_free_node(
|
||||
cls, node_data: dict, node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
|
||||
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
|
||||
"""
|
||||
Run free node
|
||||
|
||||
NOTE: only parameter_extractor/question_classifier are supported
|
||||
|
||||
:param node_data: node data
|
||||
:param node_id: node id
|
||||
:param tenant_id: tenant id
|
||||
:param user_id: user id
|
||||
:param user_inputs: user inputs
|
||||
:return:
|
||||
"""
|
||||
# Create a minimal graph for single node execution
|
||||
graph_dict = cls._create_single_node_graph(node_id, node_data)
|
||||
|
||||
node_type = NodeType(node_data.get("type", ""))
|
||||
if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}:
|
||||
raise ValueError(f"Node type {node_type} not supported")
|
||||
|
||||
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type]["1"]
|
||||
if not node_cls:
|
||||
raise ValueError(f"Node class not found for node type {node_type}")
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.empty(),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
)
|
||||
|
||||
# init graph init params and runtime state
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id=tenant_id,
|
||||
app_id="",
|
||||
workflow_id="",
|
||||
graph_config=graph_dict,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=graph_dict, node_factory=node_factory)
|
||||
|
||||
node_cls = cast(type[Node], node_cls)
|
||||
# init workflow run state
|
||||
node_config = {
|
||||
"id": node_id,
|
||||
"data": node_data,
|
||||
}
|
||||
node: Node = node_cls(
|
||||
id=str(uuid.uuid4()),
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(node_data)
|
||||
|
||||
try:
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_dict, config=node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
cls.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# run node
|
||||
generator = node.run()
|
||||
|
||||
return node, generator
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"error while running node, node_id=%s, node_type=%s, node_version=%s",
|
||||
node.id,
|
||||
node.node_type,
|
||||
node.version(),
|
||||
)
|
||||
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
|
||||
|
||||
@staticmethod
|
||||
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
|
||||
# NOTE(QuantumGhost): Avoid using this function in new code.
|
||||
# Keep values structured as long as possible and only convert to dict
|
||||
# immediately before serialization (e.g., JSON serialization) to maintain
|
||||
# data integrity and type information.
|
||||
result = WorkflowEntry._handle_special_values(value)
|
||||
return result if isinstance(result, Mapping) or result is None else dict(result)
|
||||
|
||||
@staticmethod
|
||||
def _handle_special_values(value: Any) -> Any:
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
res = {}
|
||||
for k, v in value.items():
|
||||
res[k] = WorkflowEntry._handle_special_values(v)
|
||||
return res
|
||||
if isinstance(value, list):
|
||||
res_list = []
|
||||
for item in value:
|
||||
res_list.append(WorkflowEntry._handle_special_values(item))
|
||||
return res_list
|
||||
if isinstance(value, File):
|
||||
return value.to_dict()
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def mapping_user_inputs_to_variable_pool(
|
||||
cls,
|
||||
*,
|
||||
variable_mapping: Mapping[str, Sequence[str]],
|
||||
user_inputs: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
# NOTE(QuantumGhost): This logic should remain synchronized with
|
||||
# the implementation of `load_into_variable_pool`, specifically the logic about
|
||||
# variable existence checking.
|
||||
|
||||
# WARNING(QuantumGhost): The semantics of this method are not clearly defined,
|
||||
# and multiple parts of the codebase depend on its current behavior.
|
||||
# Modify with caution.
|
||||
for node_variable, variable_selector in variable_mapping.items():
|
||||
# fetch node id and variable key from node_variable
|
||||
node_variable_list = node_variable.split(".")
|
||||
if len(node_variable_list) < 1:
|
||||
raise ValueError(f"Invalid node variable {node_variable}")
|
||||
|
||||
node_variable_key = ".".join(node_variable_list[1:])
|
||||
|
||||
if (node_variable_key not in user_inputs and node_variable not in user_inputs) and not variable_pool.get(
|
||||
variable_selector
|
||||
):
|
||||
raise ValueError(f"Variable key {node_variable} not found in user inputs.")
|
||||
|
||||
# environment variable already exist in variable pool, not from user inputs
|
||||
if variable_pool.get(variable_selector):
|
||||
continue
|
||||
|
||||
# fetch variable node id from variable selector
|
||||
variable_node_id = variable_selector[0]
|
||||
variable_key_list = variable_selector[1:]
|
||||
variable_key_list = list(variable_key_list)
|
||||
|
||||
# get input value
|
||||
input_value = user_inputs.get(node_variable)
|
||||
if not input_value:
|
||||
input_value = user_inputs.get(node_variable_key)
|
||||
|
||||
if isinstance(input_value, dict) and "type" in input_value and "transfer_method" in input_value:
|
||||
input_value = file_factory.build_from_mapping(mapping=input_value, tenant_id=tenant_id)
|
||||
if (
|
||||
isinstance(input_value, list)
|
||||
and all(isinstance(item, dict) for item in input_value)
|
||||
and all("type" in item and "transfer_method" in item for item in input_value)
|
||||
):
|
||||
input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id)
|
||||
|
||||
# append variable and value to variable pool
|
||||
if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID:
|
||||
variable_pool.add([variable_node_id] + variable_key_list, input_value)
|
||||
|
|
@ -12,6 +12,7 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, Ch
|
|||
from core.entities.provider_entities import QuotaUnit, SystemConfiguration
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client, redis_fallback
|
||||
from libs import datetime_utils
|
||||
from models.model import Message
|
||||
from models.provider import Provider, ProviderType
|
||||
|
|
@ -19,6 +20,32 @@ from models.provider_ids import ModelProviderID
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis cache key prefix for provider last used timestamps
|
||||
_PROVIDER_LAST_USED_CACHE_PREFIX = "provider:last_used"
|
||||
# Default TTL for cache entries (10 minutes)
|
||||
_CACHE_TTL_SECONDS = 600
|
||||
LAST_USED_UPDATE_WINDOW_SECONDS = 60 * 5
|
||||
|
||||
|
||||
def _get_provider_cache_key(tenant_id: str, provider_name: str) -> str:
|
||||
"""Generate Redis cache key for provider last used timestamp."""
|
||||
return f"{_PROVIDER_LAST_USED_CACHE_PREFIX}:{tenant_id}:{provider_name}"
|
||||
|
||||
|
||||
@redis_fallback(default_return=None)
|
||||
def _get_last_update_timestamp(cache_key: str) -> Optional[datetime]:
|
||||
"""Get last update timestamp from Redis cache."""
|
||||
timestamp_str = redis_client.get(cache_key)
|
||||
if timestamp_str:
|
||||
return datetime.fromtimestamp(float(timestamp_str.decode("utf-8")))
|
||||
return None
|
||||
|
||||
|
||||
@redis_fallback()
|
||||
def _set_last_update_timestamp(cache_key: str, timestamp: datetime) -> None:
|
||||
"""Set last update timestamp in Redis cache with TTL."""
|
||||
redis_client.setex(cache_key, _CACHE_TTL_SECONDS, str(timestamp.timestamp()))
|
||||
|
||||
|
||||
class _ProviderUpdateFilters(BaseModel):
|
||||
"""Filters for identifying Provider records to update."""
|
||||
|
|
@ -139,7 +166,7 @@ def handle(sender: Message, **kwargs):
|
|||
provider_name,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
# Log failure with timing and context
|
||||
duration = time_module.perf_counter() - start_time
|
||||
|
||||
|
|
@ -215,8 +242,23 @@ def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]
|
|||
|
||||
# Prepare values dict for SQLAlchemy update
|
||||
update_values = {}
|
||||
# updateing to `last_used` is removed due to performance reason.
|
||||
# ref: https://github.com/langgenius/dify/issues/24526
|
||||
|
||||
# NOTE: For frequently used providers under high load, this implementation may experience
|
||||
# race conditions or update contention despite the time-window optimization:
|
||||
# 1. Multiple concurrent requests might check the same cache key simultaneously
|
||||
# 2. Redis cache operations are not atomic with database updates
|
||||
# 3. Heavy providers could still face database lock contention during peak usage
|
||||
# The current implementation is acceptable for most scenarios, but future optimization
|
||||
# considerations could include: batched updates, or async processing.
|
||||
if values.last_used is not None:
|
||||
cache_key = _get_provider_cache_key(filters.tenant_id, filters.provider_name)
|
||||
now = datetime_utils.naive_utc_now()
|
||||
last_update = _get_last_update_timestamp(cache_key)
|
||||
|
||||
if last_update is None or (now - last_update).total_seconds() > LAST_USED_UPDATE_WINDOW_SECONDS:
|
||||
update_values["last_used"] = values.last_used
|
||||
_set_last_update_timestamp(cache_key, now)
|
||||
|
||||
if values.quota_used is not None:
|
||||
update_values["quota_used"] = values.quota_used
|
||||
# Skip the current update operation if no updates are required.
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import os
|
|||
import urllib.parse
|
||||
import uuid
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import select
|
||||
|
|
@ -258,7 +258,6 @@ def _get_remote_file_info(url: str):
|
|||
mime_type = ""
|
||||
|
||||
resp = ssrf_proxy.head(url, follow_redirects=True)
|
||||
resp = cast(httpx.Response, resp)
|
||||
if resp.status_code == httpx.codes.OK:
|
||||
if content_disposition := resp.headers.get("Content-Disposition"):
|
||||
filename = str(content_disposition.split("filename=")[-1].strip('"'))
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class EnvironmentVariableField(fields.Raw):
|
|||
return {
|
||||
"id": value.id,
|
||||
"name": value.name,
|
||||
"value": encrypter.obfuscated_token(value.value),
|
||||
"value": encrypter.full_mask_token(),
|
||||
"value_type": value.value_type.value,
|
||||
"description": value.description,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,32 @@
|
|||
"""chore: add workflow app log run id index
|
||||
|
||||
Revision ID: b95962a3885c
|
||||
Revises: 0e154742a5fa
|
||||
Create Date: 2025-08-29 15:34:09.838623
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'b95962a3885c'
|
||||
down_revision = '8d289573e1da'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_app_logs', schema=None) as batch_op:
|
||||
batch_op.create_index('workflow_app_log_workflow_run_id_idx', ['workflow_run_id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_app_logs', schema=None) as batch_op:
|
||||
batch_op.drop_index('workflow_app_log_workflow_run_id_idx')
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -319,7 +319,7 @@ class MCPToolProvider(Base):
|
|||
|
||||
@property
|
||||
def decrypted_server_url(self) -> str:
|
||||
return cast(str, encrypter.decrypt_token(self.tenant_id, self.server_url))
|
||||
return encrypter.decrypt_token(self.tenant_id, self.server_url)
|
||||
|
||||
@property
|
||||
def masked_server_url(self) -> str:
|
||||
|
|
|
|||
|
|
@ -835,6 +835,7 @@ class WorkflowAppLog(Base):
|
|||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"),
|
||||
sa.Index("workflow_app_log_app_idx", "tenant_id", "app_id"),
|
||||
sa.Index("workflow_app_log_workflow_run_id_idx", "workflow_run_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ class AccountService:
|
|||
account.last_active_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
return cast(Account, account)
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def get_account_jwt_token(account: Account) -> str:
|
||||
|
|
@ -191,7 +191,7 @@ class AccountService:
|
|||
|
||||
db.session.commit()
|
||||
|
||||
return cast(Account, account)
|
||||
return account
|
||||
|
||||
@staticmethod
|
||||
def update_account_password(account, password, new_password):
|
||||
|
|
@ -1127,7 +1127,7 @@ class TenantService:
|
|||
def get_custom_config(tenant_id: str) -> dict:
|
||||
tenant = db.get_or_404(Tenant, tenant_id)
|
||||
|
||||
return cast(dict, tenant.custom_config_dict)
|
||||
return tenant.custom_config_dict
|
||||
|
||||
@staticmethod
|
||||
def is_owner(account: Account, tenant: Tenant) -> bool:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import uuid
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
from flask_login import current_user
|
||||
|
|
@ -40,7 +40,7 @@ class AppAnnotationService:
|
|||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
annotation = message.annotation
|
||||
annotation: Optional[MessageAnnotation] = message.annotation
|
||||
# save the message annotation
|
||||
if annotation:
|
||||
annotation.content = args["answer"]
|
||||
|
|
@ -70,7 +70,7 @@ class AppAnnotationService:
|
|||
app_id,
|
||||
annotation_setting.collection_binding_id,
|
||||
)
|
||||
return cast(MessageAnnotation, annotation)
|
||||
return annotation
|
||||
|
||||
@classmethod
|
||||
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:
|
||||
|
|
|
|||
|
|
@ -1149,7 +1149,7 @@ class DocumentService:
|
|||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
|
@ -1612,7 +1612,7 @@ class DocumentService:
|
|||
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
reranking_enable=False,
|
||||
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
|
||||
top_k=2,
|
||||
top_k=4,
|
||||
score_threshold_enabled=False,
|
||||
)
|
||||
# save dataset
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ default_retrieval_model = {
|
|||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
|
@ -66,7 +66,7 @@ class HitTestingService:
|
|||
retrieval_method=retrieval_model.get("search_method", "semantic_search"),
|
||||
dataset_id=dataset.id,
|
||||
query=query,
|
||||
top_k=retrieval_model.get("top_k", 2),
|
||||
top_k=retrieval_model.get("top_k", 4),
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
|
|
|
|||
|
|
@ -112,7 +112,9 @@ class MessageService:
|
|||
base_query = base_query.where(Message.conversation_id == conversation.id)
|
||||
|
||||
# Check if include_ids is not None and not empty to avoid WHERE false condition
|
||||
if include_ids is not None and len(include_ids) > 0:
|
||||
if include_ids is not None:
|
||||
if len(include_ids) == 0:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
base_query = base_query.where(Message.id.in_(include_ids))
|
||||
|
||||
if last_id:
|
||||
|
|
|
|||
|
|
@ -27,73 +27,73 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
|
|||
documents = []
|
||||
start_at = time.perf_counter()
|
||||
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if dataset is None:
|
||||
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
count = len(document_ids)
|
||||
if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
raise ValueError(
|
||||
"Your total number of documents plus the number of uploads have over the limit of "
|
||||
"your subscription."
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if dataset is None:
|
||||
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
count = len(document_ids)
|
||||
if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
raise ValueError(
|
||||
"Your total number of documents plus the number of uploads have over the limit of "
|
||||
"your subscription."
|
||||
)
|
||||
except Exception as e:
|
||||
for document_id in document_ids:
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.where(Document.id == document_id, Document.dataset_id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
except Exception as e:
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
return
|
||||
|
||||
for document_id in document_ids:
|
||||
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
|
||||
|
||||
document = (
|
||||
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
# clean old data
|
||||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
documents.append(document)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
return
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
for document_id in document_ids:
|
||||
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
|
||||
|
||||
document = (
|
||||
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
|
||||
if document:
|
||||
# clean old data
|
||||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
documents.append(document)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import time
|
||||
import uuid
|
||||
from os import getenv
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -11,7 +10,6 @@ from core.workflow.enums import WorkflowNodeExecutionStatus
|
|||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.entities import CodeNodeData
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
|
@ -242,8 +240,6 @@ def test_execute_code_output_validator_depth():
|
|||
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
|
||||
}
|
||||
|
||||
node._node_data = cast(CodeNodeData, node._node_data)
|
||||
|
||||
# validate
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
|
|
@ -338,8 +334,6 @@ def test_execute_code_output_object_list():
|
|||
]
|
||||
}
|
||||
|
||||
node._node_data = cast(CodeNodeData, node._node_data)
|
||||
|
||||
# validate
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,390 @@
|
|||
import time
|
||||
import uuid
|
||||
from os import getenv
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.node_factory import DifyNodeFactory
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
CODE_MAX_STRING_LENGTH = int(getenv("CODE_MAX_STRING_LENGTH", "10000"))
|
||||
|
||||
|
||||
def init_code_node(code_config: dict):
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-code-target",
|
||||
"source": "start",
|
||||
"target": "code",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, code_config],
|
||||
}
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="aaa", files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(["code", "args1"], 1)
|
||||
variable_pool.add(["code", "args2"], 2)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = CodeNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=code_config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
if "data" in code_config:
|
||||
node.init_node_data(code_config["data"])
|
||||
|
||||
return node
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
||||
def test_execute_code(setup_code_executor_mock):
|
||||
code = """
|
||||
def main(args1: int, args2: int) -> dict:
|
||||
return {
|
||||
"result": args1 + args2,
|
||||
}
|
||||
"""
|
||||
# trim first 4 spaces at the beginning of each line
|
||||
code = "\n".join([line[4:] for line in code.split("\n")])
|
||||
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "number",
|
||||
},
|
||||
},
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "args2"]},
|
||||
],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
"code": code,
|
||||
},
|
||||
}
|
||||
|
||||
node = init_code_node(code_config)
|
||||
node.graph_runtime_state.variable_pool.add(["1", "args1"], 1)
|
||||
node.graph_runtime_state.variable_pool.add(["1", "args2"], 2)
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["result"] == 3
|
||||
assert result.error == ""
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
||||
def test_execute_code_output_validator(setup_code_executor_mock):
|
||||
code = """
|
||||
def main(args1: int, args2: int) -> dict:
|
||||
return {
|
||||
"result": args1 + args2,
|
||||
}
|
||||
"""
|
||||
# trim first 4 spaces at the beginning of each line
|
||||
code = "\n".join([line[4:] for line in code.split("\n")])
|
||||
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "args2"]},
|
||||
],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
"code": code,
|
||||
},
|
||||
}
|
||||
|
||||
node = init_code_node(code_config)
|
||||
node.graph_runtime_state.variable_pool.add(["1", "args1"], 1)
|
||||
node.graph_runtime_state.variable_pool.add(["1", "args2"], 2)
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert result.error == "Output variable `result` must be a string"
|
||||
|
||||
|
||||
def test_execute_code_output_validator_depth():
|
||||
code = """
|
||||
def main(args1: int, args2: int) -> dict:
|
||||
return {
|
||||
"result": {
|
||||
"result": args1 + args2,
|
||||
}
|
||||
}
|
||||
"""
|
||||
# trim first 4 spaces at the beginning of each line
|
||||
code = "\n".join([line[4:] for line in code.split("\n")])
|
||||
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"outputs": {
|
||||
"string_validator": {
|
||||
"type": "string",
|
||||
},
|
||||
"number_validator": {
|
||||
"type": "number",
|
||||
},
|
||||
"number_array_validator": {
|
||||
"type": "array[number]",
|
||||
},
|
||||
"string_array_validator": {
|
||||
"type": "array[string]",
|
||||
},
|
||||
"object_validator": {
|
||||
"type": "object",
|
||||
"children": {
|
||||
"result": {
|
||||
"type": "number",
|
||||
},
|
||||
"depth": {
|
||||
"type": "object",
|
||||
"children": {
|
||||
"depth": {
|
||||
"type": "object",
|
||||
"children": {
|
||||
"depth": {
|
||||
"type": "number",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "args2"]},
|
||||
],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
"code": code,
|
||||
},
|
||||
}
|
||||
|
||||
node = init_code_node(code_config)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
"number_validator": 1,
|
||||
"string_validator": "1",
|
||||
"number_array_validator": [1, 2, 3, 3.333],
|
||||
"string_array_validator": ["1", "2", "3"],
|
||||
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
|
||||
}
|
||||
|
||||
# validate
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
"number_validator": "1",
|
||||
"string_validator": 1,
|
||||
"number_array_validator": ["1", "2", "3", "3.333"],
|
||||
"string_array_validator": [1, 2, 3],
|
||||
"object_validator": {"result": "1", "depth": {"depth": {"depth": "1"}}},
|
||||
}
|
||||
|
||||
# validate
|
||||
with pytest.raises(ValueError):
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
"number_validator": 1,
|
||||
"string_validator": (CODE_MAX_STRING_LENGTH + 1) * "1",
|
||||
"number_array_validator": [1, 2, 3, 3.333],
|
||||
"string_array_validator": ["1", "2", "3"],
|
||||
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
|
||||
}
|
||||
|
||||
# validate
|
||||
with pytest.raises(ValueError):
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
"number_validator": 1,
|
||||
"string_validator": "1",
|
||||
"number_array_validator": [1, 2, 3, 3.333] * 2000,
|
||||
"string_array_validator": ["1", "2", "3"],
|
||||
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
|
||||
}
|
||||
|
||||
# validate
|
||||
with pytest.raises(ValueError):
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
|
||||
def test_execute_code_output_object_list():
|
||||
code = """
|
||||
def main(args1: int, args2: int) -> dict:
|
||||
return {
|
||||
"result": {
|
||||
"result": args1 + args2,
|
||||
}
|
||||
}
|
||||
"""
|
||||
# trim first 4 spaces at the beginning of each line
|
||||
code = "\n".join([line[4:] for line in code.split("\n")])
|
||||
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"outputs": {
|
||||
"object_list": {
|
||||
"type": "array[object]",
|
||||
},
|
||||
},
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "args2"]},
|
||||
],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
"code": code,
|
||||
},
|
||||
}
|
||||
|
||||
node = init_code_node(code_config)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
"object_list": [
|
||||
{
|
||||
"result": 1,
|
||||
},
|
||||
{
|
||||
"result": 2,
|
||||
},
|
||||
{
|
||||
"result": [1, 2, 3],
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
# validate
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
# construct result
|
||||
result = {
|
||||
"object_list": [
|
||||
{
|
||||
"result": 1,
|
||||
},
|
||||
{
|
||||
"result": 2,
|
||||
},
|
||||
{
|
||||
"result": [1, 2, 3],
|
||||
},
|
||||
1,
|
||||
]
|
||||
}
|
||||
|
||||
# validate
|
||||
with pytest.raises(ValueError):
|
||||
node._transform_result(result, node._node_data.outputs)
|
||||
|
||||
|
||||
def test_execute_code_scientific_notation():
|
||||
code = """
|
||||
def main() -> dict:
|
||||
return {
|
||||
"result": -8.0E-5
|
||||
}
|
||||
"""
|
||||
code = "\n".join([line[4:] for line in code.split("\n")])
|
||||
|
||||
code_config = {
|
||||
"id": "code",
|
||||
"data": {
|
||||
"outputs": {
|
||||
"result": {
|
||||
"type": "number",
|
||||
},
|
||||
},
|
||||
"title": "123",
|
||||
"variables": [],
|
||||
"answer": "123",
|
||||
"code_language": "python3",
|
||||
"code": code,
|
||||
},
|
||||
}
|
||||
|
||||
node = init_code_node(code_config)
|
||||
# execute node
|
||||
result = node._run()
|
||||
assert isinstance(result, NodeRunResult)
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
|
|
@ -0,0 +1,788 @@
|
|||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.tools.entities.api_entities import ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
|
||||
class TestToolTransformService:
|
||||
"""Integration tests for ToolTransformService using testcontainers."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.tools.tools_transform_service.dify_config") as mock_dify_config,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_dify_config.CONSOLE_API_URL = "https://console.example.com"
|
||||
|
||||
yield {
|
||||
"dify_config": mock_dify_config,
|
||||
}
|
||||
|
||||
def _create_test_tool_provider(
|
||||
self, db_session_with_containers, mock_external_service_dependencies, provider_type="api"
|
||||
):
|
||||
"""
|
||||
Helper method to create a test tool provider for testing.
|
||||
|
||||
Args:
|
||||
db_session_with_containers: Database session from testcontainers infrastructure
|
||||
mock_external_service_dependencies: Mock dependencies
|
||||
provider_type: Type of provider to create
|
||||
|
||||
Returns:
|
||||
Tool provider instance
|
||||
"""
|
||||
fake = Faker()
|
||||
|
||||
if provider_type == "api":
|
||||
provider = ApiToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
icon_dark='{"background": "#252525", "content": "🔧"}',
|
||||
tenant_id="test_tenant_id",
|
||||
user_id="test_user_id",
|
||||
credentials={"auth_type": "api_key_header", "api_key": "test_key"},
|
||||
provider_type="api",
|
||||
)
|
||||
elif provider_type == "builtin":
|
||||
provider = BuiltinToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
icon="🔧",
|
||||
icon_dark="🔧",
|
||||
tenant_id="test_tenant_id",
|
||||
provider="test_provider",
|
||||
credential_type="api_key",
|
||||
credentials={"api_key": "test_key"},
|
||||
)
|
||||
elif provider_type == "workflow":
|
||||
provider = WorkflowToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
icon_dark='{"background": "#252525", "content": "🔧"}',
|
||||
tenant_id="test_tenant_id",
|
||||
user_id="test_user_id",
|
||||
workflow_id="test_workflow_id",
|
||||
)
|
||||
elif provider_type == "mcp":
|
||||
provider = MCPToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
provider_icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
tenant_id="test_tenant_id",
|
||||
user_id="test_user_id",
|
||||
server_url="https://mcp.example.com",
|
||||
server_identifier="test_server",
|
||||
tools='[{"name": "test_tool", "description": "Test tool"}]',
|
||||
authed=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider type: {provider_type}")
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
|
||||
return provider
|
||||
|
||||
def test_get_plugin_icon_url_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful plugin icon URL generation.
|
||||
|
||||
This test verifies:
|
||||
- Proper URL construction for plugin icons
|
||||
- Correct tenant_id and filename handling
|
||||
- URL format compliance
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
tenant_id = fake.uuid4()
|
||||
filename = "test_icon.png"
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.get_plugin_icon_url(tenant_id, filename)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert "console/api/workspaces/current/plugin/icon" in result
|
||||
assert tenant_id in result
|
||||
assert filename in result
|
||||
assert result.startswith("https://console.example.com")
|
||||
|
||||
# Verify URL structure
|
||||
expected_url = f"https://console.example.com/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={filename}"
|
||||
assert result == expected_url
|
||||
|
||||
def test_get_plugin_icon_url_with_empty_console_url(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test plugin icon URL generation when CONSOLE_API_URL is empty.
|
||||
|
||||
This test verifies:
|
||||
- Fallback to relative URL when CONSOLE_API_URL is None
|
||||
- Proper URL construction with relative path
|
||||
"""
|
||||
# Arrange: Setup mock with empty console URL
|
||||
mock_external_service_dependencies["dify_config"].CONSOLE_API_URL = None
|
||||
fake = Faker()
|
||||
tenant_id = fake.uuid4()
|
||||
filename = "test_icon.png"
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.get_plugin_icon_url(tenant_id, filename)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert result.startswith("/console/api/workspaces/current/plugin/icon")
|
||||
assert tenant_id in result
|
||||
assert filename in result
|
||||
|
||||
# Verify URL structure
|
||||
expected_url = f"/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={filename}"
|
||||
assert result == expected_url
|
||||
|
||||
def test_get_tool_provider_icon_url_builtin_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful tool provider icon URL generation for builtin providers.
|
||||
|
||||
This test verifies:
|
||||
- Proper URL construction for builtin tool providers
|
||||
- Correct provider type handling
|
||||
- URL format compliance
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
provider_type = ToolProviderType.BUILT_IN.value
|
||||
provider_name = fake.company()
|
||||
icon = "🔧"
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert "console/api/workspaces/current/tool-provider/builtin" in result
|
||||
# Note: provider_name may contain spaces that get URL encoded
|
||||
assert provider_name.replace(" ", "%20") in result or provider_name in result
|
||||
assert result.endswith("/icon")
|
||||
assert result.startswith("https://console.example.com")
|
||||
|
||||
# Verify URL structure (accounting for URL encoding)
|
||||
# The actual result will have URL-encoded spaces (%20), so we need to compare accordingly
|
||||
expected_url = (
|
||||
f"https://console.example.com/console/api/workspaces/current/tool-provider/builtin/{provider_name}/icon"
|
||||
)
|
||||
# Convert expected URL to match the actual URL encoding
|
||||
expected_encoded = expected_url.replace(" ", "%20")
|
||||
assert result == expected_encoded
|
||||
|
||||
def test_get_tool_provider_icon_url_api_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful tool provider icon URL generation for API providers.
|
||||
|
||||
This test verifies:
|
||||
- Proper icon handling for API tool providers
|
||||
- JSON string parsing for icon data
|
||||
- Fallback icon when parsing fails
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
provider_type = ToolProviderType.API.value
|
||||
provider_name = fake.company()
|
||||
icon = '{"background": "#FF6B6B", "content": "🔧"}'
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert isinstance(result, dict)
|
||||
assert result["background"] == "#FF6B6B"
|
||||
assert result["content"] == "🔧"
|
||||
|
||||
def test_get_tool_provider_icon_url_api_invalid_json(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tool provider icon URL generation for API providers with invalid JSON.
|
||||
|
||||
This test verifies:
|
||||
- Proper fallback when JSON parsing fails
|
||||
- Default icon structure when exception occurs
|
||||
"""
|
||||
# Arrange: Setup test data with invalid JSON
|
||||
fake = Faker()
|
||||
provider_type = ToolProviderType.API.value
|
||||
provider_name = fake.company()
|
||||
icon = '{"invalid": json}'
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert isinstance(result, dict)
|
||||
assert result["background"] == "#252525"
|
||||
# Note: emoji characters may be represented as Unicode escape sequences
|
||||
assert result["content"] == "😁" or result["content"] == "\ud83d\ude01"
|
||||
|
||||
def test_get_tool_provider_icon_url_workflow_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful tool provider icon URL generation for workflow providers.
|
||||
|
||||
This test verifies:
|
||||
- Proper icon handling for workflow tool providers
|
||||
- Direct icon return for workflow type
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
provider_type = ToolProviderType.WORKFLOW.value
|
||||
provider_name = fake.company()
|
||||
icon = {"background": "#FF6B6B", "content": "🔧"}
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert isinstance(result, dict)
|
||||
assert result["background"] == "#FF6B6B"
|
||||
assert result["content"] == "🔧"
|
||||
|
||||
def test_get_tool_provider_icon_url_mcp_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful tool provider icon URL generation for MCP providers.
|
||||
|
||||
This test verifies:
|
||||
- Direct icon return for MCP type
|
||||
- No URL transformation for MCP providers
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
provider_type = ToolProviderType.MCP.value
|
||||
provider_name = fake.company()
|
||||
icon = {"background": "#FF6B6B", "content": "🔧"}
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert isinstance(result, dict)
|
||||
assert result["background"] == "#FF6B6B"
|
||||
assert result["content"] == "🔧"
|
||||
|
||||
def test_get_tool_provider_icon_url_unknown_type(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test tool provider icon URL generation for unknown provider types.
|
||||
|
||||
This test verifies:
|
||||
- Empty string return for unknown provider types
|
||||
- Proper handling of unsupported types
|
||||
"""
|
||||
# Arrange: Setup test data with unknown type
|
||||
fake = Faker()
|
||||
provider_type = "unknown_type"
|
||||
provider_name = fake.company()
|
||||
icon = "🔧"
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.get_tool_provider_icon_url(provider_type, provider_name, icon)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result == ""
|
||||
|
||||
def test_repack_provider_dict_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful provider repacking with dictionary input.
|
||||
|
||||
This test verifies:
|
||||
- Proper icon URL generation for dictionary providers
|
||||
- Correct provider type handling
|
||||
- Icon transformation for different provider types
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
tenant_id = fake.uuid4()
|
||||
provider = {"type": ToolProviderType.BUILT_IN.value, "name": fake.company(), "icon": "🔧"}
|
||||
|
||||
# Act: Execute the method under test
|
||||
ToolTransformService.repack_provider(tenant_id, provider)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert "icon" in provider
|
||||
assert isinstance(provider["icon"], str)
|
||||
assert "console/api/workspaces/current/tool-provider/builtin" in provider["icon"]
|
||||
# Note: provider name may contain spaces that get URL encoded
|
||||
assert provider["name"].replace(" ", "%20") in provider["icon"] or provider["name"] in provider["icon"]
|
||||
|
||||
def test_repack_provider_entity_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful provider repacking with ToolProviderApiEntity input.
|
||||
|
||||
This test verifies:
|
||||
- Proper icon URL generation for entity providers
|
||||
- Plugin icon handling when plugin_id is present
|
||||
- Regular icon handling when plugin_id is not present
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
tenant_id = fake.uuid4()
|
||||
|
||||
# Create provider entity with plugin_id
|
||||
provider = ToolProviderApiEntity(
|
||||
id=fake.uuid4(),
|
||||
author=fake.name(),
|
||||
name=fake.company(),
|
||||
description=I18nObject(en_US=fake.text(max_nb_chars=100)),
|
||||
icon="test_icon.png",
|
||||
icon_dark="test_icon_dark.png",
|
||||
label=I18nObject(en_US=fake.company()),
|
||||
type=ToolProviderType.API,
|
||||
masked_credentials={},
|
||||
is_team_authorization=True,
|
||||
plugin_id="test_plugin_id",
|
||||
tools=[],
|
||||
labels=[],
|
||||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
ToolTransformService.repack_provider(tenant_id, provider)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert provider.icon is not None
|
||||
assert isinstance(provider.icon, str)
|
||||
assert "console/api/workspaces/current/plugin/icon" in provider.icon
|
||||
assert tenant_id in provider.icon
|
||||
assert "test_icon.png" in provider.icon
|
||||
|
||||
# Verify dark icon handling
|
||||
assert provider.icon_dark is not None
|
||||
assert isinstance(provider.icon_dark, str)
|
||||
assert "console/api/workspaces/current/plugin/icon" in provider.icon_dark
|
||||
assert tenant_id in provider.icon_dark
|
||||
assert "test_icon_dark.png" in provider.icon_dark
|
||||
|
||||
def test_repack_provider_entity_no_plugin_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful provider repacking with ToolProviderApiEntity input without plugin_id.
|
||||
|
||||
This test verifies:
|
||||
- Proper icon URL generation for non-plugin providers
|
||||
- Regular tool provider icon handling
|
||||
- Dark icon handling when present
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
tenant_id = fake.uuid4()
|
||||
|
||||
# Create provider entity without plugin_id
|
||||
provider = ToolProviderApiEntity(
|
||||
id=fake.uuid4(),
|
||||
author=fake.name(),
|
||||
name=fake.company(),
|
||||
description=I18nObject(en_US=fake.text(max_nb_chars=100)),
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
icon_dark='{"background": "#252525", "content": "🔧"}',
|
||||
label=I18nObject(en_US=fake.company()),
|
||||
type=ToolProviderType.API,
|
||||
masked_credentials={},
|
||||
is_team_authorization=True,
|
||||
plugin_id=None,
|
||||
tools=[],
|
||||
labels=[],
|
||||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
ToolTransformService.repack_provider(tenant_id, provider)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert provider.icon is not None
|
||||
assert isinstance(provider.icon, dict)
|
||||
assert provider.icon["background"] == "#FF6B6B"
|
||||
assert provider.icon["content"] == "🔧"
|
||||
|
||||
# Verify dark icon handling
|
||||
assert provider.icon_dark is not None
|
||||
assert isinstance(provider.icon_dark, dict)
|
||||
assert provider.icon_dark["background"] == "#252525"
|
||||
assert provider.icon_dark["content"] == "🔧"
|
||||
|
||||
def test_repack_provider_entity_no_dark_icon(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test provider repacking with ToolProviderApiEntity input without dark icon.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling when icon_dark is None or empty
|
||||
- No errors when dark icon is not present
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
tenant_id = fake.uuid4()
|
||||
|
||||
# Create provider entity without dark icon
|
||||
provider = ToolProviderApiEntity(
|
||||
id=fake.uuid4(),
|
||||
author=fake.name(),
|
||||
name=fake.company(),
|
||||
description=I18nObject(en_US=fake.text(max_nb_chars=100)),
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
icon_dark=None,
|
||||
label=I18nObject(en_US=fake.company()),
|
||||
type=ToolProviderType.API,
|
||||
masked_credentials={},
|
||||
is_team_authorization=True,
|
||||
plugin_id=None,
|
||||
tools=[],
|
||||
labels=[],
|
||||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
ToolTransformService.repack_provider(tenant_id, provider)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert provider.icon is not None
|
||||
assert isinstance(provider.icon, dict)
|
||||
assert provider.icon["background"] == "#FF6B6B"
|
||||
assert provider.icon["content"] == "🔧"
|
||||
|
||||
# Verify dark icon remains None
|
||||
assert provider.icon_dark is None
|
||||
|
||||
def test_builtin_provider_to_user_provider_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion of builtin provider to user provider.
|
||||
|
||||
This test verifies:
|
||||
- Proper entity creation with all required fields
|
||||
- Credentials schema handling
|
||||
- Team authorization setup
|
||||
- Plugin ID handling
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
|
||||
# Create mock provider controller
|
||||
mock_controller = Mock()
|
||||
mock_controller.entity.identity.name = fake.company()
|
||||
mock_controller.entity.identity.author = fake.name()
|
||||
mock_controller.entity.identity.description = I18nObject(en_US=fake.text(max_nb_chars=100))
|
||||
mock_controller.entity.identity.icon = "🔧"
|
||||
mock_controller.entity.identity.icon_dark = "🔧"
|
||||
mock_controller.entity.identity.label = I18nObject(en_US=fake.company())
|
||||
mock_controller.plugin_id = None
|
||||
mock_controller.plugin_unique_identifier = None
|
||||
mock_controller.tool_labels = ["label1", "label2"]
|
||||
mock_controller.need_credentials = True
|
||||
|
||||
# Mock credentials schema
|
||||
mock_credential = Mock()
|
||||
mock_credential.to_basic_provider_config.return_value.name = "api_key"
|
||||
mock_controller.get_credentials_schema_by_type.return_value = [mock_credential]
|
||||
|
||||
# Create mock database provider
|
||||
mock_db_provider = Mock()
|
||||
mock_db_provider.credential_type = "api-key"
|
||||
mock_db_provider.tenant_id = fake.uuid4()
|
||||
mock_db_provider.credentials = {"api_key": "encrypted_key"}
|
||||
|
||||
# Mock encryption
|
||||
with patch("services.tools.tools_transform_service.create_provider_encrypter") as mock_encrypter:
|
||||
mock_encrypter_instance = Mock()
|
||||
mock_encrypter_instance.decrypt.return_value = {"api_key": "decrypted_key"}
|
||||
mock_encrypter_instance.mask_tool_credentials.return_value = {"api_key": ""}
|
||||
mock_encrypter.return_value = (mock_encrypter_instance, None)
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.builtin_provider_to_user_provider(
|
||||
mock_controller, mock_db_provider, decrypt_credentials=True
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result.id == mock_controller.entity.identity.name
|
||||
assert result.author == mock_controller.entity.identity.author
|
||||
assert result.name == mock_controller.entity.identity.name
|
||||
assert result.description == mock_controller.entity.identity.description
|
||||
assert result.icon == mock_controller.entity.identity.icon
|
||||
assert result.icon_dark == mock_controller.entity.identity.icon_dark
|
||||
assert result.label == mock_controller.entity.identity.label
|
||||
assert result.type == ToolProviderType.BUILT_IN
|
||||
assert result.is_team_authorization is True
|
||||
assert result.plugin_id is None
|
||||
assert result.tools == []
|
||||
assert result.labels == ["label1", "label2"]
|
||||
assert result.masked_credentials == {"api_key": ""}
|
||||
assert result.original_credentials == {"api_key": "decrypted_key"}
|
||||
|
||||
def test_builtin_provider_to_user_provider_plugin_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion of builtin provider to user provider with plugin.
|
||||
|
||||
This test verifies:
|
||||
- Plugin ID and unique identifier handling
|
||||
- Proper entity creation for plugin providers
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
|
||||
# Create mock provider controller with plugin
|
||||
mock_controller = Mock()
|
||||
mock_controller.entity.identity.name = fake.company()
|
||||
mock_controller.entity.identity.author = fake.name()
|
||||
mock_controller.entity.identity.description = I18nObject(en_US=fake.text(max_nb_chars=100))
|
||||
mock_controller.entity.identity.icon = "🔧"
|
||||
mock_controller.entity.identity.icon_dark = "🔧"
|
||||
mock_controller.entity.identity.label = I18nObject(en_US=fake.company())
|
||||
mock_controller.plugin_id = "test_plugin_id"
|
||||
mock_controller.plugin_unique_identifier = "test_unique_id"
|
||||
mock_controller.tool_labels = ["label1"]
|
||||
mock_controller.need_credentials = False
|
||||
|
||||
# Mock credentials schema
|
||||
mock_credential = Mock()
|
||||
mock_credential.to_basic_provider_config.return_value.name = "api_key"
|
||||
mock_controller.get_credentials_schema_by_type.return_value = [mock_credential]
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.builtin_provider_to_user_provider(
|
||||
mock_controller, None, decrypt_credentials=False
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
# Note: The method checks isinstance(provider_controller, PluginToolProviderController)
|
||||
# Since we're using a Mock, this check will fail, so plugin_id will remain None
|
||||
# In a real test with actual PluginToolProviderController, this would work
|
||||
assert result.is_team_authorization is True
|
||||
assert result.allow_delete is False
|
||||
|
||||
def test_builtin_provider_to_user_provider_no_credentials(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test conversion of builtin provider to user provider without credentials.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling when no credentials are needed
|
||||
- Team authorization setup for no-credentials providers
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
|
||||
# Create mock provider controller
|
||||
mock_controller = Mock()
|
||||
mock_controller.entity.identity.name = fake.company()
|
||||
mock_controller.entity.identity.author = fake.name()
|
||||
mock_controller.entity.identity.description = I18nObject(en_US=fake.text(max_nb_chars=100))
|
||||
mock_controller.entity.identity.icon = "🔧"
|
||||
mock_controller.entity.identity.icon_dark = "🔧"
|
||||
mock_controller.entity.identity.label = I18nObject(en_US=fake.company())
|
||||
mock_controller.plugin_id = None
|
||||
mock_controller.plugin_unique_identifier = None
|
||||
mock_controller.tool_labels = []
|
||||
mock_controller.need_credentials = False
|
||||
|
||||
# Mock credentials schema
|
||||
mock_credential = Mock()
|
||||
mock_credential.to_basic_provider_config.return_value.name = "api_key"
|
||||
mock_controller.get_credentials_schema_by_type.return_value = [mock_credential]
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.builtin_provider_to_user_provider(
|
||||
mock_controller, None, decrypt_credentials=False
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result.is_team_authorization is True
|
||||
assert result.allow_delete is False
|
||||
assert result.masked_credentials == {}
|
||||
|
||||
def test_api_provider_to_controller_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
Test successful conversion of API provider to controller.
|
||||
|
||||
This test verifies:
|
||||
- Proper controller creation from database provider
|
||||
- Auth type handling for different credential types
|
||||
- Backward compatibility for auth types
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
|
||||
# Create API tool provider with api_key_header auth
|
||||
provider = ApiToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
tenant_id=fake.uuid4(),
|
||||
user_id=fake.uuid4(),
|
||||
credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}',
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
tools_str="[]",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.api_provider_to_controller(provider)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert hasattr(result, "from_db")
|
||||
# Additional assertions would depend on the actual controller implementation
|
||||
|
||||
def test_api_provider_to_controller_api_key_query(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test conversion of API provider to controller with api_key_query auth type.
|
||||
|
||||
This test verifies:
|
||||
- Proper auth type handling for query parameter authentication
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
|
||||
# Create API tool provider with api_key_query auth
|
||||
provider = ApiToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
tenant_id=fake.uuid4(),
|
||||
user_id=fake.uuid4(),
|
||||
credentials_str='{"auth_type": "api_key_query", "api_key": "test_key"}',
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
tools_str="[]",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.api_provider_to_controller(provider)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert hasattr(result, "from_db")
|
||||
|
||||
def test_api_provider_to_controller_backward_compatibility(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test conversion of API provider to controller with backward compatibility auth types.
|
||||
|
||||
This test verifies:
|
||||
- Proper handling of legacy auth type values
|
||||
- Backward compatibility for api_key and api_key_header
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
|
||||
# Create API tool provider with legacy auth type
|
||||
provider = ApiToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
tenant_id=fake.uuid4(),
|
||||
user_id=fake.uuid4(),
|
||||
credentials_str='{"auth_type": "api_key", "api_key": "test_key"}',
|
||||
schema="{}",
|
||||
schema_type_str="openapi",
|
||||
tools_str="[]",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.api_provider_to_controller(provider)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert hasattr(result, "from_db")
|
||||
|
||||
def test_workflow_provider_to_controller_success(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test successful conversion of workflow provider to controller.
|
||||
|
||||
This test verifies:
|
||||
- Proper controller creation from workflow provider
|
||||
- Workflow-specific controller handling
|
||||
"""
|
||||
# Arrange: Setup test data
|
||||
fake = Faker()
|
||||
|
||||
# Create workflow tool provider
|
||||
provider = WorkflowToolProvider(
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
icon='{"background": "#FF6B6B", "content": "🔧"}',
|
||||
tenant_id=fake.uuid4(),
|
||||
user_id=fake.uuid4(),
|
||||
app_id=fake.uuid4(),
|
||||
label="Test Workflow",
|
||||
version="1.0.0",
|
||||
parameter_configuration="[]",
|
||||
)
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
|
||||
# Mock the WorkflowToolProviderController.from_db method to avoid app dependency
|
||||
with patch("services.tools.tools_transform_service.WorkflowToolProviderController.from_db") as mock_from_db:
|
||||
mock_controller = Mock()
|
||||
mock_from_db.return_value = mock_controller
|
||||
|
||||
# Act: Execute the method under test
|
||||
result = ToolTransformService.workflow_provider_to_controller(provider)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
assert result is not None
|
||||
assert result == mock_controller
|
||||
mock_from_db.assert_called_once_with(provider)
|
||||
|
|
@ -621,7 +621,7 @@ export default translation
|
|||
&& !trimmed.startsWith('//'))
|
||||
break
|
||||
}
|
||||
else {
|
||||
else {
|
||||
break
|
||||
}
|
||||
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue