mirror of
https://github.com/langgenius/dify.git
synced 2026-04-25 01:26:57 +08:00
refactor: migrate session.query to select API in core misc modules (#34608)
This commit is contained in:
parent
2f9667de76
commit
b55bef4438
@ -509,8 +509,8 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
with Session(db.engine, expire_on_commit=False) as session:
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
agent_thought: MessageAgentThought | None = (
|
agent_thought: MessageAgentThought | None = session.scalar(
|
||||||
session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
|
select(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).limit(1)
|
||||||
)
|
)
|
||||||
|
|
||||||
if agent_thought:
|
if agent_thought:
|
||||||
|
|||||||
@ -345,8 +345,8 @@ class DatasourceManager:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File:
|
def get_upload_file_by_id(cls, file_id: str, tenant_id: str) -> File:
|
||||||
with session_factory.create_session() as session:
|
with session_factory.create_session() as session:
|
||||||
upload_file = (
|
upload_file = session.scalar(
|
||||||
session.query(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).first()
|
select(UploadFile).where(UploadFile.id == file_id, UploadFile.tenant_id == tenant_id).limit(1)
|
||||||
)
|
)
|
||||||
if not upload_file:
|
if not upload_file:
|
||||||
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
|
raise ValueError(f"UploadFile not found for file_id={file_id}, tenant_id={tenant_id}")
|
||||||
|
|||||||
@ -467,7 +467,7 @@ class LLMGenerator:
|
|||||||
):
|
):
|
||||||
session = db.session()
|
session = db.session()
|
||||||
|
|
||||||
app: App | None = session.query(App).where(App.id == flow_id).first()
|
app: App | None = session.scalar(select(App).where(App.id == flow_id).limit(1))
|
||||||
if not app:
|
if not app:
|
||||||
raise ValueError("App not found.")
|
raise ValueError("App not found.")
|
||||||
workflow = workflow_service.get_draft_workflow(app_model=app)
|
workflow = workflow_service.get_draft_workflow(app_model=app)
|
||||||
|
|||||||
@ -56,8 +56,10 @@ class BaseTraceInstance(ABC):
|
|||||||
if not service_account:
|
if not service_account:
|
||||||
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
|
||||||
|
|
||||||
current_tenant = (
|
current_tenant = session.scalar(
|
||||||
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
|
select(TenantAccountJoin)
|
||||||
|
.where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True))
|
||||||
|
.limit(1)
|
||||||
)
|
)
|
||||||
if not current_tenant:
|
if not current_tenant:
|
||||||
raise ValueError(f"Current tenant not found for account {service_account.id}")
|
raise ValueError(f"Current tenant not found for account {service_account.id}")
|
||||||
|
|||||||
@ -241,8 +241,10 @@ class TencentDataTrace(BaseTraceInstance):
|
|||||||
if not service_account:
|
if not service_account:
|
||||||
raise ValueError(f"Creator account not found for app {app_id}")
|
raise ValueError(f"Creator account not found for app {app_id}")
|
||||||
|
|
||||||
current_tenant = (
|
current_tenant = session.scalar(
|
||||||
session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
|
select(TenantAccountJoin)
|
||||||
|
.where(TenantAccountJoin.account_id == service_account.id, TenantAccountJoin.current.is_(True))
|
||||||
|
.limit(1)
|
||||||
)
|
)
|
||||||
if not current_tenant:
|
if not current_tenant:
|
||||||
raise ValueError(f"Current tenant not found for account {service_account.id}")
|
raise ValueError(f"Current tenant not found for account {service_account.id}")
|
||||||
|
|||||||
@ -505,13 +505,7 @@ class TestEasyUiBasedGenerateTaskPipeline:
|
|||||||
def __exit__(self, exc_type, exc, tb):
|
def __exit__(self, exc_type, exc, tb):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def query(self, *args, **kwargs):
|
def scalar(self, *args, **kwargs):
|
||||||
return self
|
|
||||||
|
|
||||||
def where(self, *args, **kwargs):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def first(self):
|
|
||||||
return agent_thought
|
return agent_thought
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
@ -1182,13 +1176,7 @@ class TestEasyUiBasedGenerateTaskPipeline:
|
|||||||
def __exit__(self, exc_type, exc, tb):
|
def __exit__(self, exc_type, exc, tb):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def query(self, *args, **kwargs):
|
def scalar(self, *args, **kwargs):
|
||||||
return self
|
|
||||||
|
|
||||||
def where(self, *args, **kwargs):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def first(self):
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session)
|
monkeypatch.setattr("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session", _Session)
|
||||||
|
|||||||
@ -632,16 +632,6 @@ def test_get_upload_file_by_id_builds_file(mocker):
|
|||||||
source_url="http://x",
|
source_url="http://x",
|
||||||
)
|
)
|
||||||
|
|
||||||
class _Q:
|
|
||||||
def __init__(self, row):
|
|
||||||
self._row = row
|
|
||||||
|
|
||||||
def where(self, *_args, **_kwargs):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def first(self):
|
|
||||||
return self._row
|
|
||||||
|
|
||||||
class _S:
|
class _S:
|
||||||
def __init__(self, row):
|
def __init__(self, row):
|
||||||
self._row = row
|
self._row = row
|
||||||
@ -652,8 +642,8 @@ def test_get_upload_file_by_id_builds_file(mocker):
|
|||||||
def __exit__(self, *exc):
|
def __exit__(self, *exc):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def query(self, *_):
|
def scalar(self, *_args, **_kwargs):
|
||||||
return _Q(self._row)
|
return self._row
|
||||||
|
|
||||||
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S(fake_row))
|
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S(fake_row))
|
||||||
|
|
||||||
@ -665,13 +655,6 @@ def test_get_upload_file_by_id_builds_file(mocker):
|
|||||||
|
|
||||||
|
|
||||||
def test_get_upload_file_by_id_raises_when_missing(mocker):
|
def test_get_upload_file_by_id_raises_when_missing(mocker):
|
||||||
class _Q:
|
|
||||||
def where(self, *_args, **_kwargs):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def first(self):
|
|
||||||
return None
|
|
||||||
|
|
||||||
class _S:
|
class _S:
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
@ -679,8 +662,8 @@ def test_get_upload_file_by_id_raises_when_missing(mocker):
|
|||||||
def __exit__(self, *exc):
|
def __exit__(self, *exc):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def query(self, *_):
|
def scalar(self, *_args, **_kwargs):
|
||||||
return _Q()
|
return None
|
||||||
|
|
||||||
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S())
|
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S())
|
||||||
|
|
||||||
|
|||||||
@ -346,13 +346,13 @@ class TestLLMGenerator:
|
|||||||
|
|
||||||
def test_instruction_modify_workflow_app_not_found(self):
|
def test_instruction_modify_workflow_app_not_found(self):
|
||||||
with patch("extensions.ext_database.db.session") as mock_session:
|
with patch("extensions.ext_database.db.session") as mock_session:
|
||||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = None
|
mock_session.return_value.scalar.return_value = None
|
||||||
with pytest.raises(ValueError, match="App not found."):
|
with pytest.raises(ValueError, match="App not found."):
|
||||||
LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", MagicMock())
|
LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", MagicMock())
|
||||||
|
|
||||||
def test_instruction_modify_workflow_no_workflow(self):
|
def test_instruction_modify_workflow_no_workflow(self):
|
||||||
with patch("extensions.ext_database.db.session") as mock_session:
|
with patch("extensions.ext_database.db.session") as mock_session:
|
||||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
mock_session.return_value.scalar.return_value = MagicMock()
|
||||||
workflow_service = MagicMock()
|
workflow_service = MagicMock()
|
||||||
workflow_service.get_draft_workflow.return_value = None
|
workflow_service.get_draft_workflow.return_value = None
|
||||||
with pytest.raises(ValueError, match="Workflow not found for the given app model."):
|
with pytest.raises(ValueError, match="Workflow not found for the given app model."):
|
||||||
@ -360,7 +360,7 @@ class TestLLMGenerator:
|
|||||||
|
|
||||||
def test_instruction_modify_workflow_success(self, mock_model_instance, model_config_entity):
|
def test_instruction_modify_workflow_success(self, mock_model_instance, model_config_entity):
|
||||||
with patch("extensions.ext_database.db.session") as mock_session:
|
with patch("extensions.ext_database.db.session") as mock_session:
|
||||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
mock_session.return_value.scalar.return_value = MagicMock()
|
||||||
workflow = MagicMock()
|
workflow = MagicMock()
|
||||||
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}}
|
workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}}
|
||||||
|
|
||||||
|
|||||||
@ -407,8 +407,7 @@ class TestTencentDataTrace:
|
|||||||
mock_db.engine = "engine"
|
mock_db.engine = "engine"
|
||||||
with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx:
|
with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx:
|
||||||
session = mock_session_ctx.return_value.__enter__.return_value
|
session = mock_session_ctx.return_value.__enter__.return_value
|
||||||
session.scalar.side_effect = [app, account]
|
session.scalar.side_effect = [app, account, tenant_join]
|
||||||
session.query.return_value.filter_by.return_value.first.return_value = tenant_join
|
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository"
|
"core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository"
|
||||||
|
|||||||
@ -76,10 +76,7 @@ def test_get_service_account_with_tenant_tenant_not_found(mock_db_session):
|
|||||||
mock_account = MagicMock(spec=Account)
|
mock_account = MagicMock(spec=Account)
|
||||||
mock_account.id = "creator_id"
|
mock_account.id = "creator_id"
|
||||||
|
|
||||||
mock_db_session.scalar.side_effect = [mock_app, mock_account]
|
mock_db_session.scalar.side_effect = [mock_app, mock_account, None]
|
||||||
|
|
||||||
# session.query(TenantAccountJoin).filter_by(...).first() returns None
|
|
||||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = None
|
|
||||||
|
|
||||||
config = MagicMock(spec=BaseTracingConfig)
|
config = MagicMock(spec=BaseTracingConfig)
|
||||||
instance = ConcreteTraceInstance(config)
|
instance = ConcreteTraceInstance(config)
|
||||||
@ -97,11 +94,10 @@ def test_get_service_account_with_tenant_success(mock_db_session):
|
|||||||
mock_account.id = "creator_id"
|
mock_account.id = "creator_id"
|
||||||
mock_account.set_tenant_id = MagicMock()
|
mock_account.set_tenant_id = MagicMock()
|
||||||
|
|
||||||
mock_db_session.scalar.side_effect = [mock_app, mock_account]
|
|
||||||
|
|
||||||
mock_tenant_join = MagicMock(spec=TenantAccountJoin)
|
mock_tenant_join = MagicMock(spec=TenantAccountJoin)
|
||||||
mock_tenant_join.tenant_id = "tenant_id"
|
mock_tenant_join.tenant_id = "tenant_id"
|
||||||
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_tenant_join
|
|
||||||
|
mock_db_session.scalar.side_effect = [mock_app, mock_account, mock_tenant_join]
|
||||||
|
|
||||||
config = MagicMock(spec=BaseTracingConfig)
|
config = MagicMock(spec=BaseTracingConfig)
|
||||||
instance = ConcreteTraceInstance(config)
|
instance = ConcreteTraceInstance(config)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user