mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 18:06:36 +08:00
refactor: migrate session.query to select API in core misc modules (#34608)
This commit is contained in:
parent
2f9667de76
commit
b55bef4438
@ -509,8 +509,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
:return:
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
agent_thought: MessageAgentThought | None = (
|
||||
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
|
||||
agent_thought: MessageAgentThought | None = session.scalar(
|
||||
select(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).limit(1)
|
||||
)
|
||||
|
||||
if agent_thought:
|
||||
|
||||
@ -345,8 +345,8 @@ class DatasourceManager:
|
||||
@classmethod
|
||||
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File:
|
||||
with session_factory.create_session() as session:
|
||||
upload_file = (
|
||||
session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first()
|
||||
upload_file = session.scalar(
|
||||
select(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).limit(1)
|
||||
)
|
||||
if not upload_file:
|
||||
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
|
||||
|
||||
@ -467,7 +467,7 @@ class LLMGenerator:
|
||||
):
|
||||
session = db.session()
|
||||
|
||||
app: App | None = session.query(App).where(App.id == flow_id).first()
|
||||
app: App | None = session.scalar(select(App).where(App.id == flow_id).limit(1))
|
||||
if not app:
|
||||
raise ValueError("App not found.")
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app)
|
||||
|
||||
@ -56,8 +56,10 @@ class BaseTraceInstance(ABC):
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||
|
||||
current_tenant = (
|
||||
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
|
||||
current_tenant = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True))
|
||||
.limit(1)
|
||||
)
|
||||
if not current_tenant:
|
||||
raise ValueError(f"Current tenant not found for account {service_account.id}")
|
||||
|
||||
@ -241,8 +241,10 @@ class TencentDataTrace(BaseTraceInstance):
|
||||
if not service_account:
|
||||
raise ValueError(f"Creator account not found for app {app_id}")
|
||||
|
||||
current_tenant = (
|
||||
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
|
||||
current_tenant = session.scalar(
|
||||
select(TenantAccountJoin)
|
||||
.where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True))
|
||||
.limit(1)
|
||||
)
|
||||
if not current_tenant:
|
||||
raise ValueError(f"Current tenant not found for account {service_account.id}")
|
||||
|
||||
@ -505,13 +505,7 @@ class TestEasyUiBasedGenerateTaskPipeline:
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def query(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def where(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
def scalar(self, *args, **kwargs):
|
||||
return agent_thought
|
||||
|
||||
monkeypatch.setattr(
|
||||
@ -1182,13 +1176,7 @@ class TestEasyUiBasedGenerateTaskPipeline:
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def query(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def where(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
def scalar(self, *args, **kwargs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session)
|
||||
|
||||
@ -632,16 +632,6 @@ def test_get_upload_file_by_id_builds_file(mocker):
|
||||
source_url="http://x",
|
||||
)
|
||||
|
||||
class _Q:
|
||||
def __init__(self, row):
|
||||
self._row = row
|
||||
|
||||
def where(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._row
|
||||
|
||||
class _S:
|
||||
def __init__(self, row):
|
||||
self._row = row
|
||||
@ -652,8 +642,8 @@ def test_get_upload_file_by_id_builds_file(mocker):
|
||||
def __exit__(self, *exc):
|
||||
return False
|
||||
|
||||
def query(self, *_):
|
||||
return _Q(self._row)
|
||||
def scalar(self, *_args, **_kwargs):
|
||||
return self._row
|
||||
|
||||
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S(fake_row))
|
||||
|
||||
@ -665,13 +655,6 @@ def test_get_upload_file_by_id_builds_file(mocker):
|
||||
|
||||
|
||||
def test_get_upload_file_by_id_raises_when_missing(mocker):
|
||||
class _Q:
|
||||
def where(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return None
|
||||
|
||||
class _S:
|
||||
def __enter__(self):
|
||||
return self
|
||||
@ -679,8 +662,8 @@ def test_get_upload_file_by_id_raises_when_missing(mocker):
|
||||
def __exit__(self, *exc):
|
||||
return False
|
||||
|
||||
def query(self, *_):
|
||||
return _Q()
|
||||
def scalar(self, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S())
|
||||
|
||||
|
||||
@ -346,13 +346,13 @@ class TestLLMGenerator:
|
||||
|
||||
def test_instruction_modify_workflow_app_not_found(self):
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = None
|
||||
mock_session.return_value.scalar.return_value = None
|
||||
with pytest.raises(ValueError, match="App not found."):
|
||||
LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", MagicMock())
|
||||
|
||||
def test_instruction_modify_workflow_no_workflow(self):
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
mock_session.return_value.scalar.return_value = MagicMock()
|
||||
workflow_service = MagicMock()
|
||||
workflow_service.get_draft_workflow.return_value = None
|
||||
with pytest.raises(ValueError, match="Workflow not found for the given app model."):
|
||||
@ -360,7 +360,7 @@ class TestLLMGenerator:
|
||||
|
||||
def test_instruction_modify_workflow_success(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
mock_session.return_value.scalar.return_value = MagicMock()
|
||||
workflow = MagicMock()
|
||||
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}}
|
||||
|
||||
|
||||
@ -407,8 +407,7 @@ class TestTencentDataTrace:
|
||||
mock_db.engine = "engine"
|
||||
with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx:
|
||||
session = mock_session_ctx.return_value.__enter__.return_value
|
||||
session.scalar.side_effect = [app, account]
|
||||
session.query.return_value.filter_by.return_value.first.return_value = tenant_join
|
||||
session.scalar.side_effect = [app, account, tenant_join]
|
||||
|
||||
with patch(
|
||||
"core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository"
|
||||
|
||||
@ -76,10 +76,7 @@ def test_get_service_account_with_tenant_tenant_not_found(mock_db_session):
|
||||
mock_account = MagicMock(spec=Account)
|
||||
mock_account.id = "creator_id"
|
||||
|
||||
mock_db_session.scalar.side_effect = [mock_app, mock_account]
|
||||
|
||||
# session.query(TenantAccountJoin).filter_by(...).first() returns None
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_db_session.scalar.side_effect = [mock_app, mock_account, None]
|
||||
|
||||
config = MagicMock(spec=BaseTracingConfig)
|
||||
instance = ConcreteTraceInstance(config)
|
||||
@ -97,11 +94,10 @@ def test_get_service_account_with_tenant_success(mock_db_session):
|
||||
mock_account.id = "creator_id"
|
||||
mock_account.set_tenant_id = MagicMock()
|
||||
|
||||
mock_db_session.scalar.side_effect = [mock_app, mock_account]
|
||||
|
||||
mock_tenant_join = MagicMock(spec=TenantAccountJoin)
|
||||
mock_tenant_join.tenant_id = "tenant_id"
|
||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_tenant_join
|
||||
|
||||
mock_db_session.scalar.side_effect = [mock_app, mock_account, mock_tenant_join]
|
||||
|
||||
config = MagicMock(spec=BaseTracingConfig)
|
||||
instance = ConcreteTraceInstance(config)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user