mirror of
https://github.com/langgenius/dify.git
synced 2026-06-24 21:11:16 +08:00
chore: not use request.scoped session (#37421)
Co-authored-by: WH-2099 <wh2099@pm.me>
This commit is contained in:
parent
7d2f25df8e
commit
0cc27dd401
@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
@ -22,7 +22,7 @@ from core.app.entities.queue_entities import (
|
||||
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
|
||||
from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.db.session_factory import session_factory
|
||||
from core.db.session_factory import create_session, session_factory
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
@ -107,7 +107,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_execution_id=self.application_generate_entity.workflow_run_id,
|
||||
)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with create_session() as session:
|
||||
app_record = session.scalar(select(App).where(App.id == app_config.app_id))
|
||||
|
||||
if not app_record:
|
||||
@ -204,6 +204,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
trace_session_id=self.application_generate_entity.extras.get("trace_session_id"),
|
||||
)
|
||||
|
||||
# Release the Flask scoped session before workflow execution so a checked-out DB connection
|
||||
# is not held for the lifetime of the graph run.
|
||||
db.session.close()
|
||||
|
||||
# RUN WORKFLOW
|
||||
@ -368,7 +370,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
|
||||
:return: List of conversation variables ready for use
|
||||
"""
|
||||
with sessionmaker(bind=db.engine).begin() as session:
|
||||
with create_session() as session, session.begin():
|
||||
existing_variables = self._load_existing_conversation_variables(session)
|
||||
|
||||
if not existing_variables:
|
||||
|
||||
@ -12,10 +12,10 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.db.session_factory import create_session
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationError
|
||||
from extensions.ext_database import db
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from graphon.model_runtime.model_providers.base.large_language_model import LargeLanguageModel
|
||||
@ -47,7 +47,10 @@ class AgentChatAppRunner(AppRunner):
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(AgentChatAppConfig, app_config)
|
||||
app_stmt = select(App).where(App.id == app_config.app_id)
|
||||
app_record = db.session.scalar(app_stmt)
|
||||
with create_session() as session:
|
||||
app_record = session.scalar(app_stmt)
|
||||
if app_record:
|
||||
session.expunge(app_record)
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
@ -185,14 +188,18 @@ class AgentChatAppRunner(AppRunner):
|
||||
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
|
||||
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
conversation_stmt = select(Conversation).where(Conversation.id == conversation.id)
|
||||
conversation_result = db.session.scalar(conversation_stmt)
|
||||
if conversation_result is None:
|
||||
raise ValueError("Conversation not found")
|
||||
msg_stmt = select(Message).where(Message.id == message.id)
|
||||
message_result = db.session.scalar(msg_stmt)
|
||||
with create_session() as session:
|
||||
conversation_result = session.scalar(conversation_stmt)
|
||||
if conversation_result is None:
|
||||
raise ValueError("Conversation not found")
|
||||
|
||||
message_result = session.scalar(msg_stmt)
|
||||
if message_result is not None:
|
||||
session.expunge(message_result)
|
||||
session.expunge(conversation_result)
|
||||
if message_result is None:
|
||||
raise ValueError("Message not found")
|
||||
db.session.close()
|
||||
|
||||
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
|
||||
# start agent runner
|
||||
|
||||
@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
)
|
||||
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.db.session_factory import create_session
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationError
|
||||
@ -46,7 +47,10 @@ class ChatAppRunner(AppRunner):
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(ChatAppConfig, app_config)
|
||||
stmt = select(App).where(App.id == app_config.app_id)
|
||||
app_record = db.session.scalar(stmt)
|
||||
with create_session() as session:
|
||||
app_record = session.scalar(stmt)
|
||||
if app_record:
|
||||
session.expunge(app_record)
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
@ -216,6 +220,8 @@ class ChatAppRunner(AppRunner):
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
|
||||
# Release the Flask scoped session before LLM streaming so a checked-out DB connection
|
||||
# is not held for the lifetime of the provider response.
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
|
||||
@ -10,6 +10,7 @@ from core.app.entities.app_invoke_entities import (
|
||||
CompletionAppGenerateEntity,
|
||||
)
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.db.session_factory import create_session
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationError
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
@ -39,7 +40,10 @@ class CompletionAppRunner(AppRunner):
|
||||
app_config = application_generate_entity.app_config
|
||||
app_config = cast(CompletionAppConfig, app_config)
|
||||
stmt = select(App).where(App.id == app_config.app_id)
|
||||
app_record = db.session.scalar(stmt)
|
||||
with create_session() as session:
|
||||
app_record = session.scalar(stmt)
|
||||
if app_record:
|
||||
session.expunge(app_record)
|
||||
if not app_record:
|
||||
raise ValueError("App not found")
|
||||
|
||||
@ -174,6 +178,8 @@ class CompletionAppRunner(AppRunner):
|
||||
model=application_generate_entity.model_conf.model,
|
||||
)
|
||||
|
||||
# Release the Flask scoped session before LLM streaming so a checked-out DB connection
|
||||
# is not held for the lifetime of the provider response.
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
|
||||
@ -3,6 +3,7 @@ import time
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
|
||||
@ -14,12 +15,12 @@ from core.app.entities.app_invoke_entities import (
|
||||
build_dify_run_context,
|
||||
)
|
||||
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
|
||||
from core.db.session_factory import create_session
|
||||
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
|
||||
from core.workflow.node_factory import DifyGraphInitContext, DifyNodeFactory, get_default_root_node_id
|
||||
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
|
||||
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from graphon.enums import WorkflowType
|
||||
from graphon.graph import Graph
|
||||
from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent
|
||||
@ -83,22 +84,24 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
user_from = self._resolve_user_from(invoke_from)
|
||||
|
||||
user_id = None
|
||||
if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
end_user = db.session.get(EndUser, self.application_generate_entity.user_id)
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = self.application_generate_entity.user_id
|
||||
with create_session() as session:
|
||||
if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
end_user = session.get(EndUser, self.application_generate_entity.user_id)
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
pipeline = db.session.get(Pipeline, app_config.app_id)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
pipeline = session.get(Pipeline, app_config.app_id)
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
|
||||
workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
workflow = self.get_workflow(session=session, pipeline=pipeline, workflow_id=app_config.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
db.session.close()
|
||||
session.expunge(pipeline)
|
||||
session.expunge(workflow)
|
||||
|
||||
# if only single iteration run is requested
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
@ -208,12 +211,12 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
|
||||
def get_workflow(self, session: Session, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = db.session.scalar(
|
||||
workflow = session.scalar(
|
||||
select(Workflow)
|
||||
.where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id)
|
||||
.limit(1)
|
||||
@ -298,11 +301,11 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
||||
"""
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
if document_id and dataset_id:
|
||||
document = db.session.scalar(
|
||||
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = event.error or "Unknown error"
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
with create_session() as session, session.begin():
|
||||
document = session.scalar(
|
||||
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
|
||||
)
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = event.error or "Unknown error"
|
||||
session.add(document)
|
||||
|
||||
@ -3,7 +3,6 @@ from collections.abc import Generator, Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
@ -13,10 +12,19 @@ from core.app.apps.completion.app_generator import CompletionAppGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
|
||||
from core.db.session_factory import create_session
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from extensions.ext_database import db
|
||||
from models import Account
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models import Account, TenantAccountJoin
|
||||
from models.model import (
|
||||
App,
|
||||
AppMode,
|
||||
AppModelConfig,
|
||||
AppModelConfigDict,
|
||||
EndUser,
|
||||
load_annotation_reply_config,
|
||||
)
|
||||
from models.workflow import Workflow
|
||||
from services.end_user_service import EndUserService
|
||||
|
||||
|
||||
@ -30,18 +38,18 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
|
||||
"""Retrieve app parameters."""
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app.workflow
|
||||
workflow = cls._get_workflow(app)
|
||||
if workflow is None:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
features_dict: dict[str, Any] = workflow.features_dict
|
||||
user_input_form = workflow.user_input_form(to_old_structure=True)
|
||||
else:
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
app_model_config_dict = cls._get_app_model_config_dict(app)
|
||||
if app_model_config_dict is None:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
features_dict = cast(dict[str, Any], app_model_config_dict)
|
||||
|
||||
user_input_form = features_dict.get("user_input_form", [])
|
||||
|
||||
@ -68,7 +76,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
if not user_id:
|
||||
user = EndUserService.get_or_create_end_user(app)
|
||||
else:
|
||||
user = cls._get_user(user_id)
|
||||
user = cls._get_user(user_id, app)
|
||||
|
||||
conversation_id = conversation_id or ""
|
||||
|
||||
@ -79,7 +87,10 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
|
||||
return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files)
|
||||
case AppMode.WORKFLOW:
|
||||
return cls.invoke_workflow_app(app, user, stream, inputs, files)
|
||||
workflow = cls._get_workflow(app)
|
||||
if not workflow:
|
||||
raise ValueError("unexpected app type")
|
||||
return cls.invoke_workflow_app(app, workflow, user, stream, inputs, files)
|
||||
case AppMode.COMPLETION:
|
||||
return cls.invoke_completion_app(app, user, stream, inputs, files)
|
||||
case _:
|
||||
@ -101,7 +112,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
match app.mode:
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
workflow = app.workflow
|
||||
workflow = cls._get_workflow(app)
|
||||
if not workflow:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
@ -158,6 +169,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
def invoke_workflow_app(
|
||||
cls,
|
||||
app: App,
|
||||
workflow: Workflow,
|
||||
user: EndUser | Account,
|
||||
stream: bool,
|
||||
inputs: Mapping,
|
||||
@ -166,10 +178,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"""
|
||||
invoke workflow app
|
||||
"""
|
||||
workflow = app.workflow
|
||||
if not workflow:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=db.engine,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
@ -207,16 +215,26 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_user(cls, user_id: str) -> EndUser | Account:
|
||||
def _get_user(cls, user_id: str, app: App) -> EndUser | Account:
|
||||
"""
|
||||
get the user by user id
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(EndUser).where(EndUser.id == user_id)
|
||||
with create_session() as session:
|
||||
stmt = select(EndUser).where(
|
||||
EndUser.id == user_id,
|
||||
EndUser.tenant_id == app.tenant_id,
|
||||
EndUser.app_id == app.id,
|
||||
)
|
||||
user = session.scalar(stmt)
|
||||
if not user:
|
||||
stmt = select(Account).where(Account.id == user_id)
|
||||
stmt = select(Account).where(
|
||||
Account.id == user_id,
|
||||
Account.id == TenantAccountJoin.account_id,
|
||||
TenantAccountJoin.tenant_id == app.tenant_id,
|
||||
)
|
||||
user = session.scalar(stmt)
|
||||
if user:
|
||||
session.expunge(user)
|
||||
|
||||
if not user:
|
||||
raise ValueError("user not found")
|
||||
@ -229,7 +247,10 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
get app
|
||||
"""
|
||||
try:
|
||||
app = db.session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1))
|
||||
with create_session() as session:
|
||||
app = session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1))
|
||||
if app:
|
||||
session.expunge(app)
|
||||
except Exception:
|
||||
raise ValueError("app not found")
|
||||
|
||||
@ -237,3 +258,41 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
raise ValueError("app not found")
|
||||
|
||||
return app
|
||||
|
||||
@classmethod
|
||||
def _get_workflow(cls, app: App) -> Workflow | None:
|
||||
"""
|
||||
get workflow without relying on App.workflow's request-scoped session property
|
||||
"""
|
||||
if not app.workflow_id:
|
||||
return None
|
||||
|
||||
with create_session() as session:
|
||||
workflow = session.scalar(
|
||||
select(Workflow)
|
||||
.where(Workflow.id == app.workflow_id, Workflow.tenant_id == app.tenant_id, Workflow.app_id == app.id)
|
||||
.limit(1)
|
||||
)
|
||||
if workflow:
|
||||
session.expunge(workflow)
|
||||
return workflow
|
||||
|
||||
@classmethod
|
||||
def _get_app_model_config_dict(cls, app: App) -> AppModelConfigDict | None:
|
||||
"""
|
||||
get app model config features without relying on request-scoped session-backed model properties
|
||||
"""
|
||||
if not app.app_model_config_id:
|
||||
return None
|
||||
|
||||
with create_session() as session:
|
||||
app_model_config = session.scalar(
|
||||
select(AppModelConfig)
|
||||
.where(AppModelConfig.id == app.app_model_config_id, AppModelConfig.app_id == app.id)
|
||||
.limit(1)
|
||||
)
|
||||
if app_model_config is None:
|
||||
return None
|
||||
|
||||
annotation_reply = load_annotation_reply_config(session, app_model_config.app_id)
|
||||
return app_model_config.to_dict(annotation_reply=annotation_reply)
|
||||
|
||||
@ -774,26 +774,7 @@ class AppModelConfig(TypeBase):
|
||||
|
||||
@property
|
||||
def annotation_reply_dict(self) -> AnnotationReplyConfig:
|
||||
annotation_setting = db.session.scalar(
|
||||
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id)
|
||||
)
|
||||
if annotation_setting:
|
||||
collection_binding_detail = annotation_setting.collection_binding_detail
|
||||
if not collection_binding_detail:
|
||||
raise ValueError("Collection binding detail not found")
|
||||
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name,
|
||||
},
|
||||
}
|
||||
|
||||
else:
|
||||
return {"enabled": False}
|
||||
return load_annotation_reply_config(db.session(), self.app_id)
|
||||
|
||||
@property
|
||||
def more_like_this_dict(self) -> EnabledConfig:
|
||||
@ -864,7 +845,7 @@ class AppModelConfig(TypeBase):
|
||||
},
|
||||
)
|
||||
|
||||
def to_dict(self) -> AppModelConfigDict:
|
||||
def to_dict(self, *, annotation_reply: AnnotationReplyConfig | None = None) -> AppModelConfigDict:
|
||||
return {
|
||||
"opening_statement": self.opening_statement,
|
||||
"suggested_questions": self.suggested_questions_list,
|
||||
@ -872,7 +853,7 @@ class AppModelConfig(TypeBase):
|
||||
"speech_to_text": self.speech_to_text_dict,
|
||||
"text_to_speech": self.text_to_speech_dict,
|
||||
"retriever_resource": self.retriever_resource_dict,
|
||||
"annotation_reply": self.annotation_reply_dict,
|
||||
"annotation_reply": annotation_reply if annotation_reply is not None else self.annotation_reply_dict,
|
||||
"more_like_this": self.more_like_this_dict,
|
||||
"sensitive_word_avoidance": self.sensitive_word_avoidance_dict,
|
||||
"external_data_tools": self.external_data_tools_list,
|
||||
@ -2038,6 +2019,30 @@ class AppAnnotationSetting(TypeBase):
|
||||
)
|
||||
|
||||
|
||||
def load_annotation_reply_config(session: Session, app_id: str) -> AnnotationReplyConfig:
|
||||
annotation_setting = session.scalar(select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id))
|
||||
if annotation_setting is None:
|
||||
return {"enabled": False}
|
||||
|
||||
from .dataset import DatasetCollectionBinding
|
||||
|
||||
collection_binding_detail = session.scalar(
|
||||
select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == annotation_setting.collection_binding_id)
|
||||
)
|
||||
if collection_binding_detail is None:
|
||||
raise ValueError("Collection binding detail not found")
|
||||
|
||||
return {
|
||||
"id": annotation_setting.id,
|
||||
"enabled": True,
|
||||
"score_threshold": annotation_setting.score_threshold,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": collection_binding_detail.provider_name,
|
||||
"embedding_model_name": collection_binding_detail.model_name,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class OperationLog(TypeBase):
|
||||
__tablename__ = "operation_logs"
|
||||
__table_args__ = (
|
||||
|
||||
@ -25,6 +25,15 @@ MINIMAL_GRAPH = {
|
||||
}
|
||||
|
||||
|
||||
def _patch_create_session(mock_session: MagicMock):
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = mock_session
|
||||
session_context.__exit__.return_value = False
|
||||
mock_session.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_session.begin.return_value.__exit__.return_value = False
|
||||
return patch("core.app.apps.advanced_chat.app_runner.create_session", return_value=session_context)
|
||||
|
||||
|
||||
class TestAdvancedChatAppRunnerConversationVariables:
|
||||
"""Test that AdvancedChatAppRunner correctly handles conversation variables."""
|
||||
|
||||
@ -135,10 +144,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
|
||||
# Patch the necessary components
|
||||
with (
|
||||
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
|
||||
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
|
||||
_patch_create_session(mock_session),
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(
|
||||
runner,
|
||||
@ -151,12 +158,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client,
|
||||
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock GraphRuntimeState to accept the variable pool
|
||||
mock_graph_runtime_state_class.return_value = MagicMock()
|
||||
|
||||
@ -281,10 +282,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
|
||||
# Patch the necessary components
|
||||
with (
|
||||
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
|
||||
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
|
||||
_patch_create_session(mock_session),
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(
|
||||
runner,
|
||||
@ -298,12 +297,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client,
|
||||
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock ConversationVariable.from_variable to return mock objects
|
||||
mock_conv_vars = []
|
||||
for var in workflow_vars:
|
||||
@ -434,10 +427,8 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
|
||||
# Patch the necessary components
|
||||
with (
|
||||
patch("core.app.apps.advanced_chat.app_runner.sessionmaker") as mock_sessionmaker,
|
||||
patch("core.app.apps.advanced_chat.app_runner.Session") as mock_session_class,
|
||||
_patch_create_session(mock_session),
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(
|
||||
runner,
|
||||
@ -450,12 +441,6 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
patch("core.app.apps.advanced_chat.app_runner.redis_client") as mock_redis_client,
|
||||
patch("core.app.apps.advanced_chat.app_runner.RedisChannel") as mock_redis_channel_class,
|
||||
):
|
||||
# Setup mocks
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
|
||||
mock_session_class.return_value.__enter__.return_value = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
|
||||
# Mock GraphRuntimeState to accept the variable pool
|
||||
mock_graph_runtime_state_class.return_value = MagicMock()
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
import core.app.apps.advanced_chat.app_runner as module
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueStopEvent
|
||||
@ -85,27 +86,24 @@ def build_runner():
|
||||
|
||||
def _patch_common_run_deps(runner: AdvancedChatAppRunner):
|
||||
"""Context manager that patches common heavy deps used by run()."""
|
||||
# create_session() returns a context manager whose body yields a session that
|
||||
# supports both scalar() (app record lookup) and begin()/scalars().all()
|
||||
# (conversation variable initialization).
|
||||
mock_session = MagicMock()
|
||||
mock_session.scalar.return_value = MagicMock()
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = mock_session
|
||||
session_context.__exit__.return_value = False
|
||||
mock_session.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_session.begin.return_value.__exit__.return_value = False
|
||||
|
||||
return patch.multiple(
|
||||
"core.app.apps.advanced_chat.app_runner",
|
||||
Session=MagicMock(
|
||||
return_value=MagicMock(
|
||||
__enter__=lambda s: s,
|
||||
__exit__=lambda *a, **k: False,
|
||||
scalar=lambda *a, **k: MagicMock(),
|
||||
),
|
||||
),
|
||||
sessionmaker=MagicMock(
|
||||
return_value=MagicMock(
|
||||
begin=MagicMock(
|
||||
return_value=MagicMock(
|
||||
__enter__=lambda s: MagicMock(scalars=MagicMock(return_value=MagicMock(all=lambda: []))),
|
||||
__exit__=lambda *a, **k: False,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
create_session=MagicMock(return_value=session_context),
|
||||
select=MagicMock(),
|
||||
db=MagicMock(engine=MagicMock()),
|
||||
session_factory=MagicMock(get_session_maker=MagicMock(return_value=MagicMock())),
|
||||
RedisChannel=MagicMock(),
|
||||
redis_client=MagicMock(),
|
||||
WorkflowEntry=MagicMock(**{"return_value.run.return_value": iter([])}),
|
||||
@ -192,3 +190,42 @@ def test_run_returns_early_when_direct_output_via_handle_input_moderation(build_
|
||||
# Ensure no further steps executed
|
||||
mock_anno.assert_not_called()
|
||||
mock_init_graph.assert_not_called()
|
||||
|
||||
|
||||
def test_run_closes_scoped_session_before_workflow_run(build_runner):
|
||||
runner = build_runner
|
||||
events = []
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.scalar.return_value = MagicMock()
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = mock_session
|
||||
session_context.__exit__.return_value = False
|
||||
|
||||
workflow_entry = MagicMock()
|
||||
|
||||
def run_workflow():
|
||||
events.append("run")
|
||||
return iter([])
|
||||
|
||||
workflow_entry.run.side_effect = run_workflow
|
||||
|
||||
with (
|
||||
patch.object(module, "create_session", return_value=session_context),
|
||||
patch.object(module, "session_factory", MagicMock(get_session_maker=MagicMock(return_value=MagicMock()))),
|
||||
patch.object(module, "RedisChannel"),
|
||||
patch.object(module, "redis_client"),
|
||||
patch.object(module, "WorkflowEntry", return_value=workflow_entry),
|
||||
patch.object(module.db.session, "close", side_effect=lambda: events.append("close")),
|
||||
patch.object(
|
||||
runner,
|
||||
"handle_input_moderation",
|
||||
return_value=(False, runner.application_generate_entity.inputs, runner.application_generate_entity.query),
|
||||
),
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False),
|
||||
patch.object(runner, "_initialize_conversation_variables", return_value=[]),
|
||||
patch.object(runner, "_init_graph", return_value=MagicMock()),
|
||||
):
|
||||
runner.run()
|
||||
|
||||
assert events == ["close", "run"]
|
||||
|
||||
@ -13,12 +13,24 @@ def runner():
|
||||
return AgentChatAppRunner()
|
||||
|
||||
|
||||
def patch_create_session(mocker: MockerFixture, *, return_value=None, side_effect=None):
|
||||
session = mocker.MagicMock()
|
||||
if side_effect is not None:
|
||||
session.scalar.side_effect = side_effect
|
||||
else:
|
||||
session.scalar.return_value = return_value
|
||||
session_context = mocker.MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.create_session", return_value=session_context)
|
||||
return session
|
||||
|
||||
|
||||
class TestAgentChatAppRunnerRun:
|
||||
def test_run_app_not_found(self, runner: AgentChatAppRunner, mocker: MockerFixture):
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", agent=mocker.MagicMock())
|
||||
generate_entity = mocker.MagicMock(app_config=app_config, inputs={}, query="q", files=[], stream=True)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=None)
|
||||
patch_create_session(mocker, return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
|
||||
@ -37,7 +49,7 @@ class TestAgentChatAppRunnerRun:
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
patch_create_session(mocker, return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", side_effect=ModerationError("bad"))
|
||||
mocker.patch.object(runner, "direct_output")
|
||||
@ -62,7 +74,7 @@ class TestAgentChatAppRunnerRun:
|
||||
invoke_from=mocker.MagicMock(),
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
patch_create_session(mocker, return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
annotation = mocker.MagicMock(id="anno", content="answer")
|
||||
@ -91,7 +103,7 @@ class TestAgentChatAppRunnerRun:
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
patch_create_session(mocker, return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
@ -121,7 +133,7 @@ class TestAgentChatAppRunnerRun:
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
patch_create_session(mocker, return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
@ -163,7 +175,7 @@ class TestAgentChatAppRunnerRun:
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
patch_create_session(mocker, return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
@ -179,10 +191,7 @@ class TestAgentChatAppRunnerRun:
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
patch_create_session(mocker, side_effect=[app_record, conversation, message])
|
||||
|
||||
runner_cls = mocker.MagicMock()
|
||||
mocker.patch(f"core.app.apps.agent_chat.app_runner.{expected_runner}", runner_cls)
|
||||
@ -219,7 +228,7 @@ class TestAgentChatAppRunnerRun:
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
patch_create_session(mocker, return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
@ -235,10 +244,7 @@ class TestAgentChatAppRunnerRun:
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
patch_create_session(mocker, side_effect=[app_record, conversation, message])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
@ -267,7 +273,7 @@ class TestAgentChatAppRunnerRun:
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
patch_create_session(mocker, return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
@ -283,10 +289,7 @@ class TestAgentChatAppRunnerRun:
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
patch_create_session(mocker, side_effect=[app_record, conversation, message])
|
||||
|
||||
runner_cls = mocker.MagicMock()
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.FunctionCallAgentRunner", runner_cls)
|
||||
@ -323,10 +326,7 @@ class TestAgentChatAppRunnerRun:
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, None],
|
||||
)
|
||||
patch_create_session(mocker, side_effect=[app_record, None])
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
@ -357,10 +357,7 @@ class TestAgentChatAppRunnerRun:
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, mocker.MagicMock(id="conv"), None],
|
||||
)
|
||||
patch_create_session(mocker, side_effect=[app_record, mocker.MagicMock(id="conv"), None])
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
@ -391,7 +388,7 @@ class TestAgentChatAppRunnerRun:
|
||||
user_id="user",
|
||||
)
|
||||
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record)
|
||||
patch_create_session(mocker, return_value=app_record)
|
||||
mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None))
|
||||
mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q"))
|
||||
mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None)
|
||||
@ -407,10 +404,7 @@ class TestAgentChatAppRunnerRun:
|
||||
|
||||
conversation = mocker.MagicMock(id="conv")
|
||||
message = mocker.MagicMock(id="msg")
|
||||
mocker.patch(
|
||||
"core.app.apps.agent_chat.app_runner.db.session.scalar",
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
patch_create_session(mocker, side_effect=[app_record, conversation, message])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import ANY, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -29,6 +30,19 @@ class DummyQueueManager:
|
||||
self.published.append((event, pub_from))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patched_create_session(*, return_value=None, side_effect=None):
|
||||
session = MagicMock()
|
||||
if side_effect is not None:
|
||||
session.scalar.side_effect = side_effect
|
||||
else:
|
||||
session.scalar.return_value = return_value
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
with patch("core.app.apps.chat.app_runner.create_session", return_value=session_context):
|
||||
yield session
|
||||
|
||||
|
||||
class TestChatAppGenerator:
|
||||
def test_generate_requires_query(self):
|
||||
generator = ChatAppGenerator()
|
||||
@ -167,7 +181,7 @@ class TestChatAppRunner:
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
with patch("core.app.apps.chat.app_runner.db.session.scalar", return_value=None):
|
||||
with patched_create_session(return_value=None):
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))
|
||||
|
||||
@ -195,10 +209,7 @@ class TestChatAppRunner:
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.chat.app_runner.db.session.scalar",
|
||||
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
|
||||
),
|
||||
patched_create_session(return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1")),
|
||||
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
|
||||
patch.object(ChatAppRunner, "moderation_for_inputs", side_effect=ModerationError("blocked")),
|
||||
patch.object(ChatAppRunner, "direct_output") as mock_direct,
|
||||
@ -233,10 +244,7 @@ class TestChatAppRunner:
|
||||
annotation = SimpleNamespace(id="ann-1", content="answer")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.chat.app_runner.db.session.scalar",
|
||||
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
|
||||
),
|
||||
patched_create_session(return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1")),
|
||||
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
|
||||
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
|
||||
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=annotation),
|
||||
@ -272,13 +280,73 @@ class TestChatAppRunner:
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"core.app.apps.chat.app_runner.db.session.scalar",
|
||||
return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"),
|
||||
),
|
||||
patched_create_session(return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1")),
|
||||
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
|
||||
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
|
||||
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=None),
|
||||
patch.object(ChatAppRunner, "check_hosting_moderation", return_value=True),
|
||||
):
|
||||
runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1"))
|
||||
|
||||
def test_run_closes_scoped_session_before_stream_consumption(self):
|
||||
runner = ChatAppRunner()
|
||||
app_config = SimpleNamespace(
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
prompt_template=None,
|
||||
external_data_variables=[],
|
||||
dataset=None,
|
||||
additional_features=None,
|
||||
)
|
||||
app_generate_entity = DummyGenerateEntity(
|
||||
app_config=app_config,
|
||||
model_conf=SimpleNamespace(provider_model_bundle=None, model="model-1", parameters={}),
|
||||
inputs={},
|
||||
query="hi",
|
||||
files=[],
|
||||
file_upload_config=None,
|
||||
conversation_id=None,
|
||||
stream=True,
|
||||
user_id="user-1",
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
)
|
||||
|
||||
events = []
|
||||
queue_manager = DummyQueueManager()
|
||||
model_instance = MagicMock()
|
||||
|
||||
def invoke_stream():
|
||||
events.append("first-chunk")
|
||||
yield "chunk"
|
||||
|
||||
def invoke_llm(**kwargs):
|
||||
events.append("invoke")
|
||||
return invoke_stream()
|
||||
|
||||
with (
|
||||
patched_create_session(return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1")),
|
||||
patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])),
|
||||
patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")),
|
||||
patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=None),
|
||||
patch.object(ChatAppRunner, "check_hosting_moderation", return_value=False),
|
||||
patch.object(ChatAppRunner, "recalc_llm_max_tokens"),
|
||||
patch.object(
|
||||
ChatAppRunner,
|
||||
"_handle_invoke_result",
|
||||
side_effect=lambda invoke_result, **kwargs: list(invoke_result),
|
||||
) as mock_handle,
|
||||
patch("core.app.apps.chat.app_runner.ModelInstance", return_value=model_instance),
|
||||
patch("core.app.apps.chat.app_runner.db.session.close", side_effect=lambda: events.append("close")),
|
||||
):
|
||||
model_instance.invoke_llm.side_effect = invoke_llm
|
||||
runner.run(app_generate_entity, queue_manager, SimpleNamespace(), SimpleNamespace(id="m1"))
|
||||
|
||||
assert events == ["close", "invoke", "first-chunk"]
|
||||
mock_handle.assert_called_once_with(
|
||||
invoke_result=ANY,
|
||||
queue_manager=queue_manager,
|
||||
stream=True,
|
||||
message_id="m1",
|
||||
user_id="user-1",
|
||||
tenant_id="tenant-1",
|
||||
)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
@ -47,25 +48,28 @@ def _build_generate_entity(app_config, file_upload_config=None):
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def patched_create_session(*, return_value=None):
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = return_value
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
with patch.object(module, "create_session", return_value=session_context):
|
||||
yield session
|
||||
|
||||
|
||||
class TestCompletionAppRunner:
|
||||
def test_run_app_not_found(self, runner, mocker: MockerFixture):
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = None
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock())
|
||||
with patched_create_session(return_value=None):
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock())
|
||||
|
||||
def test_run_moderation_error_outputs_direct(self, runner, mocker: MockerFixture):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = app_record
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config)
|
||||
|
||||
@ -74,7 +78,8 @@ class TestCompletionAppRunner:
|
||||
runner.direct_output = MagicMock()
|
||||
runner._handle_invoke_result = MagicMock()
|
||||
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
|
||||
with patched_create_session(return_value=app_record):
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
|
||||
|
||||
runner.direct_output.assert_called_once()
|
||||
runner._handle_invoke_result.assert_not_called()
|
||||
@ -82,10 +87,6 @@ class TestCompletionAppRunner:
|
||||
def test_run_hosting_moderation_stops(self, runner, mocker: MockerFixture):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = app_record
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config)
|
||||
|
||||
@ -94,18 +95,14 @@ class TestCompletionAppRunner:
|
||||
runner.check_hosting_moderation = MagicMock(return_value=True)
|
||||
runner._handle_invoke_result = MagicMock()
|
||||
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
|
||||
with patched_create_session(return_value=app_record):
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
|
||||
|
||||
runner._handle_invoke_result.assert_not_called()
|
||||
|
||||
def test_run_dataset_and_external_tools_flow(self, runner, mocker: MockerFixture):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = app_record
|
||||
session.close = MagicMock()
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
retrieve_config = MagicMock(query_variable="qvar")
|
||||
dataset_config = MagicMock(dataset_ids=["ds"], retrieve_config=retrieve_config)
|
||||
additional_features = MagicMock(show_retrieve_source=True)
|
||||
@ -135,19 +132,56 @@ class TestCompletionAppRunner:
|
||||
model_instance.invoke_llm.return_value = "invoke_result"
|
||||
mocker.patch.object(module, "ModelInstance", return_value=model_instance)
|
||||
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg", tenant_id="tenant"))
|
||||
with patched_create_session(return_value=app_record):
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg", tenant_id="tenant"))
|
||||
|
||||
dataset_retrieval.retrieve.assert_called_once()
|
||||
assert dataset_retrieval.retrieve.call_args.kwargs["query"] == "query_from_input"
|
||||
runner._handle_invoke_result.assert_called_once()
|
||||
|
||||
def test_run_closes_scoped_session_before_stream_consumption(self, runner, mocker: MockerFixture):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config)
|
||||
queue_manager = MagicMock()
|
||||
|
||||
events = []
|
||||
runner.organize_prompt_messages = MagicMock(return_value=([], None))
|
||||
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
|
||||
runner.check_hosting_moderation = MagicMock(return_value=False)
|
||||
runner.recalc_llm_max_tokens = MagicMock()
|
||||
runner._handle_invoke_result = MagicMock(side_effect=lambda invoke_result, **kwargs: list(invoke_result))
|
||||
|
||||
model_instance = MagicMock()
|
||||
|
||||
def invoke_stream():
|
||||
events.append("first-chunk")
|
||||
yield "chunk"
|
||||
|
||||
def invoke_llm(**kwargs):
|
||||
events.append("invoke")
|
||||
return invoke_stream()
|
||||
|
||||
model_instance.invoke_llm.side_effect = invoke_llm
|
||||
mocker.patch.object(module, "ModelInstance", return_value=model_instance)
|
||||
mocker.patch.object(module.db.session, "close", side_effect=lambda: events.append("close"))
|
||||
|
||||
with patched_create_session(return_value=app_record):
|
||||
runner.run(app_generate_entity, queue_manager, MagicMock(id="msg"))
|
||||
|
||||
assert events == ["close", "invoke", "first-chunk"]
|
||||
runner._handle_invoke_result.assert_called_once_with(
|
||||
invoke_result=ANY,
|
||||
queue_manager=queue_manager,
|
||||
stream=True,
|
||||
message_id="msg",
|
||||
user_id="user",
|
||||
tenant_id="tenant",
|
||||
)
|
||||
|
||||
def test_run_uses_low_image_detail_default(self, runner, mocker: MockerFixture):
|
||||
app_record = MagicMock(id="app1", tenant_id="tenant")
|
||||
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = app_record
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
app_config = _build_app_config()
|
||||
app_generate_entity = _build_generate_entity(app_config, file_upload_config=None)
|
||||
|
||||
@ -155,7 +189,8 @@ class TestCompletionAppRunner:
|
||||
runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query"))
|
||||
runner.check_hosting_moderation = MagicMock(return_value=True)
|
||||
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
|
||||
with patched_create_session(return_value=app_record):
|
||||
runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg"))
|
||||
|
||||
assert (
|
||||
runner.organize_prompt_messages.call_args.kwargs["image_detail_config"]
|
||||
|
||||
@ -53,6 +53,32 @@ def _build_app_generate_entity() -> SimpleNamespace:
|
||||
)
|
||||
|
||||
|
||||
def _patch_create_session(mocker: MockerFixture, session: MagicMock, *, events: list[str] | None = None):
|
||||
"""Patch create_session() to yield ``session`` inside its ``with`` body and ``begin()`` block.
|
||||
|
||||
The runner now obtains short-lived sessions via ``create_session()`` instead of the
|
||||
Flask scoped ``db.session``, so tests patch the module-level ``create_session`` and
|
||||
hand back a context manager that yields the mock session.
|
||||
"""
|
||||
session_context = MagicMock()
|
||||
|
||||
def enter_session():
|
||||
if events is not None:
|
||||
events.append("session_enter")
|
||||
return session
|
||||
|
||||
def exit_session(*args):
|
||||
if events is not None:
|
||||
events.append("session_exit")
|
||||
return False
|
||||
|
||||
session_context.__enter__.side_effect = enter_session
|
||||
session_context.__exit__.side_effect = exit_session
|
||||
session.begin.return_value.__enter__.return_value = session
|
||||
session.begin.return_value.__exit__.return_value = False
|
||||
return mocker.patch.object(module, "create_session", return_value=session_context)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
app_generate_entity = _build_app_generate_entity()
|
||||
@ -77,13 +103,14 @@ def test_get_app_id(runner):
|
||||
assert runner._get_app_id() == "pipe"
|
||||
|
||||
|
||||
def test_get_workflow_returns_workflow(mocker, runner):
|
||||
def test_get_workflow_returns_workflow(runner):
|
||||
pipeline = MagicMock(tenant_id="tenant", id="pipe")
|
||||
workflow = MagicMock(id="wf")
|
||||
|
||||
mocker.patch.object(module.db, "session", MagicMock(scalar=MagicMock(return_value=workflow)))
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = workflow
|
||||
|
||||
result = runner.get_workflow(pipeline=pipeline, workflow_id="wf")
|
||||
result = runner.get_workflow(session=session, pipeline=pipeline, workflow_id="wf")
|
||||
|
||||
assert result == workflow
|
||||
|
||||
@ -116,7 +143,7 @@ def test_update_document_status_on_failure(mocker, runner):
|
||||
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = document
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
_patch_create_session(mocker, session)
|
||||
|
||||
event = GraphRunFailedEvent(error="boom")
|
||||
|
||||
@ -124,7 +151,10 @@ def test_update_document_status_on_failure(mocker, runner):
|
||||
|
||||
assert document.indexing_status == "error"
|
||||
assert document.error == "boom"
|
||||
session.commit.assert_called_once()
|
||||
session.add.assert_called_once_with(document)
|
||||
session.begin.assert_called_once()
|
||||
session.begin.return_value.__enter__.assert_called_once()
|
||||
session.begin.return_value.__exit__.assert_called_once()
|
||||
|
||||
|
||||
def test_run_pipeline_not_found(mocker: MockerFixture):
|
||||
@ -135,7 +165,7 @@ def test_run_pipeline_not_found(mocker: MockerFixture):
|
||||
|
||||
session = MagicMock()
|
||||
session.get.side_effect = [None, None]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
_patch_create_session(mocker, session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
@ -158,7 +188,7 @@ def test_run_workflow_not_initialized(mocker: MockerFixture):
|
||||
|
||||
session = MagicMock()
|
||||
session.get.side_effect = [None, pipeline]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
_patch_create_session(mocker, session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
@ -184,7 +214,7 @@ def test_run_single_iteration_path(mocker: MockerFixture):
|
||||
|
||||
session = MagicMock()
|
||||
session.get.side_effect = [end_user, pipeline]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
_patch_create_session(mocker, session)
|
||||
|
||||
runner = PipelineRunner(
|
||||
application_generate_entity=app_generate_entity,
|
||||
@ -229,10 +259,11 @@ def test_run_normal_path_builds_graph(mocker: MockerFixture):
|
||||
|
||||
pipeline = MagicMock(id="pipe")
|
||||
end_user = MagicMock(session_id="sess")
|
||||
events = []
|
||||
|
||||
session = MagicMock()
|
||||
session.get.side_effect = [end_user, pipeline]
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
_patch_create_session(mocker, session, events=events)
|
||||
|
||||
workflow = MagicMock(
|
||||
id="wf",
|
||||
@ -276,10 +307,11 @@ def test_run_normal_path_builds_graph(mocker: MockerFixture):
|
||||
|
||||
workflow_entry = MagicMock()
|
||||
workflow_entry.graph_engine = MagicMock()
|
||||
workflow_entry.run.return_value = []
|
||||
workflow_entry.run.side_effect = lambda: events.append("workflow_run") or []
|
||||
mocker.patch.object(module, "WorkflowEntry", return_value=workflow_entry)
|
||||
mocker.patch.object(module, "WorkflowPersistenceLayer", return_value=MagicMock())
|
||||
|
||||
runner.run()
|
||||
|
||||
assert events == ["session_enter", "session_exit", "workflow_run"]
|
||||
runner._init_rag_pipeline_graph.assert_called_once()
|
||||
|
||||
@ -5,6 +5,7 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
|
||||
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
|
||||
@ -16,6 +17,16 @@ class _Chunk(BaseModel):
|
||||
value: int
|
||||
|
||||
|
||||
def _build_app_model_config(result: dict | None = None):
|
||||
app_model_config = MagicMock()
|
||||
app_model_config.app_id = "app-1"
|
||||
app_model_config.to_dict.return_value = result or {
|
||||
"user_input_form": [{"name": "bar"}],
|
||||
"annotation_reply": {"enabled": False},
|
||||
}
|
||||
return app_model_config
|
||||
|
||||
|
||||
class TestBaseBackwardsInvocation:
|
||||
def test_convert_to_event_stream_with_generator_and_error(self):
|
||||
def _stream():
|
||||
@ -42,12 +53,25 @@ class TestBaseBackwardsInvocation:
|
||||
|
||||
|
||||
class TestPluginAppBackwardsInvocation:
|
||||
def patch_create_session(self, mocker: MockerFixture, *, return_value=None, side_effect=None):
|
||||
session = MagicMock()
|
||||
if side_effect is not None:
|
||||
session.scalar.side_effect = side_effect
|
||||
else:
|
||||
session.scalar.return_value = return_value
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
mocker.patch("core.plugin.backwards_invocation.app.create_session", return_value=session_ctx)
|
||||
return session
|
||||
|
||||
def test_fetch_app_info_workflow_path(self, mocker: MockerFixture):
|
||||
workflow = MagicMock()
|
||||
workflow.features_dict = {"feature": "v"}
|
||||
workflow.user_input_form.return_value = [{"name": "foo"}]
|
||||
app = MagicMock(mode=AppMode.WORKFLOW, workflow=workflow)
|
||||
app = MagicMock(mode=AppMode.WORKFLOW)
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app)
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_workflow", return_value=workflow)
|
||||
mapper = mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.get_parameters_from_feature_dict",
|
||||
return_value={"mapped": True},
|
||||
@ -59,10 +83,10 @@ class TestPluginAppBackwardsInvocation:
|
||||
mapper.assert_called_once_with(features_dict={"feature": "v"}, user_input_form=[{"name": "foo"}])
|
||||
|
||||
def test_fetch_app_info_model_config_path(self, mocker: MockerFixture):
|
||||
model_config = MagicMock()
|
||||
model_config.to_dict.return_value = {"user_input_form": [{"name": "bar"}], "k": "v"}
|
||||
app = MagicMock(mode=AppMode.COMPLETION, app_model_config=model_config)
|
||||
model_config_dict = {"user_input_form": [{"name": "bar"}], "k": "v"}
|
||||
app = MagicMock(mode=AppMode.COMPLETION)
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app)
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app_model_config_dict", return_value=model_config_dict)
|
||||
mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.get_parameters_from_feature_dict",
|
||||
return_value={"mapped": True},
|
||||
@ -85,8 +109,10 @@ class TestPluginAppBackwardsInvocation:
|
||||
def test_invoke_app_routes_by_mode(self, mocker: MockerFixture, mode, route_method):
|
||||
app = MagicMock(mode=mode)
|
||||
user = MagicMock()
|
||||
workflow = MagicMock()
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app)
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_user", return_value=user)
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_workflow", return_value=workflow)
|
||||
route = mocker.patch.object(PluginAppBackwardsInvocation, route_method, return_value={"routed": True})
|
||||
|
||||
result = PluginAppBackwardsInvocation.invoke_app(
|
||||
@ -106,7 +132,9 @@ class TestPluginAppBackwardsInvocation:
|
||||
def test_invoke_app_uses_end_user_when_user_id_missing(self, mocker: MockerFixture):
|
||||
app = MagicMock(mode=AppMode.WORKFLOW)
|
||||
end_user = MagicMock()
|
||||
workflow = MagicMock()
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app)
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_workflow", return_value=workflow)
|
||||
get_or_create = mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.EndUserService.get_or_create_end_user",
|
||||
return_value=end_user,
|
||||
@ -126,7 +154,8 @@ class TestPluginAppBackwardsInvocation:
|
||||
|
||||
assert result == {"ok": True}
|
||||
get_or_create.assert_called_once_with(app)
|
||||
assert route.call_args.args[1] is end_user
|
||||
assert route.call_args.args[1] is workflow
|
||||
assert route.call_args.args[2] is end_user
|
||||
|
||||
def test_invoke_app_missing_query_for_chat_raises(self, mocker: MockerFixture):
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=MagicMock(mode=AppMode.CHAT))
|
||||
@ -190,7 +219,7 @@ class TestPluginAppBackwardsInvocation:
|
||||
|
||||
app = MagicMock()
|
||||
app.mode = AppMode.ADVANCED_CHAT
|
||||
app.workflow = workflow
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_workflow", return_value=workflow)
|
||||
|
||||
mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.db",
|
||||
@ -217,8 +246,9 @@ class TestPluginAppBackwardsInvocation:
|
||||
assert isinstance(pause_state_config, PauseStateLayerConfig)
|
||||
assert pause_state_config.state_owner_user_id == "owner-id"
|
||||
|
||||
def test_invoke_chat_app_advanced_chat_without_workflow_raises(self):
|
||||
app = MagicMock(mode=AppMode.ADVANCED_CHAT, workflow=None)
|
||||
def test_invoke_chat_app_advanced_chat_without_workflow_raises(self, mocker: MockerFixture):
|
||||
app = MagicMock(mode=AppMode.ADVANCED_CHAT)
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_workflow", return_value=None)
|
||||
with pytest.raises(ValueError, match="unexpected app type"):
|
||||
PluginAppBackwardsInvocation.invoke_chat_app(
|
||||
app=app,
|
||||
@ -249,7 +279,6 @@ class TestPluginAppBackwardsInvocation:
|
||||
|
||||
app = MagicMock()
|
||||
app.mode = AppMode.WORKFLOW
|
||||
app.workflow = workflow
|
||||
|
||||
mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.db",
|
||||
@ -262,6 +291,7 @@ class TestPluginAppBackwardsInvocation:
|
||||
|
||||
result = PluginAppBackwardsInvocation.invoke_workflow_app(
|
||||
app=app,
|
||||
workflow=workflow,
|
||||
user=MagicMock(),
|
||||
stream=False,
|
||||
inputs={"k": "v"},
|
||||
@ -274,12 +304,18 @@ class TestPluginAppBackwardsInvocation:
|
||||
assert isinstance(pause_state_config, PauseStateLayerConfig)
|
||||
assert pause_state_config.state_owner_user_id == "owner-id"
|
||||
|
||||
def test_invoke_workflow_app_without_workflow_raises(self):
|
||||
app = MagicMock(mode=AppMode.WORKFLOW, workflow=None)
|
||||
def test_invoke_app_workflow_without_workflow_raises(self, mocker: MockerFixture):
|
||||
app = MagicMock(mode=AppMode.WORKFLOW)
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app)
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_user", return_value=MagicMock())
|
||||
mocker.patch.object(PluginAppBackwardsInvocation, "_get_workflow", return_value=None)
|
||||
with pytest.raises(ValueError, match="unexpected app type"):
|
||||
PluginAppBackwardsInvocation.invoke_workflow_app(
|
||||
app=app,
|
||||
user=MagicMock(),
|
||||
PluginAppBackwardsInvocation.invoke_app(
|
||||
app_id="app",
|
||||
user_id="user",
|
||||
tenant_id="tenant",
|
||||
conversation_id=None,
|
||||
query=None,
|
||||
stream=False,
|
||||
inputs={},
|
||||
files=[],
|
||||
@ -297,58 +333,97 @@ class TestPluginAppBackwardsInvocation:
|
||||
assert spy.call_count == 1
|
||||
|
||||
def test_get_user_returns_end_user(self, mocker: MockerFixture):
|
||||
session = MagicMock()
|
||||
session.scalar.side_effect = [MagicMock(id="end-user")]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx)
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock()))
|
||||
session = self.patch_create_session(mocker, side_effect=[MagicMock(id="end-user")])
|
||||
app = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
|
||||
user = PluginAppBackwardsInvocation._get_user("uid", app)
|
||||
|
||||
user = PluginAppBackwardsInvocation._get_user("uid")
|
||||
assert user.id == "end-user"
|
||||
stmt = session.scalar.call_args_list[0].args[0]
|
||||
compiled = str(stmt.compile(dialect=postgresql.dialect()))
|
||||
assert "end_users.id" in compiled
|
||||
assert "end_users.tenant_id" in compiled
|
||||
assert "end_users.app_id" in compiled
|
||||
assert stmt.compile().params == {"id_1": "uid", "tenant_id_1": "tenant-1", "app_id_1": "app-1"}
|
||||
|
||||
def test_get_user_falls_back_to_account_user(self, mocker: MockerFixture):
|
||||
session = MagicMock()
|
||||
session.scalar.side_effect = [None, MagicMock(id="account-user")]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx)
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock()))
|
||||
session = self.patch_create_session(mocker, side_effect=[None, MagicMock(id="account-user")])
|
||||
app = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
|
||||
user = PluginAppBackwardsInvocation._get_user("uid", app)
|
||||
|
||||
user = PluginAppBackwardsInvocation._get_user("uid")
|
||||
assert user.id == "account-user"
|
||||
stmt = session.scalar.call_args_list[1].args[0]
|
||||
compiled = str(stmt.compile(dialect=postgresql.dialect()))
|
||||
assert "accounts.id" in compiled
|
||||
assert "tenant_account_joins.account_id" in compiled
|
||||
assert "tenant_account_joins.tenant_id" in compiled
|
||||
assert stmt.compile().params == {"id_1": "uid", "tenant_id_1": "tenant-1"}
|
||||
|
||||
def test_get_user_raises_when_user_not_found(self, mocker: MockerFixture):
|
||||
session = MagicMock()
|
||||
session.scalar.side_effect = [None, None]
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__enter__.return_value = session
|
||||
session_ctx.__exit__.return_value = None
|
||||
mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx)
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock()))
|
||||
self.patch_create_session(mocker, side_effect=[None, None])
|
||||
app = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
|
||||
with pytest.raises(ValueError, match="user not found"):
|
||||
PluginAppBackwardsInvocation._get_user("uid")
|
||||
PluginAppBackwardsInvocation._get_user("uid", app)
|
||||
|
||||
def test_get_app_returns_app(self, mocker: MockerFixture):
|
||||
app_obj = MagicMock(id="app")
|
||||
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=app_obj)))
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", db)
|
||||
self.patch_create_session(mocker, return_value=app_obj)
|
||||
|
||||
assert PluginAppBackwardsInvocation._get_app("app", "tenant") is app_obj
|
||||
|
||||
def test_get_app_raises_when_missing(self, mocker: MockerFixture):
|
||||
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=None)))
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", db)
|
||||
self.patch_create_session(mocker, return_value=None)
|
||||
|
||||
with pytest.raises(ValueError, match="app not found"):
|
||||
PluginAppBackwardsInvocation._get_app("app", "tenant")
|
||||
|
||||
def test_get_app_raises_when_query_fails(self, mocker: MockerFixture):
|
||||
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(side_effect=RuntimeError("db down"))))
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", db)
|
||||
self.patch_create_session(mocker, side_effect=RuntimeError("db down"))
|
||||
|
||||
with pytest.raises(ValueError, match="app not found"):
|
||||
PluginAppBackwardsInvocation._get_app("app", "tenant")
|
||||
|
||||
def test_get_workflow_stays_inside_app_boundary(self, mocker: MockerFixture):
|
||||
workflow = MagicMock(id="workflow")
|
||||
session = self.patch_create_session(mocker, return_value=workflow)
|
||||
app = SimpleNamespace(id="app-1", tenant_id="tenant-1", workflow_id="workflow-1")
|
||||
|
||||
assert PluginAppBackwardsInvocation._get_workflow(app) is workflow
|
||||
|
||||
stmt = session.scalar.call_args.args[0]
|
||||
compiled = str(stmt.compile(dialect=postgresql.dialect()))
|
||||
assert "workflows.id" in compiled
|
||||
assert "workflows.tenant_id" in compiled
|
||||
assert "workflows.app_id" in compiled
|
||||
assert stmt.compile().params == {
|
||||
"id_1": "workflow-1",
|
||||
"tenant_id_1": "tenant-1",
|
||||
"app_id_1": "app-1",
|
||||
"param_1": 1,
|
||||
}
|
||||
|
||||
def test_get_app_model_config_dict_uses_explicit_session_for_annotation_reply(self, mocker: MockerFixture):
|
||||
annotation_reply = {"enabled": False}
|
||||
app_model_config = _build_app_model_config()
|
||||
session = self.patch_create_session(mocker, return_value=app_model_config)
|
||||
load_annotation_reply_config = mocker.patch(
|
||||
"core.plugin.backwards_invocation.app.load_annotation_reply_config",
|
||||
return_value=annotation_reply,
|
||||
)
|
||||
app = SimpleNamespace(id="app-1", app_model_config_id="config-1")
|
||||
|
||||
result = PluginAppBackwardsInvocation._get_app_model_config_dict(app)
|
||||
|
||||
assert result is not None
|
||||
assert result["user_input_form"] == [{"name": "bar"}]
|
||||
assert result["annotation_reply"] == annotation_reply
|
||||
load_annotation_reply_config.assert_called_once_with(session, "app-1")
|
||||
app_model_config.to_dict.assert_called_once_with(annotation_reply=annotation_reply)
|
||||
|
||||
stmt = session.scalar.call_args.args[0]
|
||||
compiled = str(stmt.compile(dialect=postgresql.dialect()))
|
||||
assert "app_model_configs.id" in compiled
|
||||
assert "app_model_configs.app_id" in compiled
|
||||
assert stmt.compile().params == {"id_1": "config-1", "app_id_1": "app-1", "param_1": 1}
|
||||
|
||||
@ -12,10 +12,11 @@ import json
|
||||
from datetime import UTC, datetime
|
||||
from decimal import Decimal
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from models.enums import ConversationFromSource
|
||||
from models.model import (
|
||||
@ -29,6 +30,7 @@ from models.model import (
|
||||
Message,
|
||||
MessageAnnotation,
|
||||
Site,
|
||||
load_annotation_reply_config,
|
||||
)
|
||||
|
||||
|
||||
@ -342,6 +344,70 @@ class TestAppModelConfig:
|
||||
# Assert
|
||||
assert result == questions
|
||||
|
||||
def test_to_dict_uses_injected_annotation_reply(self):
|
||||
config = AppModelConfig(app_id=str(uuid4()))
|
||||
annotation_reply = {"enabled": False}
|
||||
|
||||
with patch.object(
|
||||
AppModelConfig,
|
||||
"annotation_reply_dict",
|
||||
new_callable=PropertyMock,
|
||||
side_effect=AssertionError("annotation_reply_dict should not be accessed"),
|
||||
):
|
||||
result = config.to_dict(annotation_reply=annotation_reply)
|
||||
|
||||
assert result["annotation_reply"] == annotation_reply
|
||||
|
||||
|
||||
class TestAnnotationReplyConfigLoader:
|
||||
def test_load_annotation_reply_config_returns_disabled_when_setting_missing(self):
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = None
|
||||
|
||||
result = load_annotation_reply_config(session, "app-1")
|
||||
|
||||
assert result == {"enabled": False}
|
||||
session.scalar.assert_called_once()
|
||||
stmt = session.scalar.call_args.args[0]
|
||||
compiled = str(stmt.compile(dialect=postgresql.dialect()))
|
||||
assert "app_annotation_settings.app_id" in compiled
|
||||
assert stmt.compile().params == {"app_id_1": "app-1"}
|
||||
|
||||
def test_load_annotation_reply_config_returns_embedding_model(self):
|
||||
session = MagicMock()
|
||||
annotation_setting = SimpleNamespace(
|
||||
id="annotation-1",
|
||||
score_threshold=0.7,
|
||||
collection_binding_id="binding-1",
|
||||
)
|
||||
collection_binding = SimpleNamespace(provider_name="provider", model_name="embedding")
|
||||
session.scalar.side_effect = [annotation_setting, collection_binding]
|
||||
|
||||
result = load_annotation_reply_config(session, "app-1")
|
||||
|
||||
assert result == {
|
||||
"id": "annotation-1",
|
||||
"enabled": True,
|
||||
"score_threshold": 0.7,
|
||||
"embedding_model": {
|
||||
"embedding_provider_name": "provider",
|
||||
"embedding_model_name": "embedding",
|
||||
},
|
||||
}
|
||||
assert session.scalar.call_count == 2
|
||||
stmt = session.scalar.call_args_list[1].args[0]
|
||||
compiled = str(stmt.compile(dialect=postgresql.dialect()))
|
||||
assert "dataset_collection_bindings.id" in compiled
|
||||
assert stmt.compile().params == {"id_1": "binding-1"}
|
||||
|
||||
def test_load_annotation_reply_config_raises_when_binding_missing(self):
|
||||
session = MagicMock()
|
||||
annotation_setting = SimpleNamespace(collection_binding_id="binding-1")
|
||||
session.scalar.side_effect = [annotation_setting, None]
|
||||
|
||||
with pytest.raises(ValueError, match="Collection binding detail not found"):
|
||||
load_annotation_reply_config(session, "app-1")
|
||||
|
||||
|
||||
class TestConversationModel:
|
||||
"""Test suite for Conversation model integrity."""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user