refactor: migrate session.query to select API in core misc modules (#34608)

This commit is contained in:
Renzo 2026-04-06 23:08:34 -05:00 committed by GitHub
parent 2f9667de76
commit b55bef4438
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 26 additions and 56 deletions

View File

@ -509,8 +509,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:return: :return:
""" """
with Session(db.engine, expire_on_commit=False) as session: with Session(db.engine, expire_on_commit=False) as session:
agent_thought: MessageAgentThought | None = ( agent_thought: MessageAgentThought | None = session.scalar(
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first() select(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).limit(1)
) )
if agent_thought: if agent_thought:

View File

@ -345,8 +345,8 @@ class DatasourceManager:
@classmethod @classmethod
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File: def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File:
with session_factory.create_session() as session: with session_factory.create_session() as session:
upload_file = ( upload_file = session.scalar(
session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first() select(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).limit(1)
) )
if not upload_file: if not upload_file:
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}") raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")

View File

@ -467,7 +467,7 @@ class LLMGenerator:
): ):
session = db.session() session = db.session()
app: App | None = session.query(App).where(App.id == flow_id).first() app: App | None = session.scalar(select(App).where(App.id == flow_id).limit(1))
if not app: if not app:
raise ValueError("App not found.") raise ValueError("App not found.")
workflow = workflow_service.get_draft_workflow(app_model=app) workflow = workflow_service.get_draft_workflow(app_model=app)

View File

@ -56,8 +56,10 @@ class BaseTraceInstance(ABC):
if not service_account: if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = ( current_tenant = session.scalar(
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first() select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True))
.limit(1)
) )
if not current_tenant: if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}") raise ValueError(f"Current tenant not found for account {service_account.id}")

View File

@ -241,8 +241,10 @@ class TencentDataTrace(BaseTraceInstance):
if not service_account: if not service_account:
raise ValueError(f"Creator account not found for app {app_id}") raise ValueError(f"Creator account not found for app {app_id}")
current_tenant = ( current_tenant = session.scalar(
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first() select(TenantAccountJoin)
.where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True))
.limit(1)
) )
if not current_tenant: if not current_tenant:
raise ValueError(f"Current tenant not found for account {service_account.id}") raise ValueError(f"Current tenant not found for account {service_account.id}")

View File

@ -505,13 +505,7 @@ class TestEasyUiBasedGenerateTaskPipeline:
def __exit__(self, exc_type, exc, tb): def __exit__(self, exc_type, exc, tb):
return False return False
def query(self, *args, **kwargs): def scalar(self, *args, **kwargs):
return self
def where(self, *args, **kwargs):
return self
def first(self):
return agent_thought return agent_thought
monkeypatch.setattr( monkeypatch.setattr(
@ -1182,13 +1176,7 @@ class TestEasyUiBasedGenerateTaskPipeline:
def __exit__(self, exc_type, exc, tb): def __exit__(self, exc_type, exc, tb):
return False return False
def query(self, *args, **kwargs): def scalar(self, *args, **kwargs):
return self
def where(self, *args, **kwargs):
return self
def first(self):
return None return None
monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session) monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session)

View File

@ -632,16 +632,6 @@ def test_get_upload_file_by_id_builds_file(mocker):
source_url="http://x", source_url="http://x",
) )
class _Q:
def __init__(self, row):
self._row = row
def where(self, *_args, **_kwargs):
return self
def first(self):
return self._row
class _S: class _S:
def __init__(self, row): def __init__(self, row):
self._row = row self._row = row
@ -652,8 +642,8 @@ def test_get_upload_file_by_id_builds_file(mocker):
def __exit__(self, *exc): def __exit__(self, *exc):
return False return False
def query(self, *_): def scalar(self, *_args, **_kwargs):
return _Q(self._row) return self._row
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S(fake_row)) mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S(fake_row))
@ -665,13 +655,6 @@ def test_get_upload_file_by_id_builds_file(mocker):
def test_get_upload_file_by_id_raises_when_missing(mocker): def test_get_upload_file_by_id_raises_when_missing(mocker):
class _Q:
def where(self, *_args, **_kwargs):
return self
def first(self):
return None
class _S: class _S:
def __enter__(self): def __enter__(self):
return self return self
@ -679,8 +662,8 @@ def test_get_upload_file_by_id_raises_when_missing(mocker):
def __exit__(self, *exc): def __exit__(self, *exc):
return False return False
def query(self, *_): def scalar(self, *_args, **_kwargs):
return _Q() return None
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S()) mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S())

