mirror of
https://github.com/langgenius/dify.git
synced 2026-04-27 19:27:23 +08:00
refactor: use sessionmaker in small services 2 (#34696)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
909c062ee1
commit
a65e1f71b4
@ -9,7 +9,7 @@ from typing import Any, TypedDict, cast
|
|||||||
|
|
||||||
from pydantic import BaseModel, TypeAdapter
|
from pydantic import BaseModel, TypeAdapter
|
||||||
from sqlalchemy import delete, func, select, update
|
from sqlalchemy import delete, func, select, update
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
|
|
||||||
class InvitationData(TypedDict):
|
class InvitationData(TypedDict):
|
||||||
@ -1516,7 +1516,7 @@ class RegisterService:
|
|||||||
|
|
||||||
check_workspace_member_invite_permission(tenant.id)
|
check_workspace_member_invite_permission(tenant.id)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||||
|
|
||||||
if not account:
|
if not account:
|
||||||
|
|||||||
@ -11,7 +11,7 @@ from typing import Any, Union
|
|||||||
|
|
||||||
from celery.result import AsyncResult
|
from celery.result import AsyncResult
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from enums.quota_type import QuotaType
|
from enums.quota_type import QuotaType
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -237,7 +237,7 @@ class AsyncWorkflowService:
|
|||||||
Returns:
|
Returns:
|
||||||
Trigger log as dictionary or None if not found
|
Trigger log as dictionary or None if not found
|
||||||
"""
|
"""
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||||
trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id, tenant_id)
|
trigger_log = trigger_log_repo.get_by_id(workflow_trigger_log_id, tenant_id)
|
||||||
|
|
||||||
@ -263,7 +263,7 @@ class AsyncWorkflowService:
|
|||||||
Returns:
|
Returns:
|
||||||
List of trigger logs as dictionaries
|
List of trigger logs as dictionaries
|
||||||
"""
|
"""
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||||
logs = trigger_log_repo.get_recent_logs(
|
logs = trigger_log_repo.get_recent_logs(
|
||||||
tenant_id=tenant_id, app_id=app_id, hours=hours, limit=limit, offset=offset
|
tenant_id=tenant_id, app_id=app_id, hours=hours, limit=limit, offset=offset
|
||||||
@ -286,7 +286,7 @@ class AsyncWorkflowService:
|
|||||||
Returns:
|
Returns:
|
||||||
List of failed trigger logs as dictionaries
|
List of failed trigger logs as dictionaries
|
||||||
"""
|
"""
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||||
logs = trigger_log_repo.get_failed_for_retry(
|
logs = trigger_log_repo.get_failed_for_retry(
|
||||||
tenant_id=tenant_id, max_retry_count=max_retry_count, limit=limit
|
tenant_id=tenant_id, max_retry_count=max_retry_count, limit=limit
|
||||||
|
|||||||
@ -346,7 +346,7 @@ class ClearFreePlanTenantExpiredLogs:
|
|||||||
started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
|
started_at = datetime.datetime(2023, 4, 3, 8, 59, 24)
|
||||||
current_time = started_at
|
current_time = started_at
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
total_tenant_count = session.query(Tenant.id).count()
|
total_tenant_count = session.query(Tenant.id).count()
|
||||||
|
|
||||||
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
|
click.echo(click.style(f"Total tenant count: {total_tenant_count}", fg="white"))
|
||||||
@ -398,7 +398,7 @@ class ClearFreePlanTenantExpiredLogs:
|
|||||||
# Initial interval of 1 day, will be dynamically adjusted based on tenant count
|
# Initial interval of 1 day, will be dynamically adjusted based on tenant count
|
||||||
interval = datetime.timedelta(days=1)
|
interval = datetime.timedelta(days=1)
|
||||||
# Process tenants in this batch
|
# Process tenants in this batch
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
# Calculate tenant count in next batch with current interval
|
# Calculate tenant count in next batch with current interval
|
||||||
# Try different intervals until we find one with a reasonable tenant count
|
# Try different intervals until we find one with a reasonable tenant count
|
||||||
test_intervals = [
|
test_intervals = [
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
from sqlalchemy import select, update
|
from sqlalchemy import select, update
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.errors.error import QuotaExceededError
|
from core.errors.error import QuotaExceededError
|
||||||
@ -71,7 +71,7 @@ class CreditPoolService:
|
|||||||
actual_credits = min(credits_required, pool.remaining_credits)
|
actual_credits = min(credits_required, pool.remaining_credits)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
stmt = (
|
stmt = (
|
||||||
update(TenantCreditPool)
|
update(TenantCreditPool)
|
||||||
.where(
|
.where(
|
||||||
@ -81,7 +81,6 @@ class CreditPoolService:
|
|||||||
.values(quota_used=TenantCreditPool.quota_used + actual_credits)
|
.values(quota_used=TenantCreditPool.quota_used + actual_credits)
|
||||||
)
|
)
|
||||||
session.execute(stmt)
|
session.execute(stmt)
|
||||||
session.commit()
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to deduct credits for tenant %s", tenant_id)
|
logger.exception("Failed to deduct credits for tenant %s", tenant_id)
|
||||||
raise QuotaExceededError("Failed to deduct credits")
|
raise QuotaExceededError("Failed to deduct credits")
|
||||||
|
|||||||
@ -15,7 +15,7 @@ from graphon.model_runtime.entities.model_entities import ModelFeature, ModelTyp
|
|||||||
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||||
from redis.exceptions import LockNotOwnedError
|
from redis.exceptions import LockNotOwnedError
|
||||||
from sqlalchemy import delete, exists, func, select, update
|
from sqlalchemy import delete, exists, func, select, update
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
from werkzeug.exceptions import Forbidden, NotFound
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -551,7 +551,7 @@ class DatasetService:
|
|||||||
external_knowledge_id: External knowledge identifier
|
external_knowledge_id: External knowledge identifier
|
||||||
external_knowledge_api_id: External knowledge API identifier
|
external_knowledge_api_id: External knowledge API identifier
|
||||||
"""
|
"""
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
external_knowledge_binding = (
|
external_knowledge_binding = (
|
||||||
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
|
session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
|
||||||
)
|
)
|
||||||
@ -559,14 +559,14 @@ class DatasetService:
|
|||||||
if not external_knowledge_binding:
|
if not external_knowledge_binding:
|
||||||
raise ValueError("External knowledge binding not found.")
|
raise ValueError("External knowledge binding not found.")
|
||||||
|
|
||||||
# Update binding if values have changed
|
# Update binding if values have changed
|
||||||
if (
|
if (
|
||||||
external_knowledge_binding.external_knowledge_id != external_knowledge_id
|
external_knowledge_binding.external_knowledge_id != external_knowledge_id
|
||||||
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
|
or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
|
||||||
):
|
):
|
||||||
external_knowledge_binding.external_knowledge_id = external_knowledge_id
|
external_knowledge_binding.external_knowledge_id = external_knowledge_id
|
||||||
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
|
external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
|
||||||
db.session.add(external_knowledge_binding)
|
session.add(external_knowledge_binding)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _update_internal_dataset(dataset, data, user):
|
def _update_internal_dataset(dataset, data, user):
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import enum
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import sessionmaker
|
||||||
from werkzeug.exceptions import BadRequest
|
from werkzeug.exceptions import BadRequest
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -29,7 +29,7 @@ class OAuthServerService:
|
|||||||
def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None:
|
def get_oauth_provider_app(client_id: str) -> OAuthProviderApp | None:
|
||||||
query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id)
|
query = select(OAuthProviderApp).where(OAuthProviderApp.client_id == client_id)
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
|
||||||
return session.execute(query).scalar_one_or_none()
|
return session.execute(query).scalar_one_or_none()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -1182,7 +1182,7 @@ class RagPipelineService:
|
|||||||
workflow = db.session.get(Workflow, pipeline.workflow_id)
|
workflow = db.session.get(Workflow, pipeline.workflow_id)
|
||||||
if not workflow:
|
if not workflow:
|
||||||
raise ValueError("Workflow not found")
|
raise ValueError("Workflow not found")
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
dataset = pipeline.retrieve_dataset(session=session)
|
dataset = pipeline.retrieve_dataset(session=session)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("Dataset not found")
|
raise ValueError("Dataset not found")
|
||||||
@ -1209,7 +1209,7 @@ class RagPipelineService:
|
|||||||
|
|
||||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
rag_pipeline_dsl_service = RagPipelineDslService(session)
|
||||||
dsl = rag_pipeline_dsl_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
|
dsl = rag_pipeline_dsl_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=True)
|
||||||
if args.get("icon_info") is None:
|
if args.get("icon_info") is None:
|
||||||
|
|||||||
@ -834,7 +834,7 @@ class WorkflowService:
|
|||||||
if workflow_node_execution is None:
|
if workflow_node_execution is None:
|
||||||
raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving")
|
raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving")
|
||||||
|
|
||||||
with Session(db.engine) as session:
|
with sessionmaker(db.engine).begin() as session:
|
||||||
outputs = workflow_node_execution.load_full_outputs(session, storage)
|
outputs = workflow_node_execution.load_full_outputs(session, storage)
|
||||||
|
|
||||||
with Session(bind=db.engine) as session, session.begin():
|
with Session(bind=db.engine) as session, session.begin():
|
||||||
|
|||||||
@ -1427,16 +1427,18 @@ class TestRegisterService:
|
|||||||
mock_tenant.name = "Test Workspace"
|
mock_tenant.name = "Test Workspace"
|
||||||
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
|
mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter")
|
||||||
|
|
||||||
# Mock database queries - need to mock the Session query
|
# Mock database queries - need to mock the sessionmaker query
|
||||||
mock_session = MagicMock()
|
mock_session = MagicMock()
|
||||||
mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account
|
mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account
|
||||||
|
|
||||||
|
mock_sessionmaker = MagicMock()
|
||||||
|
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||||
|
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("services.account_service.Session") as mock_session_class,
|
patch("services.account_service.sessionmaker", mock_sessionmaker),
|
||||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||||
):
|
):
|
||||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
|
||||||
mock_session_class.return_value.__exit__.return_value = None
|
|
||||||
mock_lookup.return_value = None
|
mock_lookup.return_value = None
|
||||||
|
|
||||||
# Mock RegisterService.register
|
# Mock RegisterService.register
|
||||||
@ -1485,12 +1487,14 @@ class TestRegisterService:
|
|||||||
mixed_email = "Invitee@Example.com"
|
mixed_email = "Invitee@Example.com"
|
||||||
|
|
||||||
mock_session = MagicMock()
|
mock_session = MagicMock()
|
||||||
|
mock_sessionmaker = MagicMock()
|
||||||
|
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||||
|
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("services.account_service.Session") as mock_session_class,
|
patch("services.account_service.sessionmaker", mock_sessionmaker),
|
||||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||||
):
|
):
|
||||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
|
||||||
mock_session_class.return_value.__exit__.return_value = None
|
|
||||||
mock_lookup.return_value = None
|
mock_lookup.return_value = None
|
||||||
|
|
||||||
mock_new_account = TestAccountAssociatedDataFactory.create_account_mock(
|
mock_new_account = TestAccountAssociatedDataFactory.create_account_mock(
|
||||||
@ -1541,16 +1545,18 @@ class TestRegisterService:
|
|||||||
account_id="existing-user-456", email="existing@example.com", status="pending"
|
account_id="existing-user-456", email="existing@example.com", status="pending"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock database queries - need to mock the Session query
|
# Mock database queries - need to mock the sessionmaker query
|
||||||
mock_session = MagicMock()
|
mock_session = MagicMock()
|
||||||
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account
|
mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account
|
||||||
|
|
||||||
|
mock_sessionmaker = MagicMock()
|
||||||
|
mock_sessionmaker.return_value.begin.return_value.__enter__.return_value = mock_session
|
||||||
|
mock_sessionmaker.return_value.begin.return_value.__exit__.return_value = None
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("services.account_service.Session") as mock_session_class,
|
patch("services.account_service.sessionmaker", mock_sessionmaker),
|
||||||
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
patch("services.account_service.AccountService.get_account_by_email_with_case_fallback") as mock_lookup,
|
||||||
):
|
):
|
||||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
|
||||||
mock_session_class.return_value.__exit__.return_value = None
|
|
||||||
mock_lookup.return_value = mock_existing_account
|
mock_lookup.return_value = mock_existing_account
|
||||||
|
|
||||||
# Mock scalar for TenantAccountJoin lookup - no existing member
|
# Mock scalar for TenantAccountJoin lookup - no existing member
|
||||||
|
|||||||
@ -357,11 +357,12 @@ class TestAsyncWorkflowService:
|
|||||||
mock_session_context.__enter__.return_value = mock_session
|
mock_session_context.__enter__.return_value = mock_session
|
||||||
mock_session_context.__exit__.return_value = None
|
mock_session_context.__exit__.return_value = None
|
||||||
|
|
||||||
|
mock_sessionmaker = MagicMock()
|
||||||
|
mock_sessionmaker.return_value.begin.return_value = mock_session_context
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=fake_engine)),
|
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=fake_engine)),
|
||||||
patch.object(
|
patch.object(async_workflow_service_module, "sessionmaker", mock_sessionmaker),
|
||||||
async_workflow_service_module, "Session", return_value=mock_session_context
|
|
||||||
) as mock_session_class,
|
|
||||||
patch.object(
|
patch.object(
|
||||||
async_workflow_service_module,
|
async_workflow_service_module,
|
||||||
"SQLAlchemyWorkflowTriggerLogRepository",
|
"SQLAlchemyWorkflowTriggerLogRepository",
|
||||||
@ -373,7 +374,7 @@ class TestAsyncWorkflowService:
|
|||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert result == expected
|
assert result == expected
|
||||||
mock_session_class.assert_called_once_with(fake_engine)
|
mock_sessionmaker.assert_called_once_with(fake_engine)
|
||||||
mock_repo.get_by_id.assert_called_once_with("trigger-log-123", "tenant-123")
|
mock_repo.get_by_id.assert_called_once_with("trigger-log-123", "tenant-123")
|
||||||
|
|
||||||
def test_should_return_recent_logs_as_dict_list(self):
|
def test_should_return_recent_logs_as_dict_list(self):
|
||||||
@ -391,9 +392,12 @@ class TestAsyncWorkflowService:
|
|||||||
mock_session_context.__enter__.return_value = mock_session
|
mock_session_context.__enter__.return_value = mock_session
|
||||||
mock_session_context.__exit__.return_value = None
|
mock_session_context.__exit__.return_value = None
|
||||||
|
|
||||||
|
mock_sessionmaker = MagicMock()
|
||||||
|
mock_sessionmaker.return_value.begin.return_value = mock_session_context
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())),
|
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())),
|
||||||
patch.object(async_workflow_service_module, "Session", return_value=mock_session_context),
|
patch.object(async_workflow_service_module, "sessionmaker", mock_sessionmaker),
|
||||||
patch.object(
|
patch.object(
|
||||||
async_workflow_service_module,
|
async_workflow_service_module,
|
||||||
"SQLAlchemyWorkflowTriggerLogRepository",
|
"SQLAlchemyWorkflowTriggerLogRepository",
|
||||||
@ -432,9 +436,12 @@ class TestAsyncWorkflowService:
|
|||||||
mock_session_context.__enter__.return_value = mock_session
|
mock_session_context.__enter__.return_value = mock_session
|
||||||
mock_session_context.__exit__.return_value = None
|
mock_session_context.__exit__.return_value = None
|
||||||
|
|
||||||
|
mock_sessionmaker = MagicMock()
|
||||||
|
mock_sessionmaker.return_value.begin.return_value = mock_session_context
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())),
|
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())),
|
||||||
patch.object(async_workflow_service_module, "Session", return_value=mock_session_context),
|
patch.object(async_workflow_service_module, "sessionmaker", mock_sessionmaker),
|
||||||
patch.object(
|
patch.object(
|
||||||
async_workflow_service_module,
|
async_workflow_service_module,
|
||||||
"SQLAlchemyWorkflowTriggerLogRepository",
|
"SQLAlchemyWorkflowTriggerLogRepository",
|
||||||
|
|||||||
@ -209,8 +209,22 @@ def _session_wrapper_for_no_autoflush(session: Mock) -> Mock:
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def _sessionmaker_wrapper_for_begin(session: Mock) -> Mock:
|
||||||
|
"""
|
||||||
|
ClearFreePlanTenantExpiredLogs.process uses: with sessionmaker(db.engine).begin() as session:
|
||||||
|
so sessionmaker(db.engine) must return an object with a begin() method that returns a context manager.
|
||||||
|
"""
|
||||||
|
begin_cm = MagicMock()
|
||||||
|
begin_cm.__enter__.return_value = session
|
||||||
|
begin_cm.__exit__.return_value = None
|
||||||
|
|
||||||
|
sessionmaker_result = MagicMock()
|
||||||
|
sessionmaker_result.begin.return_value = begin_cm
|
||||||
|
return sessionmaker_result
|
||||||
|
|
||||||
|
|
||||||
def _session_wrapper_for_direct(session: Mock) -> Mock:
|
def _session_wrapper_for_direct(session: Mock) -> Mock:
|
||||||
"""ClearFreePlanTenantExpiredLogs.process uses: with Session(db.engine) as session:"""
|
"""ClearFreePlanTenantExpiredLogs.process uses: with Session(db.engine) as session: (for old code paths)"""
|
||||||
wrapper = MagicMock()
|
wrapper = MagicMock()
|
||||||
wrapper.__enter__.return_value = session
|
wrapper.__enter__.return_value = session
|
||||||
wrapper.__exit__.return_value = None
|
wrapper.__exit__.return_value = None
|
||||||
@ -348,7 +362,7 @@ def test_process_with_tenant_ids_filters_by_plan_and_logs_errors(monkeypatch: py
|
|||||||
count_query.count.return_value = 2
|
count_query.count.return_value = 2
|
||||||
count_session.query.return_value = count_query
|
count_session.query.return_value = count_query
|
||||||
|
|
||||||
monkeypatch.setattr(service_module, "Session", lambda _engine: _session_wrapper_for_direct(count_session))
|
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: _sessionmaker_wrapper_for_begin(count_session))
|
||||||
|
|
||||||
# Avoid LocalProxy usage
|
# Avoid LocalProxy usage
|
||||||
flask_app = service_module.Flask("test-app")
|
flask_app = service_module.Flask("test-app")
|
||||||
@ -438,8 +452,8 @@ def test_process_without_tenant_ids_batches_and_scales_interval(monkeypatch: pyt
|
|||||||
|
|
||||||
batch_session.query.side_effect = [q1, q2, q3, q4, q_rs]
|
batch_session.query.side_effect = [q1, q2, q3, q4, q_rs]
|
||||||
|
|
||||||
sessions = [_session_wrapper_for_direct(total_session), _session_wrapper_for_direct(batch_session)]
|
sessions = [_sessionmaker_wrapper_for_begin(total_session), _sessionmaker_wrapper_for_begin(batch_session)]
|
||||||
monkeypatch.setattr(service_module, "Session", lambda _engine: sessions.pop(0))
|
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: sessions.pop(0))
|
||||||
|
|
||||||
process_tenant_mock = MagicMock()
|
process_tenant_mock = MagicMock()
|
||||||
monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock)
|
monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock)
|
||||||
@ -457,7 +471,7 @@ def test_process_with_tenant_ids_emits_progress_every_100(monkeypatch: pytest.Mo
|
|||||||
count_query = MagicMock()
|
count_query = MagicMock()
|
||||||
count_query.count.return_value = 100
|
count_query.count.return_value = 100
|
||||||
count_session.query.return_value = count_query
|
count_session.query.return_value = count_query
|
||||||
monkeypatch.setattr(service_module, "Session", lambda _engine: _session_wrapper_for_direct(count_session))
|
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: _sessionmaker_wrapper_for_begin(count_session))
|
||||||
|
|
||||||
flask_app = service_module.Flask("test-app")
|
flask_app = service_module.Flask("test-app")
|
||||||
monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app))
|
monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app))
|
||||||
@ -523,8 +537,8 @@ def test_process_without_tenant_ids_all_intervals_too_many_uses_min_interval(mon
|
|||||||
|
|
||||||
batch_session.query.side_effect = [*count_queries, q_rs]
|
batch_session.query.side_effect = [*count_queries, q_rs]
|
||||||
|
|
||||||
sessions = [_session_wrapper_for_direct(total_session), _session_wrapper_for_direct(batch_session)]
|
sessions = [_sessionmaker_wrapper_for_begin(total_session), _sessionmaker_wrapper_for_begin(batch_session)]
|
||||||
monkeypatch.setattr(service_module, "Session", lambda _engine: sessions.pop(0))
|
monkeypatch.setattr(service_module, "sessionmaker", lambda _engine: sessions.pop(0))
|
||||||
|
|
||||||
process_tenant_mock = MagicMock()
|
process_tenant_mock = MagicMock()
|
||||||
monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock)
|
monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock)
|
||||||
|
|||||||
@ -578,26 +578,33 @@ class TestDatasetServiceCreationAndUpdate:
|
|||||||
binding = SimpleNamespace(external_knowledge_id="old-knowledge", external_knowledge_api_id="old-api")
|
binding = SimpleNamespace(external_knowledge_id="old-knowledge", external_knowledge_api_id="old-api")
|
||||||
session = MagicMock()
|
session = MagicMock()
|
||||||
session.query.return_value.filter_by.return_value.first.return_value = binding
|
session.query.return_value.filter_by.return_value.first.return_value = binding
|
||||||
|
session.add = MagicMock()
|
||||||
session_context = _make_session_context(session)
|
session_context = _make_session_context(session)
|
||||||
|
|
||||||
|
mock_sessionmaker = MagicMock()
|
||||||
|
mock_sessionmaker.return_value.begin.return_value = session_context
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("services.dataset_service.db") as mock_db,
|
patch("services.dataset_service.db") as mock_db,
|
||||||
patch("services.dataset_service.Session", return_value=session_context),
|
patch("services.dataset_service.sessionmaker", mock_sessionmaker),
|
||||||
):
|
):
|
||||||
DatasetService._update_external_knowledge_binding("dataset-1", "new-knowledge", "new-api")
|
DatasetService._update_external_knowledge_binding("dataset-1", "new-knowledge", "new-api")
|
||||||
|
|
||||||
assert binding.external_knowledge_id == "new-knowledge"
|
assert binding.external_knowledge_id == "new-knowledge"
|
||||||
assert binding.external_knowledge_api_id == "new-api"
|
assert binding.external_knowledge_api_id == "new-api"
|
||||||
mock_db.session.add.assert_called_once_with(binding)
|
session.add.assert_called_once_with(binding)
|
||||||
|
|
||||||
def test_update_external_knowledge_binding_raises_for_missing_binding(self):
|
def test_update_external_knowledge_binding_raises_for_missing_binding(self):
|
||||||
session = MagicMock()
|
session = MagicMock()
|
||||||
session.query.return_value.filter_by.return_value.first.return_value = None
|
session.query.return_value.filter_by.return_value.first.return_value = None
|
||||||
session_context = _make_session_context(session)
|
session_context = _make_session_context(session)
|
||||||
|
|
||||||
|
mock_sessionmaker = MagicMock()
|
||||||
|
mock_sessionmaker.return_value.begin.return_value = session_context
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("services.dataset_service.db"),
|
patch("services.dataset_service.db"),
|
||||||
patch("services.dataset_service.Session", return_value=session_context),
|
patch("services.dataset_service.sessionmaker", mock_sessionmaker),
|
||||||
):
|
):
|
||||||
with pytest.raises(ValueError, match="External knowledge binding not found"):
|
with pytest.raises(ValueError, match="External knowledge binding not found"):
|
||||||
DatasetService._update_external_knowledge_binding("dataset-1", "knowledge-1", "api-1")
|
DatasetService._update_external_knowledge_binding("dataset-1", "knowledge-1", "api-1")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user