mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 18:06:36 +08:00
refactor: migrate session.query to select API in small task files batch (#34684)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
5aa2524d33
commit
cb55176612
@ -8,7 +8,7 @@ import click
|
||||
import pandas as pd
|
||||
from celery import shared_task
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
from core.model_manager import ModelManager
|
||||
@ -140,10 +140,8 @@ def batch_create_segment_to_index_task(
|
||||
content = segment["content"]
|
||||
doc_id = str(uuid.uuid4())
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
max_position = (
|
||||
session.query(func.max(DocumentSegment.position))
|
||||
.where(DocumentSegment.document_id == document_config["id"])
|
||||
.scalar()
|
||||
max_position = session.scalar(
|
||||
select(func.max(DocumentSegment.position)).where(DocumentSegment.document_id == document_config["id"])
|
||||
)
|
||||
segment_document = DocumentSegment(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.db.session_factory import session_factory
|
||||
@ -14,7 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
@shared_task(queue="dataset")
|
||||
def delete_account_task(account_id):
|
||||
with session_factory.create_session() as session:
|
||||
account = session.query(Account).where(Account.id == account_id).first()
|
||||
account = session.scalar(select(Account).where(Account.id == account_id).limit(1))
|
||||
try:
|
||||
if dify_config.BILLING_ENABLED:
|
||||
BillingService.delete_account(account_id)
|
||||
|
||||
@ -82,7 +82,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
try:
|
||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
|
||||
if dataset is None:
|
||||
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
|
||||
return
|
||||
|
||||
@ -26,9 +26,6 @@ def mock_db_session():
|
||||
cm.__exit__.return_value = None
|
||||
mock_sf.create_session.return_value = cm
|
||||
|
||||
query = MagicMock()
|
||||
session.query.return_value = query
|
||||
query.where.return_value = query
|
||||
yield session
|
||||
|
||||
|
||||
@ -49,12 +46,12 @@ def mock_deps():
|
||||
|
||||
def _set_account_found(mock_db_session, email: str = "user@example.com"):
|
||||
account = SimpleNamespace(email=email)
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = account
|
||||
mock_db_session.scalar.return_value = account
|
||||
return account
|
||||
|
||||
|
||||
def _set_account_missing(mock_db_session):
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = None
|
||||
mock_db_session.scalar.return_value = None
|
||||
|
||||
|
||||
class TestDeleteAccountTask:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user