fix(api): fix format, replace .filter with .where

This commit is contained in:
QuantumGhost 2025-09-17 22:55:13 +08:00
parent 24fc7d0d6b
commit 5077f8b299
18 changed files with 41 additions and 60 deletions

View File

@ -27,7 +27,7 @@ def get_rag_pipeline(
pipeline = (
db.session.query(Pipeline)
.filter(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
.first()
)

View File

@ -215,7 +215,7 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
if not file.filename:
raise FilenameNotExistsError
if not current_user:
raise ValueError("Invalid user account")

View File

@ -195,9 +195,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id
)
.where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id)
.first()
)
@ -272,7 +270,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
if document_id and dataset_id:
document = (
db.session.query(Document)
.filter(Document.id == document_id, Document.dataset_id == dataset_id)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
if document:

View File

@ -152,10 +152,7 @@ class DatasourceFileManager:
:return: the binary of the file, mime type
"""
upload_file: UploadFile | None = (
db.session.query(UploadFile).where(UploadFile.id == id)
.first()
)
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == id).first()
if not upload_file:
return None
@ -173,10 +170,7 @@ class DatasourceFileManager:
:return: the binary of the file, mime type
"""
message_file: MessageFile | None = (
db.session.query(MessageFile).where(MessageFile.id == id)
.first()
)
message_file: MessageFile | None = db.session.query(MessageFile).where(MessageFile.id == id).first()
# Check if message_file is not None
if message_file is not None:
@ -190,10 +184,7 @@ class DatasourceFileManager:
else:
tool_file_id = None
tool_file: ToolFile | None = (
db.session.query(ToolFile).where(ToolFile.id == tool_file_id)
.first()
)
tool_file: ToolFile | None = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first()
if not tool_file:
return None
@ -211,10 +202,7 @@ class DatasourceFileManager:
:return: the binary of the file, mime type
"""
upload_file: UploadFile | None = (
db.session.query(UploadFile).where(UploadFile.id == upload_file_id)
.first()
)
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if not upload_file:
return None, None

View File

@ -1,4 +1,3 @@
from pydantic import BaseModel, Field

View File

@ -1,4 +1,3 @@
from pydantic import BaseModel

View File

@ -1,4 +1,3 @@
from pydantic import BaseModel, ConfigDict
from models.dataset import Document

View File

@ -13,4 +13,4 @@ class WorkflowNodeRunFailedError(Exception):
@property
def error(self) -> str:
return self._error
return self._error

View File

@ -160,7 +160,7 @@ class KnowledgeIndexNode(Node):
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.word_count = (
db.session.query(func.sum(DocumentSegment.word_count))
.filter(
.where(
DocumentSegment.document_id == document.id,
DocumentSegment.dataset_id == dataset.id,
)
@ -168,7 +168,7 @@ class KnowledgeIndexNode(Node):
)
db.session.add(document)
# update document segment status
db.session.query(DocumentSegment).filter(
db.session.query(DocumentSegment).where(
DocumentSegment.document_id == document.id,
DocumentSegment.dataset_id == dataset.id,
).update(

View File

@ -326,7 +326,7 @@ def _build_from_datasource_file(
) -> File:
datasource_file = (
db.session.query(UploadFile)
.filter(
.where(
UploadFile.id == mapping.get("datasource_file_id"),
UploadFile.tenant_id == tenant_id,
)

View File

@ -82,7 +82,7 @@ class Dataset(Base):
def total_available_documents(self):
return (
db.session.query(func.count(Document.id))
.filter(
.where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,

View File

@ -419,7 +419,7 @@ class DatasetService:
def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str):
dataset = (
db.session.query(Dataset)
.filter(
.where(
Dataset.id != dataset_id,
Dataset.name == name,
Dataset.tenant_id == tenant_id,

View File

@ -690,7 +690,7 @@ class DatasourceProviderService:
# Get all provider configurations of the current workspace
datasource_providers: list[DatasourceProvider] = (
db.session.query(DatasourceProvider)
.filter(
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
@ -862,7 +862,7 @@ class DatasourceProviderService:
# Get all provider configurations of the current workspace
datasource_providers: list[DatasourceProvider] = (
db.session.query(DatasourceProvider)
.filter(
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,

View File

@ -1,4 +1,3 @@
import yaml
from flask_login import current_user
@ -36,7 +35,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
pipeline_customized_templates = (
db.session.query(PipelineCustomizedTemplate)
.filter(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language)
.where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language)
.order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc())
.all()
)

View File

@ -1,4 +1,3 @@
import yaml
from extensions.ext_database import db

View File

@ -138,7 +138,7 @@ class RagPipelineService:
"""
customized_template: PipelineCustomizedTemplate | None = (
db.session.query(PipelineCustomizedTemplate)
.filter(
.where(
PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
)
@ -151,7 +151,7 @@ class RagPipelineService:
if template_name:
template = (
db.session.query(PipelineCustomizedTemplate)
.filter(
.where(
PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
PipelineCustomizedTemplate.id != template_id,
@ -174,7 +174,7 @@ class RagPipelineService:
"""
customized_template: PipelineCustomizedTemplate | None = (
db.session.query(PipelineCustomizedTemplate)
.filter(
.where(
PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
)
@ -192,7 +192,7 @@ class RagPipelineService:
# fetch draft workflow by rag pipeline
workflow = (
db.session.query(Workflow)
.filter(
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
@ -214,7 +214,7 @@ class RagPipelineService:
# fetch published workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.id == pipeline.workflow_id,
@ -1015,7 +1015,7 @@ class RagPipelineService:
"""
limit = int(args.get("limit", 20))
base_query = db.session.query(WorkflowRun).filter(
base_query = db.session.query(WorkflowRun).where(
WorkflowRun.tenant_id == pipeline.tenant_id,
WorkflowRun.app_id == pipeline.id,
or_(
@ -1025,7 +1025,7 @@ class RagPipelineService:
)
if args.get("last_id"):
last_workflow_run = base_query.filter(
last_workflow_run = base_query.where(
WorkflowRun.id == args.get("last_id"),
).first()
@ -1033,7 +1033,7 @@ class RagPipelineService:
raise ValueError("Last workflow run not exists")
workflow_runs = (
base_query.filter(
base_query.where(
WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
)
.order_by(WorkflowRun.created_at.desc())
@ -1046,7 +1046,7 @@ class RagPipelineService:
has_more = False
if len(workflow_runs) == limit:
current_page_first_workflow_run = workflow_runs[-1]
rest_count = base_query.filter(
rest_count = base_query.where(
WorkflowRun.created_at < current_page_first_workflow_run.created_at,
WorkflowRun.id != current_page_first_workflow_run.id,
).count()
@ -1065,7 +1065,7 @@ class RagPipelineService:
"""
workflow_run = (
db.session.query(WorkflowRun)
.filter(
.where(
WorkflowRun.tenant_id == pipeline.tenant_id,
WorkflowRun.app_id == pipeline.id,
WorkflowRun.id == run_id,
@ -1130,7 +1130,7 @@ class RagPipelineService:
if template_name:
template = (
db.session.query(PipelineCustomizedTemplate)
.filter(
.where(
PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id,
)
@ -1168,7 +1168,7 @@ class RagPipelineService:
def is_workflow_exist(self, pipeline: Pipeline) -> bool:
return (
db.session.query(Workflow)
.filter(
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == Workflow.VERSION_DRAFT,
@ -1362,10 +1362,10 @@ class RagPipelineService:
"""
Get datasource plugins
"""
dataset: Dataset | None = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset not found")
pipeline: Pipeline | None = db.session.query(Pipeline).filter(Pipeline.id == dataset.pipeline_id).first()
pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
@ -1446,10 +1446,10 @@ class RagPipelineService:
"""
Get pipeline
"""
dataset: Dataset | None = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset not found")
pipeline: Pipeline | None = db.session.query(Pipeline).filter(Pipeline.id == dataset.pipeline_id).first()
pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
return pipeline

View File

@ -318,7 +318,7 @@ class RagPipelineDslService:
if knowledge_configuration.indexing_technique == "high_quality":
dataset_collection_binding = (
self._session.query(DatasetCollectionBinding)
.filter(
.where(
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
@ -452,7 +452,7 @@ class RagPipelineDslService:
if knowledge_configuration.indexing_technique == "high_quality":
dataset_collection_binding = (
self._session.query(DatasetCollectionBinding)
.filter(
.where(
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
@ -599,7 +599,7 @@ class RagPipelineDslService:
)
workflow = (
self._session.query(Workflow)
.filter(
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
@ -673,7 +673,7 @@ class RagPipelineDslService:
workflow = (
self._session.query(Workflow)
.filter(
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",

View File

@ -33,7 +33,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
if action == "upgrade":
dataset_documents = (
db.session.query(DatasetDocument)
.filter(
.where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
@ -54,7 +54,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
# add from vector index
segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
@ -88,7 +88,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
elif action == "update":
dataset_documents = (
db.session.query(DatasetDocument)
.filter(
.where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
@ -113,7 +113,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
try:
segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)