View File

@ -346,13 +346,13 @@ class TestLLMGenerator:
def test_instruction_modify_workflow_app_not_found(self): def test_instruction_modify_workflow_app_not_found(self):
with patch("extensions.ext_database.db.session") as mock_session: with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = None mock_session.return_value.scalar.return_value = None
with pytest.raises(ValueError, match="App not found."): with pytest.raises(ValueError, match="App not found."):
LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", MagicMock()) LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", MagicMock())
def test_instruction_modify_workflow_no_workflow(self): def test_instruction_modify_workflow_no_workflow(self):
with patch("extensions.ext_database.db.session") as mock_session: with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() mock_session.return_value.scalar.return_value = MagicMock()
workflow_service = MagicMock() workflow_service = MagicMock()
workflow_service.get_draft_workflow.return_value = None workflow_service.get_draft_workflow.return_value = None
with pytest.raises(ValueError, match="Workflow not found for the given app model."): with pytest.raises(ValueError, match="Workflow not found for the given app model."):
@ -360,7 +360,7 @@ class TestLLMGenerator:
def test_instruction_modify_workflow_success(self, mock_model_instance, model_config_entity): def test_instruction_modify_workflow_success(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session") as mock_session: with patch("extensions.ext_database.db.session") as mock_session:
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() mock_session.return_value.scalar.return_value = MagicMock()
workflow = MagicMock() workflow = MagicMock()
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}} workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}}

View File

@ -407,8 +407,7 @@ class TestTencentDataTrace:
mock_db.engine = "engine" mock_db.engine = "engine"
with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx: with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx:
session = mock_session_ctx.return_value.__enter__.return_value session = mock_session_ctx.return_value.__enter__.return_value
session.scalar.side_effect = [app, account] session.scalar.side_effect = [app, account, tenant_join]
session.query.return_value.filter_by.return_value.first.return_value = tenant_join
with patch( with patch(
"core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository" "core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository"

View File

@ -76,10 +76,7 @@ def test_get_service_account_with_tenant_tenant_not_found(mock_db_session):
mock_account = MagicMock(spec=Account) mock_account = MagicMock(spec=Account)
mock_account.id = "creator_id" mock_account.id = "creator_id"
mock_db_session.scalar.side_effect = [mock_app, mock_account] mock_db_session.scalar.side_effect = [mock_app, mock_account, None]
# session.query(TenantAccountJoin).filter_by(...).first() returns None
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
config = MagicMock(spec=BaseTracingConfig) config = MagicMock(spec=BaseTracingConfig)
instance = ConcreteTraceInstance(config) instance = ConcreteTraceInstance(config)
@ -97,11 +94,10 @@ def test_get_service_account_with_tenant_success(mock_db_session):
mock_account.id = "creator_id" mock_account.id = "creator_id"
mock_account.set_tenant_id = MagicMock() mock_account.set_tenant_id = MagicMock()
mock_db_session.scalar.side_effect = [mock_app, mock_account]
mock_tenant_join = MagicMock(spec=TenantAccountJoin) mock_tenant_join = MagicMock(spec=TenantAccountJoin)
mock_tenant_join.tenant_id = "tenant_id" mock_tenant_join.tenant_id = "tenant_id"
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_tenant_join
mock_db_session.scalar.side_effect = [mock_app, mock_account, mock_tenant_join]
config = MagicMock(spec=BaseTracingConfig) config = MagicMock(spec=BaseTracingConfig)
instance = ConcreteTraceInstance(config) instance = ConcreteTraceInstance(config)