diff --git a/api/controllers/console/datasets/wraps.py b/api/controllers/console/datasets/wraps.py index 33751ab231..98abb3ef8d 100644 --- a/api/controllers/console/datasets/wraps.py +++ b/api/controllers/console/datasets/wraps.py @@ -27,7 +27,7 @@ def get_rag_pipeline( pipeline = ( db.session.query(Pipeline) - .filter(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id) + .where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id) .first() ) diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index cbc1907bf5..f05325d711 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -215,7 +215,7 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource): if not file.filename: raise FilenameNotExistsError - + if not current_user: raise ValueError("Invalid user account") diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 3b9bd224d9..ebb8b15163 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -195,9 +195,7 @@ class PipelineRunner(WorkflowBasedAppRunner): # fetch workflow by workflow_id workflow = ( db.session.query(Workflow) - .filter( - Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id - ) + .where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id) .first() ) @@ -272,7 +270,7 @@ class PipelineRunner(WorkflowBasedAppRunner): if document_id and dataset_id: document = ( db.session.query(Document) - .filter(Document.id == document_id, Document.dataset_id == dataset_id) + .where(Document.id == document_id, Document.dataset_id == dataset_id) .first() ) if document: diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py index f4e3c656bc..0c50c2f980 100644 --- a/api/core/datasource/datasource_file_manager.py +++ b/api/core/datasource/datasource_file_manager.py @@ -152,10 +152,7 @@ class DatasourceFileManager: :return: the binary of the file, mime type """ - upload_file: UploadFile | None = ( - db.session.query(UploadFile).where(UploadFile.id == id) - .first() - ) + upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == id).first() if not upload_file: return None @@ -173,10 +170,7 @@ class DatasourceFileManager: :return: the binary of the file, mime type """ - message_file: MessageFile | None = ( - db.session.query(MessageFile).where(MessageFile.id == id) - .first() - ) + message_file: MessageFile | None = db.session.query(MessageFile).where(MessageFile.id == id).first() # Check if message_file is not None if message_file is not None: @@ -190,10 +184,7 @@ class DatasourceFileManager: else: tool_file_id = None - tool_file: ToolFile | None = ( - db.session.query(ToolFile).where(ToolFile.id == tool_file_id) - .first() - ) + tool_file: ToolFile | None = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first() if not tool_file: return None @@ -211,10 +202,7 @@ class DatasourceFileManager: :return: the binary of the file, mime type """ - upload_file: UploadFile | None = ( - db.session.query(UploadFile).where(UploadFile.id == upload_file_id) - .first() - ) + upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first() if not upload_file: return None, None diff --git a/api/core/datasource/entities/common_entities.py b/api/core/datasource/entities/common_entities.py index 98680a5779..ac36d83ae3 100644 --- a/api/core/datasource/entities/common_entities.py +++ b/api/core/datasource/entities/common_entities.py @@ -1,4 +1,3 @@ - from pydantic import BaseModel, Field diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index f6da4c7094..b9ca7414dc 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -1,4 +1,3 @@ - from pydantic import BaseModel diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index b5eea0bf30..b9bf9d0d8c 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -1,4 +1,3 @@ - from pydantic import BaseModel, ConfigDict from models.dataset import Document diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py index 14e0315846..5bf1faee5d 100644 --- a/api/core/workflow/errors.py +++ b/api/core/workflow/errors.py @@ -13,4 +13,4 @@ class WorkflowNodeRunFailedError(Exception): @property def error(self) -> str: - return self._error \ No newline at end of file + return self._error diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index d5ced1a246..4b6bad1aa3 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -160,7 +160,7 @@ class KnowledgeIndexNode(Node): document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.word_count = ( db.session.query(func.sum(DocumentSegment.word_count)) - .filter( + .where( DocumentSegment.document_id == document.id, DocumentSegment.dataset_id == dataset.id, ) @@ -168,7 +168,7 @@ class KnowledgeIndexNode(Node): ) db.session.add(document) # update document segment status - db.session.query(DocumentSegment).filter( + db.session.query(DocumentSegment).where( DocumentSegment.document_id == document.id, DocumentSegment.dataset_id == dataset.id, ).update( diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 41505ab025..588168bd39 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -326,7 +326,7 @@ def _build_from_datasource_file( ) -> File: datasource_file = ( db.session.query(UploadFile) - .filter( + .where( UploadFile.id == mapping.get("datasource_file_id"), UploadFile.tenant_id == tenant_id, ) diff --git a/api/models/dataset.py b/api/models/dataset.py index d620d56006..2c4059f800 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -82,7 +82,7 @@ class Dataset(Base): def total_available_documents(self): return ( db.session.query(func.count(Document.id)) - .filter( + .where( Document.dataset_id == self.id, Document.indexing_status == "completed", Document.enabled == True, diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 798233fd95..51507886ad 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -419,7 +419,7 @@ class DatasetService: def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str): dataset = ( db.session.query(Dataset) - .filter( + .where( Dataset.id != dataset_id, Dataset.name == name, Dataset.tenant_id == tenant_id, diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 8dceeee7ec..f05a892f93 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -690,7 +690,7 @@ class DatasourceProviderService: # Get all provider configurations of the current workspace datasource_providers: list[DatasourceProvider] = ( db.session.query(DatasourceProvider) - .filter( + .where( DatasourceProvider.tenant_id == tenant_id, DatasourceProvider.provider == provider, DatasourceProvider.plugin_id == plugin_id, @@ -862,7 +862,7 @@ class DatasourceProviderService: # Get all provider configurations of the current workspace datasource_providers: list[DatasourceProvider] = ( db.session.query(DatasourceProvider) - .filter( + .where( DatasourceProvider.tenant_id == tenant_id, DatasourceProvider.provider == provider, DatasourceProvider.plugin_id == plugin_id, diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index 82a0a08ec6..ca871bcaa1 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -1,4 +1,3 @@ - import yaml from flask_login import current_user @@ -36,7 +35,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ pipeline_customized_templates = ( db.session.query(PipelineCustomizedTemplate) - .filter(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language) + .where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language) .order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc()) .all() ) diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index a544767465..ec91f79606 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -1,4 +1,3 @@ - import yaml from extensions.ext_database import db diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 0232b9998f..cc7514aaba 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -138,7 +138,7 @@ class RagPipelineService: """ customized_template: PipelineCustomizedTemplate | None = ( db.session.query(PipelineCustomizedTemplate) - .filter( + .where( PipelineCustomizedTemplate.id == template_id, PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, ) @@ -151,7 +151,7 @@ class RagPipelineService: if template_name: template = ( db.session.query(PipelineCustomizedTemplate) - .filter( + .where( PipelineCustomizedTemplate.name == template_name, PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, PipelineCustomizedTemplate.id != template_id, @@ -174,7 +174,7 @@ class RagPipelineService: """ customized_template: PipelineCustomizedTemplate | None = ( db.session.query(PipelineCustomizedTemplate) - .filter( + .where( PipelineCustomizedTemplate.id == template_id, PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id, ) @@ -192,7 +192,7 @@ class RagPipelineService: # fetch draft workflow by rag pipeline workflow = ( db.session.query(Workflow) - .filter( + .where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.version == "draft", @@ -214,7 +214,7 @@ class RagPipelineService: # fetch published workflow by workflow_id workflow = ( db.session.query(Workflow) - .filter( + .where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == pipeline.workflow_id, @@ -1015,7 +1015,7 @@ class RagPipelineService: """ limit = int(args.get("limit", 20)) - base_query = db.session.query(WorkflowRun).filter( + base_query = db.session.query(WorkflowRun).where( WorkflowRun.tenant_id == pipeline.tenant_id, WorkflowRun.app_id == pipeline.id, or_( @@ -1025,7 +1025,7 @@ class RagPipelineService: ) if args.get("last_id"): - last_workflow_run = base_query.filter( + last_workflow_run = base_query.where( WorkflowRun.id == args.get("last_id"), ).first() @@ -1033,7 +1033,7 @@ class RagPipelineService: raise ValueError("Last workflow run not exists") workflow_runs = ( - base_query.filter( + base_query.where( WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id ) .order_by(WorkflowRun.created_at.desc()) @@ -1046,7 +1046,7 @@ class RagPipelineService: has_more = False if len(workflow_runs) == limit: current_page_first_workflow_run = workflow_runs[-1] - rest_count = base_query.filter( + rest_count = base_query.where( WorkflowRun.created_at < current_page_first_workflow_run.created_at, WorkflowRun.id != current_page_first_workflow_run.id, ).count() @@ -1065,7 +1065,7 @@ class RagPipelineService: """ workflow_run = ( db.session.query(WorkflowRun) - .filter( + .where( WorkflowRun.tenant_id == pipeline.tenant_id, WorkflowRun.app_id == pipeline.id, WorkflowRun.id == run_id, @@ -1130,7 +1130,7 @@ class RagPipelineService: if template_name: template = ( db.session.query(PipelineCustomizedTemplate) - .filter( + .where( PipelineCustomizedTemplate.name == template_name, PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id, ) @@ -1168,7 +1168,7 @@ class RagPipelineService: def is_workflow_exist(self, pipeline: Pipeline) -> bool: return ( db.session.query(Workflow) - .filter( + .where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.version == Workflow.VERSION_DRAFT, @@ -1362,10 +1362,10 @@ class RagPipelineService: """ Get datasource plugins """ - dataset: Dataset | None = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: raise ValueError("Dataset not found") - pipeline: Pipeline | None = db.session.query(Pipeline).filter(Pipeline.id == dataset.pipeline_id).first() + pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first() if not pipeline: raise ValueError("Pipeline not found") @@ -1446,10 +1446,10 @@ class RagPipelineService: """ Get pipeline """ - dataset: Dataset | None = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: raise ValueError("Dataset not found") - pipeline: Pipeline | None = db.session.query(Pipeline).filter(Pipeline.id == dataset.pipeline_id).first() + pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first() if not pipeline: raise ValueError("Pipeline not found") return pipeline diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index e21d2d56bc..88f28e03ef 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -318,7 +318,7 @@ class RagPipelineDslService: if knowledge_configuration.indexing_technique == "high_quality": dataset_collection_binding = ( self._session.query(DatasetCollectionBinding) - .filter( + .where( DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model, @@ -452,7 +452,7 @@ class RagPipelineDslService: if knowledge_configuration.indexing_technique == "high_quality": dataset_collection_binding = ( self._session.query(DatasetCollectionBinding) - .filter( + .where( DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model, @@ -599,7 +599,7 @@ class RagPipelineDslService: ) workflow = ( self._session.query(Workflow) - .filter( + .where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.version == "draft", @@ -673,7 +673,7 @@ class RagPipelineDslService: workflow = ( self._session.query(Workflow) - .filter( + .where( Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.version == "draft", diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py index df4a76d94f..713f149c38 100644 --- a/api/tasks/deal_dataset_index_update_task.py +++ b/api/tasks/deal_dataset_index_update_task.py @@ -33,7 +33,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): if action == "upgrade": dataset_documents = ( db.session.query(DatasetDocument) - .filter( + .where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, @@ -54,7 +54,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): # add from vector index segments = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) .order_by(DocumentSegment.position.asc()) .all() ) @@ -88,7 +88,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): elif action == "update": dataset_documents = ( db.session.query(DatasetDocument) - .filter( + .where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, @@ -113,7 +113,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): try: segments = ( db.session.query(DocumentSegment) - .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) + .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True) .order_by(DocumentSegment.position.asc()) .all() )