refactor: migrate session.query to select API in core misc modules (#34608)

This commit is contained in:
Renzo 2026-04-06 23:08:34 -05:00 committed by GitHub
parent 2f9667de76
commit b55bef4438
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 26 additions and 56 deletions

View File

@ -509,8 +509,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:return:
"""
with Session(db.engine, expire_on_commit=False) as session:
agent_thought: MessageAgentThought | None = (
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
agent_thought: MessageAgentThought | None = session.scalar(
select(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).limit(1)
)
if agent_thought:

View File

@ -345,8 +345,8 @@ class DatasourceManager:
@classmethod
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File:
with session_factory.create_session() as session:
upload_file = (
session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first()
upload_file = session.scalar(
select(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).limit(1)
)
if not upload_file:
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")

View File

@ -467,7 +467,7 @@ class LLMGenerator:
):
session = db.session()
app: App | None = session.query(App).where(App.id == flow_id).first()
app: App | None = session.scalar(select(App).where(App.id == flow_id).limit(1))
if not app:
raise ValueError("App not found.")
workflow = workflow_service.get_draft_workflow(app_model=app)

View File

@ -56,8 +56,10 @@ class BaseTraceInstance(ABC):
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = (
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
current_tenant = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True))
.limit(1)
)
if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}")

View File

@ -241,8 +241,10 @@ class TencentDataTrace(BaseTraceInstance):
if not service_account:
raise ValueError(f"Creator account not found for app {app_id}")
current_tenant = (
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
current_tenant = session.scalar(
select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True))
.limit(1)
)
if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}")

View File

@ -505,13 +505,7 @@ class TestEasyUiBasedGenerateTaskPipeline:
def __exit__(self, exc_type, exc, tb):
return False
def query(self, *args, **kwargs):
return self
def where(self, *args, **kwargs):
return self
def first(self):
def scalar(self, *args, **kwargs):
return agent_thought
monkeypatch.setattr(
@ -1182,13 +1176,7 @@ class TestEasyUiBasedGenerateTaskPipeline:
def __exit__(self, exc_type, exc, tb):
return False
def query(self, *args, **kwargs):
return self
def where(self, *args, **kwargs):
return self
def first(self):
def scalar(self, *args, **kwargs):
return None
monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session)

View File

@ -632,16 +632,6 @@ def test_get_upload_file_by_id_builds_file(mocker):
source_url="http://x",
)
class _Q:
def __init__(self, row):
self._row = row
def where(self, *_args, **_kwargs):
return self
def first(self):
return self._row
class _S:
def __init__(self, row):
self._row = row
@ -652,8 +642,8 @@ def test_get_upload_file_by_id_builds_file(mocker):
def __exit__(self, *exc):
return False
def query(self, *_):
return _Q(self._row)
def scalar(self, *_args, **_kwargs):
return self._row
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S(fake_row))
@ -665,13 +655,6 @@ def test_get_upload_file_by_id_builds_file(mocker):
def test_get_upload_file_by_id_raises_when_missing(mocker):
class _Q:
def where(self, *_args, **_kwargs):
return self
def first(self):
return None
class _S:
def __enter__(self):
return self
@ -679,8 +662,8 @@ def test_get_upload_file_by_id_raises_when_missing(mocker):
def __exit__(self, *exc):
return False
def query(self, *_):
return _Q()
def scalar(self, *_args, **_kwargs):
return None
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S())

View File

@ -346,13 +346,13 @@ class TestLLMGenerator:
def test_instruction_modify_workflow_app_not_found(self):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = None
mock_session.return_value.scalar.return_value = None
with pytest.raises(ValueError, match="App not found."):
LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", MagicMock())
def test_instruction_modify_workflow_no_workflow(self):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
mock_session.return_value.scalar.return_value = MagicMock()
workflow_service = MagicMock()
workflow_service.get_draft_workflow.return_value = None
with pytest.raises(ValueError, match="Workflow not found for the given app model."):
@ -360,7 +360,7 @@ class TestLLMGenerator:
def test_instruction_modify_workflow_success(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
mock_session.return_value.scalar.return_value = MagicMock()
workflow = MagicMock()
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}}

View File

@ -407,8 +407,7 @@ class TestTencentDataTrace:
mock_db.engine = "engine"
with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx:
session = mock_session_ctx.return_value.__enter__.return_value
session.scalar.side_effect = [app, account]
session.query.return_value.filter_by.return_value.first.return_value = tenant_join
session.scalar.side_effect = [app, account, tenant_join]
with patch(
"core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository"

View File

@ -76,10 +76,7 @@ def test_get_service_account_with_tenant_tenant_not_found(mock_db_session):
mock_account = MagicMock(spec=Account)
mock_account.id = "creator_id"
mock_db_session.scalar.side_effect = [mock_app, mock_account]
# session.query(TenantAccountJoin).filter_by(...).first() returns None
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
mock_db_session.scalar.side_effect = [mock_app, mock_account, None]
config = MagicMock(spec=BaseTracingConfig)
instance = ConcreteTraceInstance(config)
@ -97,11 +94,10 @@ def test_get_service_account_with_tenant_success(mock_db_session):
mock_account.id = "creator_id"
mock_account.set_tenant_id = MagicMock()
mock_db_session.scalar.side_effect = [mock_app, mock_account]
mock_tenant_join = MagicMock(spec=TenantAccountJoin)
mock_tenant_join.tenant_id = "tenant_id"
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_tenant_join
mock_db_session.scalar.side_effect = [mock_app, mock_account, mock_tenant_join]
config = MagicMock(spec=BaseTracingConfig)
instance = ConcreteTraceInstance(config)