mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 18:06:36 +08:00
refactor: use sessionmaker in small services 2 (#34696)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
909c062ee1
commit
a65e1f71b4
@ -9,7 +9,7 @@ from typing import Any, TypedDict, cast
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
|
||||
class InvitationData(TypedDict):
|
||||
@ -1516,7 +1516,7 @@ class RegisterService:
|
||||
|
||||
check_workspace_member_invite_permission(tenant.id)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
if not account:
|
||||
|
||||
@ -11,7 +11,7 @@ from typing import Any, Union
|
||||
|
||||
from celery.result import AsyncResult
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from enums.quota_type import QuotaType
|
||||
from extensions.ext_database import db
|
||||
@ -237,7 +237,7 @@ class AsyncWorkflowService:
|
||||
Returns:
|
||||
Trigger log as dictionary or None if not found
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id, tenant_id)
|
||||
|
||||
@ -263,7 +263,7 @@ class AsyncWorkflowService:
|
||||
Returns:
|
||||
List of trigger logs as dictionaries
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
logs = trigger_log_repo.get_recent_logs(
|
||||
tenant_id=tenant_id, app_id=app_id, hours=hours, limit=limit, offset=offset
|
||||
@ -286,7 +286,7 @@ class AsyncWorkflowService:
|
||||
Returns:
|
||||
List of failed trigger logs as dictionaries
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
logs = trigger_log_repo.get_failed_for_retry(
|
||||
tenant_id=tenant_id, max_retry_count=max_retry_count, limit=limit
|
||||
|
||||
@ -346,7 +346,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
|
||||
current_time = started_at
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
total_tenant_count = session.query(Tenant.id).count()
|
||||
|
||||
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
|
||||
@ -398,7 +398,7 @@ class ClearFreePlanTenantExpiredLogs:
|
||||
# Initial interval of 1 day, will be dynamically adjusted based on tenant count
|
||||
interval = datetime.timedelta(days=1)
|
||||
# Process tenants in this batch
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
# Calculate tenant count in next batch with current interval
|
||||
# Try different intervals until we find one with a reasonable tenant count
|
||||
test_intervals = [
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.errors.error import QuotaExceededError
|
||||
@ -71,7 +71,7 @@ class CreditPoolService:
|
||||
actual_credits = min(credits_required, pool.remaining_credits)
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
stmt = (
|
||||
update(TenantCreditPool)
|
||||
.where(
|
||||
@ -81,7 +81,6 @@ class CreditPoolService:
|
||||
.values(quota_used=TenantCreditPool.quota_used + actual_credits)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to deduct credits for tenant %s", tenant_id)
|
||||
raise QuotaExceededError("Failed to deduct credits")
|
||||
|
||||
@ -15,7 +15,7 @@ from graphon.model_runtime.entities.model_entities import ModelFeature, ModelTyp
|
||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from redis.exceptions import LockNotOwnedError
|
||||
from sqlalchemy import delete, exists, func, select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
@ -551,7 +551,7 @@ class DatasetService:
|
||||
external_knowledge_id: External knowledge identifier
|
||||
external_knowledge_api_id: External knowledge API identifier
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
external_knowledge_binding = (
|
||||
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
|
||||
)
|
||||
@ -559,14 +559,14 @@ class DatasetService:
|
||||
if not external_knowledge_binding:
|
||||
raise ValueError("External knowledge binding not found.")
|
||||
|
||||
# Update binding if values have changed
|
||||
if (
|
||||
external_knowledge_binding.external_knowledge_id != external_knowledge_id
|
||||
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
|
||||
):
|
||||
external_knowledge_binding.external_knowledge_id = external_knowledge_id
|
||||
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
|
||||
db.session.add(external_knowledge_binding)
|
||||
# Update binding if values have changed
|
||||
if (
|
||||
external_knowledge_binding.external_knowledge_id != external_knowledge_id
|
||||
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
|
||||
):
|
||||
external_knowledge_binding.external_knowledge_id = external_knowledge_id
|
||||
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
|
||||
session.add(external_knowledge_binding)
|
||||
|
||||
@staticmethod
|
||||
def _update_internal_dataset(dataset, data, user):
|
||||
|
||||
@ -2,7 +2,7 @@ import enum
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from extensions.ext_database import db
|
||||
@ -29,7 +29,7 @@ class OAuthServerService:
|
||||
def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None:
|
||||
query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||
return session.execute(query).scalar_one_or_none()
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -1182,7 +1182,7 @@ class RagPipelineService:
|
||||
workflow = db.session.get(Workflow, pipeline.workflow_id)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
dataset = pipeline.retrieve_dataset(session=session)
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
@ -1209,7 +1209,7 @@ class RagPipelineService:
|
||||
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||
dsl = rag_pipeline_dsl_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
|
||||
if args.get("icon_info") is None:
|
||||
|
||||
@ -834,7 +834,7 @@ class WorkflowService:
|
||||
if workflow_node_execution is None:
|
||||
raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
outputs = workflow_node_execution.load_full_outputs(session, storage)
|
||||
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
|
||||
@ -1427,16 +1427,18 @@ class TestRegisterService:
|
||||
mock_tenant.name = "Test Workspace"
|
||||
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
|
||||
|
||||
# Mock database queries - need to mock the Session query
|
||||
# Mock database queries - need to mock the sessionmaker query
|
||||
mock_session = MagicMock()
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account
|
||||
|
||||
mock_sessionmaker = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
patch("services.account_service.Session") as mock_session_class,
|
||||
patch("services.account_service.sessionmaker", mock_sessionmaker),
|
||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||
):
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value.__exit__.return_value = None
|
||||
mock_lookup.return_value = None
|
||||
|
||||
# Mock RegisterService.register
|
||||
@ -1485,12 +1487,14 @@ class TestRegisterService:
|
||||
mixed_email = "Invitee@Example.com"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_sessionmaker = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
patch("services.account_service.Session") as mock_session_class,
|
||||
patch("services.account_service.sessionmaker", mock_sessionmaker),
|
||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||
):
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value.__exit__.return_value = None
|
||||
mock_lookup.return_value = None
|
||||
|
||||
mock_new_account = TestAccountAssociatedDataFactory.create_account_mock(
|
||||
@ -1541,16 +1545,18 @@ class TestRegisterService:
|
||||
account_id="existing-user-456", email="existing@example.com", status="pending"
|
||||
)
|
||||
|
||||
# Mock database queries - need to mock the Session query
|
||||
# Mock database queries - need to mock the sessionmaker query
|
||||
mock_session = MagicMock()
|
||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account
|
||||
|
||||
mock_sessionmaker = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
patch("services.account_service.Session") as mock_session_class,
|
||||
patch("services.account_service.sessionmaker", mock_sessionmaker),
|
||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||
):
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session_class.return_value.__exit__.return_value = None
|
||||
mock_lookup.return_value = mock_existing_account
|
||||
|
||||
# Mock scalar for TenantAccountJoin lookup - no existing member
|
||||
|
||||
@ -357,11 +357,12 @@ class TestAsyncWorkflowService:
|
||||
mock_session_context.__enter__.return_value = mock_session
|
||||
mock_session_context.__exit__.return_value = None
|
||||
|
||||
mock_sessionmaker = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value = mock_session_context
|
||||
|
||||
with (
|
||||
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=fake_engine)),
|
||||
patch.object(
|
||||
async_workflow_service_module, "Session", return_value=mock_session_context
|
||||
) as mock_session_class,
|
||||
patch.object(async_workflow_service_module, "sessionmaker", mock_sessionmaker),
|
||||
patch.object(
|
||||
async_workflow_service_module,
|
||||
"SQLAlchemyWorkflowTriggerLogRepository",
|
||||
@ -373,7 +374,7 @@ class TestAsyncWorkflowService:
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
mock_session_class.assert_called_once_with(fake_engine)
|
||||
mock_sessionmaker.assert_called_once_with(fake_engine)
|
||||
mock_repo.get_by_id.assert_called_once_with("trigger-log-123", "tenant-123")
|
||||
|
||||
def test_should_return_recent_logs_as_dict_list(self):
|
||||
@ -391,9 +392,12 @@ class TestAsyncWorkflowService:
|
||||
mock_session_context.__enter__.return_value = mock_session
|
||||
mock_session_context.__exit__.return_value = None
|
||||
|
||||
mock_sessionmaker = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value = mock_session_context
|
||||
|
||||
with (
|
||||
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())),
|
||||
patch.object(async_workflow_service_module, "Session", return_value=mock_session_context),
|
||||
patch.object(async_workflow_service_module, "sessionmaker", mock_sessionmaker),
|
||||
patch.object(
|
||||
async_workflow_service_module,
|
||||
"SQLAlchemyWorkflowTriggerLogRepository",
|
||||
@ -432,9 +436,12 @@ class TestAsyncWorkflowService:
|
||||
mock_session_context.__enter__.return_value = mock_session
|
||||
mock_session_context.__exit__.return_value = None
|
||||
|
||||
mock_sessionmaker = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value = mock_session_context
|
||||
|
||||
with (
|
||||
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())),
|
||||
patch.object(async_workflow_service_module, "Session", return_value=mock_session_context),
|
||||
patch.object(async_workflow_service_module, "sessionmaker", mock_sessionmaker),
|
||||
patch.object(
|
||||
async_workflow_service_module,
|
||||
"SQLAlchemyWorkflowTriggerLogRepository",
|
||||
|
||||
@ -209,8 +209,22 @@ def _session_wrapper_for_no_autoflush(session: Mock) -> Mock:
|
||||
return wrapper
|
||||
|
||||
|
||||
def _sessionmaker_wrapper_for_begin(session: Mock) -> Mock:
|
||||
"""
|
||||
ClearFreePlanTenantExpiredLogs.process uses: with sessionmaker(db.engine).begin() as session:
|
||||
so sessionmaker(db.engine) must return an object with a begin() method that returns a context manager.
|
||||
"""
|
||||
begin_cm = MagicMock()
|
||||
begin_cm.__enter__.return_value = session
|
||||
begin_cm.__exit__.return_value = None
|
||||
|
||||
sessionmaker_result = MagicMock()
|
||||
sessionmaker_result.begin.return_value = begin_cm
|
||||
return sessionmaker_result
|
||||
|
||||
|
||||
def _session_wrapper_for_direct(session: Mock) -> Mock:
|
||||
"""ClearFreePlanTenantExpiredLogs.process uses: with Session(db.engine) as session:"""
|
||||
"""ClearFreePlanTenantExpiredLogs.process uses: with Session(db.engine) as session: (for old code paths)"""
|
||||
wrapper = MagicMock()
|
||||
wrapper.__enter__.return_value = session
|
||||
wrapper.__exit__.return_value = None
|
||||
@ -348,7 +362,7 @@ def test_process_with_tenant_ids_filters_by_plan_and_logs_errors(monkeypatch: py
|
||||
count_query.count.return_value = 2
|
||||
count_session.query.return_value = count_query
|
||||
|
||||
monkeypatch.setattr(service_module, "Session", lambda _engine: _session_wrapper_for_direct(count_session))
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: _sessionmaker_wrapper_for_begin(count_session))
|
||||
|
||||
# Avoid LocalProxy usage
|
||||
flask_app = service_module.Flask("test-app")
|
||||
@ -438,8 +452,8 @@ def test_process_without_tenant_ids_batches_and_scales_interval(monkeypatch: pyt
|
||||
|
||||
batch_session.query.side_effect = [q1, q2, q3, q4, q_rs]
|
||||
|
||||
sessions = [_session_wrapper_for_direct(total_session), _session_wrapper_for_direct(batch_session)]
|
||||
monkeypatch.setattr(service_module, "Session", lambda _engine: sessions.pop(0))
|
||||
sessions = [_sessionmaker_wrapper_for_begin(total_session), _sessionmaker_wrapper_for_begin(batch_session)]
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: sessions.pop(0))
|
||||
|
||||
process_tenant_mock = MagicMock()
|
||||
monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock)
|
||||
@ -457,7 +471,7 @@ def test_process_with_tenant_ids_emits_progress_every_100(monkeypatch: pytest.Mo
|
||||
count_query = MagicMock()
|
||||
count_query.count.return_value = 100
|
||||
count_session.query.return_value = count_query
|
||||
monkeypatch.setattr(service_module, "Session", lambda _engine: _session_wrapper_for_direct(count_session))
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: _sessionmaker_wrapper_for_begin(count_session))
|
||||
|
||||
flask_app = service_module.Flask("test-app")
|
||||
monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app))
|
||||
@ -523,8 +537,8 @@ def test_process_without_tenant_ids_all_intervals_too_many_uses_min_interval(mon
|
||||
|
||||
batch_session.query.side_effect = [*count_queries, q_rs]
|
||||
|
||||
sessions = [_session_wrapper_for_direct(total_session), _session_wrapper_for_direct(batch_session)]
|
||||
monkeypatch.setattr(service_module, "Session", lambda _engine: sessions.pop(0))
|
||||
sessions = [_sessionmaker_wrapper_for_begin(total_session), _sessionmaker_wrapper_for_begin(batch_session)]
|
||||
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: sessions.pop(0))
|
||||
|
||||
process_tenant_mock = MagicMock()
|
||||
monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock)
|
||||
|
||||
@ -578,26 +578,33 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
binding = SimpleNamespace(external_knowledge_id="old-knowledge", external_knowledge_api_id="old-api")
|
||||
session = MagicMock()
|
||||
session.query.return_value.filter_by.return_value.first.return_value = binding
|
||||
session.add = MagicMock()
|
||||
session_context = _make_session_context(session)
|
||||
|
||||
mock_sessionmaker = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value = session_context
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.db") as mock_db,
|
||||
patch("services.dataset_service.Session", return_value=session_context),
|
||||
patch("services.dataset_service.sessionmaker", mock_sessionmaker),
|
||||
):
|
||||
DatasetService._update_external_knowledge_binding("dataset-1", "new-knowledge", "new-api")
|
||||
|
||||
assert binding.external_knowledge_id == "new-knowledge"
|
||||
assert binding.external_knowledge_api_id == "new-api"
|
||||
mock_db.session.add.assert_called_once_with(binding)
|
||||
session.add.assert_called_once_with(binding)
|
||||
|
||||
def test_update_external_knowledge_binding_raises_for_missing_binding(self):
|
||||
session = MagicMock()
|
||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
session_context = _make_session_context(session)
|
||||
|
||||
mock_sessionmaker = MagicMock()
|
||||
mock_sessionmaker.return_value.begin.return_value = session_context
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.db"),
|
||||
patch("services.dataset_service.Session", return_value=session_context),
|
||||
patch("services.dataset_service.sessionmaker", mock_sessionmaker),
|
||||
):
|
||||
with pytest.raises(ValueError, match="External knowledge binding not found"):
|
||||
DatasetService._update_external_knowledge_binding("dataset-1", "knowledge-1", "api-1")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user