diff --git a/api/core/ops/aliyun_trace/utils.py b/api/core/ops/aliyun_trace/utils.py index 43b204b78c..956fc60191 100644 --- a/api/core/ops/aliyun_trace/utils.py +++ b/api/core/ops/aliyun_trace/utils.py @@ -27,9 +27,7 @@ DEFAULT_FRAMEWORK_NAME = "dify" def get_user_id_from_message_data(message_data) -> str: user_id = message_data.from_account_id if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: user_id = end_user_data.session_id return user_id diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index 724127c31c..a1ea182f66 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -410,9 +410,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): # Add end user data if available if trace_info.message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, trace_info.message_data.from_end_user_id) if end_user_data is not None: metadata["end_user_id"] = end_user_data.session_id diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 4a634e2e57..3bf01eb81c 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -241,9 +241,7 @@ class LangFuseDataTrace(BaseTraceInstance): user_id = message_data.from_account_id if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: user_id = end_user_data.session_id metadata["user_id"] = user_id diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 9f7d73b4ca..d960038f15 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -259,9 +259,7 @@ class LangSmithDataTrace(BaseTraceInstance): metadata["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: end_user_id = end_user_data.session_id metadata["end_user_id"] = end_user_id diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/core/ops/mlflow_trace/mlflow_trace.py index 8ec69e3542..8bf2e5dc13 100644 --- a/api/core/ops/mlflow_trace/mlflow_trace.py +++ b/api/core/ops/mlflow_trace/mlflow_trace.py @@ -9,6 +9,7 @@ from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey from mlflow.tracing.fluent import start_span_no_context, update_current_trace from mlflow.tracing.provider import detach_span_from_context, set_span_in_context +from sqlalchemy import select from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig @@ -320,7 +321,7 @@ class MLflowDataTrace(BaseTraceInstance): def _get_message_user_id(self, metadata: dict) -> str | None: if (end_user_id := metadata.get("from_end_user_id")) and ( - end_user_data := db.session.query(EndUser).where(EndUser.id == end_user_id).first() + end_user_data := db.session.get(EndUser, end_user_id) ): return end_user_data.session_id @@ -447,25 +448,11 @@ class MLflowDataTrace(BaseTraceInstance): def _get_workflow_nodes(self, workflow_run_id: str): """Helper method to get workflow nodes""" - workflow_nodes = ( - db.session.query( - WorkflowNodeExecutionModel.id, - WorkflowNodeExecutionModel.tenant_id, - WorkflowNodeExecutionModel.app_id, - WorkflowNodeExecutionModel.title, - WorkflowNodeExecutionModel.node_type, - WorkflowNodeExecutionModel.status, - WorkflowNodeExecutionModel.inputs, - WorkflowNodeExecutionModel.outputs, - WorkflowNodeExecutionModel.created_at, - WorkflowNodeExecutionModel.elapsed_time, - WorkflowNodeExecutionModel.process_data, - WorkflowNodeExecutionModel.execution_metadata, - ) - .filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) + workflow_nodes = db.session.scalars( + select(WorkflowNodeExecutionModel) + .where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) .order_by(WorkflowNodeExecutionModel.created_at) - .all() - ) + ).all() return workflow_nodes def _get_node_span_type(self, node_type: str) -> str: diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index a3ead548bb..b98cc3ce59 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -288,9 +288,7 @@ class OpikDataTrace(BaseTraceInstance): metadata["file_list"] = file_list if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: end_user_id = end_user_data.session_id metadata["end_user_id"] = end_user_id diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 0a2a0642f1..9c36d57c6f 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -420,10 +420,10 @@ class OpsTraceManager: :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig | None = ( - db.session.query(TraceAppConfig) + trace_config_data: TraceAppConfig | None = db.session.scalar( + select(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) - .first() + .limit(1) ) if not trace_config_data: @@ -463,7 +463,7 @@ class OpsTraceManager: if isinstance(app_id, str) and app_id.startswith("tenant-"): return None - app: App | None = db.session.query(App).where(App.id == app_id).first() + app = db.session.get(App, app_id) if app is None: return None @@ -537,7 +537,7 @@ class OpsTraceManager: except KeyError: raise ValueError(f"Invalid tracing provider: {tracing_provider}") - app_config: App | None = db.session.query(App).where(App.id == app_id).first() + app_config: App | None = db.session.get(App, app_id) if not app_config: raise ValueError("App not found") app_config.tracing = json.dumps( @@ -555,7 +555,7 @@ class OpsTraceManager: :param app_id: app id :return: """ - app: App | None = db.session.query(App).where(App.id == app_id).first() + app: App | None = db.session.get(App, app_id) if not app: raise ValueError("App not found") if not app.tracing: @@ -883,7 +883,7 @@ class TraceTask: inputs = message_data.message # get message file data - message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first() + message_file_data = db.session.scalar(select(MessageFile).where(MessageFile.message_id == message_id).limit(1)) file_list = [] if message_file_data and message_file_data.url is not None: file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" @@ -972,8 +972,8 @@ class TraceTask: # get workflow_app_log_id workflow_app_log_id = None if message_data.workflow_run_id: - workflow_app_log_data = ( - db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + workflow_app_log_data = db.session.scalar( + select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == message_data.workflow_run_id).limit(1) ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None @@ -1015,8 +1015,8 @@ class TraceTask: # get workflow_app_log_id workflow_app_log_id = None if message_data.workflow_run_id: - workflow_app_log_data = ( - db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + workflow_app_log_data = db.session.scalar( + select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == message_data.workflow_run_id).limit(1) ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None @@ -1171,7 +1171,7 @@ class TraceTask: metadata["node_execution_id"] = node_execution_id file_url = "" - message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first() + message_file_data = db.session.scalar(select(MessageFile).where(MessageFile.message_id == message_id).limit(1)) if message_file_data: message_file_id = message_file_data.id if message_file_data else None type = message_file_data.type diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index a55505822a..f79544f1c7 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -245,9 +245,7 @@ class WeaveDataTrace(BaseTraceInstance): attributes["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: end_user_id = end_user_data.session_id attributes["end_user_id"] = end_user_id diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py index fa885e9320..e4d8f2d5ea 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py @@ -45,11 +45,8 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch): end_user_data = MagicMock(spec=EndUser) end_user_data.session_id = "session_id" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = end_user_data - mock_session = MagicMock() - mock_session.query.return_value = mock_query + mock_session.get.return_value = end_user_data from core.ops.aliyun_trace.utils import db @@ -63,11 +60,8 @@ def test_get_user_id_from_message_data_end_user_not_found(monkeypatch): message_data.from_account_id = "account_id" message_data.from_end_user_id = "end_user_id" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = None - mock_session = MagicMock() - mock_session.query.return_value = mock_query + mock_session.get.return_value = None from core.ops.aliyun_trace.utils import db diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py index fdf66d4d40..8ebf441921 100644 --- a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py +++ b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py @@ -365,9 +365,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_end_user - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.query", lambda model: mock_query) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock() trace_instance.add_generation = MagicMock() diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py index e89359c25b..34c64c54a1 100644 --- a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py +++ b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py @@ -319,9 +319,7 @@ def test_message_trace(trace_instance, monkeypatch): # Mock EndUser lookup mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_end_user - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.query", lambda model: mock_query) + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_run = MagicMock() diff --git a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py index 7ff6f7dcfd..afc5726ede 100644 --- a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py +++ b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py @@ -330,7 +330,7 @@ class TestTraceDispatcher: class TestWorkflowTrace: def test_basic_workflow_no_nodes(self, trace_instance, mock_tracing, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" @@ -343,7 +343,7 @@ class TestWorkflowTrace: span.end.assert_called_once() def test_workflow_filters_sys_inputs_and_adds_query(self, trace_instance, mock_tracing, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" @@ -374,7 +374,7 @@ class TestWorkflowTrace: ), outputs='{"text": "hello world"}', ) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [llm_node] + mock_db.session.scalars.return_value.all.return_value = [llm_node] workflow_span = MagicMock() node_span = MagicMock() @@ -397,7 +397,7 @@ class TestWorkflowTrace: } ), ) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [qc_node] + mock_db.session.scalars.return_value.all.return_value = [qc_node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -411,7 +411,7 @@ class TestWorkflowTrace: node_type=BuiltinNodeTypes.HTTP_REQUEST, process_data='{"url": "https://api.com"}', ) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [http_node] + mock_db.session.scalars.return_value.all.return_value = [http_node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -434,7 +434,7 @@ class TestWorkflowTrace: } ), ) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [kr_node] + mock_db.session.scalars.return_value.all.return_value = [kr_node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -448,7 +448,7 @@ class TestWorkflowTrace: def test_workflow_with_failed_node(self, trace_instance, mock_tracing, mock_db): failed_node = _make_node(status="failed") - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [failed_node] + mock_db.session.scalars.return_value.all.return_value = [failed_node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -459,7 +459,7 @@ class TestWorkflowTrace: node_span.add_event.assert_called_once() def test_workflow_with_workflow_error(self, trace_instance, mock_tracing, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] workflow_span = MagicMock() mock_tracing["start"].return_value = workflow_span mock_tracing["set"].return_value = "token" @@ -473,7 +473,7 @@ class TestWorkflowTrace: def test_workflow_node_no_inputs_no_outputs(self, trace_instance, mock_tracing, mock_db): node = _make_node(inputs=None, outputs=None) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [node] + mock_db.session.scalars.return_value.all.return_value = [node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -486,7 +486,7 @@ class TestWorkflowTrace: assert end_call.kwargs["outputs"] == {} def test_workflow_no_user_id_no_conversation_id(self, trace_instance, mock_tracing, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" @@ -501,7 +501,7 @@ class TestWorkflowTrace: def test_workflow_empty_query(self, trace_instance, mock_tracing, mock_db): """When query is empty string, it's falsy so no query key added.""" - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" @@ -680,12 +680,12 @@ class TestGetMessageUserId: def test_returns_end_user_session_id(self, trace_instance, mock_db): end_user = MagicMock() end_user.session_id = "session-1" - mock_db.session.query.return_value.where.return_value.first.return_value = end_user + mock_db.session.get.return_value = end_user result = trace_instance._get_message_user_id({"from_end_user_id": "eu-1"}) assert result == "session-1" def test_returns_account_id_when_no_end_user(self, trace_instance, mock_db): - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None result = trace_instance._get_message_user_id({"from_end_user_id": "eu-1", "from_account_id": "acc-1"}) assert result == "acc-1" @@ -834,7 +834,7 @@ class TestGenerateNameTrace: class TestGetWorkflowNodes: def test_queries_db(self, trace_instance, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = ["n1", "n2"] + mock_db.session.scalars.return_value.all.return_value = ["n1", "n2"] result = trace_instance._get_workflow_nodes("run-1") assert result == ["n1", "n2"] diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py index 6625cb719f..c02ac413f2 100644 --- a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py +++ b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py @@ -373,9 +373,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_end_user - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.query", lambda model: mock_query) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2")) trace_instance.add_span = MagicMock() diff --git a/api/tests/unit_tests/core/ops/test_ops_trace_manager.py b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py index f81806c941..e47df0121e 100644 --- a/api/tests/unit_tests/core/ops/test_ops_trace_manager.py +++ b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py @@ -157,17 +157,19 @@ def make_workflow_run(): ) -def configure_db_query(session, *, message_file=None, workflow_app_log=None): - def _side_effect(model): - query = MagicMock() - query.filter_by.return_value.first.return_value = None - if message_file and model.__name__ == "MessageFile": - query.filter_by.return_value.first.return_value = message_file - if workflow_app_log and model.__name__ == "WorkflowAppLog": - query.filter_by.return_value.first.return_value = workflow_app_log - return query +def configure_db_scalar(session, *, message_file=None, workflow_app_log=None): + """Configure session.scalar to return appropriate values for MessageFile/WorkflowAppLog lookups.""" + original_scalar = session.scalar - session.query.side_effect = _side_effect + def _side_effect(stmt): + stmt_str = str(stmt) + if "message_file" in stmt_str.lower(): + return message_file + if "workflow_app_log" in stmt_str.lower(): + return workflow_app_log + return original_scalar(stmt) + + session.scalar.side_effect = _side_effect class DummySessionContext: @@ -263,7 +265,7 @@ def workflow_repo_fixture(monkeypatch): def trace_task_message(monkeypatch, mock_db): message_data = make_message_data() monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda msg_id: message_data) - configure_db_query(mock_db, message_file=FakeMessageFile(), workflow_app_log=SimpleNamespace(id="log-id")) + configure_db_scalar(mock_db, message_file=FakeMessageFile(), workflow_app_log=SimpleNamespace(id="log-id")) return message_data @@ -307,56 +309,53 @@ def test_obfuscated_decrypt_token(encryption_mocks): def test_get_decrypted_tracing_config_returns_config(encryption_mocks, mock_db): trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc", "other_value": "info"}) - mock_db.query.return_value.where.return_value.first.return_value = trace_config_data app = SimpleNamespace(id="app-id", tenant_id="tenant") - mock_db.scalar.return_value = app + mock_db.scalar.side_effect = [trace_config_data, app] decrypted = OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") assert decrypted["other_value"] == "info" def test_get_decrypted_tracing_config_missing_trace_config(mock_db): - mock_db.query.return_value.where.return_value.first.return_value = None + mock_db.scalar.return_value = None assert OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") is None def test_get_decrypted_tracing_config_raises_for_missing_app(mock_db): trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc"}) - mock_db.query.return_value.where.return_value.first.return_value = trace_config_data - mock_db.scalar.return_value = None + mock_db.scalar.side_effect = [trace_config_data, None] with pytest.raises(ValueError, match="App not found"): OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") def test_get_decrypted_tracing_config_raises_for_none_config(mock_db): trace_config_data = SimpleNamespace(tracing_config=None) - mock_db.query.return_value.where.return_value.first.return_value = trace_config_data - mock_db.scalar.return_value = SimpleNamespace(tenant_id="tenant") + mock_db.scalar.side_effect = [trace_config_data, SimpleNamespace(tenant_id="tenant")] with pytest.raises(ValueError, match="Tracing config cannot be None"): OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") def test_get_ops_trace_instance_handles_none_app(mock_db): - mock_db.query.return_value.where.return_value.first.return_value = None + mock_db.get.return_value = None assert OpsTraceManager.get_ops_trace_instance("app-id") is None def test_get_ops_trace_instance_returns_none_when_disabled(mock_db, monkeypatch): app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": False})) - mock_db.query.return_value.where.return_value.first.return_value = app + mock_db.get.return_value = app assert OpsTraceManager.get_ops_trace_instance("app-id") is None def test_get_ops_trace_instance_invalid_provider(mock_db, monkeypatch): app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "missing"})) - mock_db.query.return_value.where.return_value.first.return_value = app + mock_db.get.return_value = app monkeypatch.setattr("core.ops.ops_trace_manager.provider_config_map", FakeProviderMap({})) assert OpsTraceManager.get_ops_trace_instance("app-id") is None def test_get_ops_trace_instance_success(monkeypatch, mock_db): app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "dummy"})) - mock_db.query.return_value.where.return_value.first.return_value = app + mock_db.get.return_value = app monkeypatch.setattr( "core.ops.ops_trace_manager.OpsTraceManager.get_decrypted_tracing_config", classmethod(lambda cls, aid, provider: {"secret_value": "decrypted", "other_value": "info"}), @@ -390,7 +389,7 @@ def test_get_app_config_through_message_id_app_model_config(mock_db): def test_update_app_tracing_config_invalid_provider(mock_db, monkeypatch): - mock_db.query.return_value.where.return_value.first.return_value = None + mock_db.get.return_value = None with pytest.raises(ValueError, match="Invalid tracing provider"): OpsTraceManager.update_app_tracing_config("app", True, "bad") with pytest.raises(ValueError, match="App not found"): @@ -399,26 +398,26 @@ def test_update_app_tracing_config_invalid_provider(mock_db, monkeypatch): def test_update_app_tracing_config_success(mock_db): app = SimpleNamespace(id="app-id", tracing="{}") - mock_db.query.return_value.where.return_value.first.return_value = app + mock_db.get.return_value = app OpsTraceManager.update_app_tracing_config("app-id", True, "dummy") assert app.tracing is not None mock_db.commit.assert_called_once() def test_get_app_tracing_config_errors_when_missing(mock_db): - mock_db.query.return_value.where.return_value.first.return_value = None + mock_db.get.return_value = None with pytest.raises(ValueError, match="App not found"): OpsTraceManager.get_app_tracing_config("app") def test_get_app_tracing_config_returns_defaults(mock_db): - mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=None) + mock_db.get.return_value = SimpleNamespace(tracing=None) assert OpsTraceManager.get_app_tracing_config("app-id") == {"enabled": False, "tracing_provider": None} def test_get_app_tracing_config_returns_payload(mock_db): payload = {"enabled": True, "tracing_provider": "dummy"} - mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=json.dumps(payload)) + mock_db.get.return_value = SimpleNamespace(tracing=json.dumps(payload)) assert OpsTraceManager.get_app_tracing_config("app-id") == payload @@ -501,7 +500,7 @@ def test_trace_task_dataset_retrieval_trace(trace_task_message): def test_trace_task_tool_trace(monkeypatch, mock_db): custom_message = make_message_data(agent_thoughts=[make_agent_thought("tool-a", datetime(2025, 2, 20, 12, 1, 0))]) monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda _: custom_message) - configure_db_query(mock_db, message_file=FakeMessageFile()) + configure_db_scalar(mock_db, message_file=FakeMessageFile()) task = TraceTask(trace_type=TraceTaskName.TOOL_TRACE, message_id="msg-id") timer = {"start": 1, "end": 5} result = task.tool_trace("msg-id", timer, tool_name="tool-a", tool_inputs={"foo": 1}, tool_outputs="result") diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py index 8987b6682c..531c7de05f 100644 --- a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py +++ b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py @@ -802,8 +802,8 @@ class TestMessageTrace: def test_basic_message_trace(self, trace_instance, monkeypatch): """message_trace creates message run and llm child run.""" monkeypatch.setattr( - "core.ops.weave_trace.weave_trace.db.session.query", - lambda model: MagicMock(where=lambda: MagicMock(first=lambda: None)), + "core.ops.weave_trace.weave_trace.db.session.get", + lambda model, pk: None, ) trace_instance.start_call = MagicMock() @@ -823,7 +823,7 @@ class TestMessageTrace: trace_instance.file_base_url = "http://files.test" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() @@ -845,7 +845,7 @@ class TestMessageTrace: end_user.session_id = "session-xyz" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = end_user + mock_db.session.get.return_value = end_user monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() @@ -865,7 +865,7 @@ class TestMessageTrace: def test_message_trace_no_end_user(self, trace_instance, monkeypatch): """message_trace handles when from_end_user_id is None.""" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() @@ -883,7 +883,7 @@ class TestMessageTrace: def test_message_trace_trace_id_fallback_to_message_id(self, trace_instance, monkeypatch): """trace_id falls back to message_id when trace_id is None.""" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() @@ -898,7 +898,7 @@ class TestMessageTrace: def test_message_trace_file_list_none(self, trace_instance, monkeypatch): """message_trace handles file_list=None gracefully.""" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock()