diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index e6360706ee..f04b0e04c3 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -104,7 +104,7 @@ class CustomizedPipelineTemplateApi(Resource): def post(self, template_id: str): with Session(db.engine) as session: template = ( - session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first() + session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first() ) if not template: raise ValueError("Customized pipeline template not found.") diff --git a/api/controllers/console/datasets/wraps.py b/api/controllers/console/datasets/wraps.py index 26783d8cf8..33751ab231 100644 --- a/api/controllers/console/datasets/wraps.py +++ b/api/controllers/console/datasets/wraps.py @@ -1,6 +1,5 @@ from collections.abc import Callable from functools import wraps -from typing import Optional from controllers.console.datasets.error import PipelineNotFoundError from extensions.ext_database import db @@ -10,7 +9,7 @@ from models.dataset import Pipeline def get_rag_pipeline( - view: Optional[Callable] = None, + view: Callable | None = None, ): def decorator(view_func): @wraps(view_func) diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 3dbc8706d3..e836a46f8f 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,6 +1,6 @@ from collections.abc import Sequence from enum import StrEnum, auto -from typing import Any, Literal, Optional +from typing import Any, Literal from pydantic import BaseModel, Field, field_validator @@ -114,9 +114,9 @@ class VariableEntity(BaseModel): hide: bool = False max_length: int | None = None options: Sequence[str] = Field(default_factory=list) - allowed_file_types: Optional[Sequence[FileType]] = Field(default_factory=list) - allowed_file_extensions: Optional[Sequence[str]] = Field(default_factory=list) - allowed_file_upload_methods: Optional[Sequence[FileTransferMethod]] = Field(default_factory=list) + allowed_file_types: Sequence[FileType] | None = Field(default_factory=list) + allowed_file_extensions: Sequence[str] | None = Field(default_factory=list) + allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list) @field_validator("description", mode="before") @classmethod @@ -134,8 +134,8 @@ class RagPipelineVariableEntity(VariableEntity): Rag Pipeline Variable Entity. """ - tooltips: Optional[str] = None - placeholder: Optional[str] = None + tooltips: str | None = None + placeholder: str | None = None belong_to_node_id: str @@ -298,7 +298,7 @@ class AppConfig(BaseModel): tenant_id: str app_id: str app_mode: AppMode - additional_features: Optional[AppAdditionalFeatures] = None + additional_features: AppAdditionalFeatures | None = None variables: list[VariableEntity] = [] sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index b1e98ed3ea..d441f273d8 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -7,7 +7,7 @@ import threading import time import uuid from collections.abc import Generator, Mapping -from typing import Any, Literal, Optional, Union, cast, overload +from typing import Any, Literal, Union, cast, overload from flask import Flask, current_app from pydantic import ValidationError @@ -69,7 +69,7 @@ class PipelineGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: Literal[True], call_depth: int, - workflow_thread_pool_id: Optional[str], + workflow_thread_pool_id: str | None, is_retry: bool = False, ) -> Generator[Mapping | str, None, None]: ... @@ -84,7 +84,7 @@ class PipelineGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: Literal[False], call_depth: int, - workflow_thread_pool_id: Optional[str], + workflow_thread_pool_id: str | None, is_retry: bool = False, ) -> Mapping[str, Any]: ... @@ -99,7 +99,7 @@ class PipelineGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: bool, call_depth: int, - workflow_thread_pool_id: Optional[str], + workflow_thread_pool_id: str | None, is_retry: bool = False, ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ... @@ -113,7 +113,7 @@ class PipelineGenerator(BaseAppGenerator): invoke_from: InvokeFrom, streaming: bool = True, call_depth: int = 0, - workflow_thread_pool_id: Optional[str] = None, + workflow_thread_pool_id: str | None = None, is_retry: bool = False, ) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]: # Add null check for dataset @@ -314,7 +314,7 @@ class PipelineGenerator(BaseAppGenerator): workflow_node_execution_repository: WorkflowNodeExecutionRepository, streaming: bool = True, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, - workflow_thread_pool_id: Optional[str] = None, + workflow_thread_pool_id: str | None = None, ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: """ Generate App response. @@ -331,7 +331,7 @@ class PipelineGenerator(BaseAppGenerator): """ with preserve_flask_contexts(flask_app, context_vars=context): # init queue manager - workflow = db.session.query(Workflow).filter(Workflow.id == workflow_id).first() + workflow = db.session.query(Workflow).where(Workflow.id == workflow_id).first() if not workflow: raise ValueError(f"Workflow not found: {workflow_id}") queue_manager = PipelineQueueManager( @@ -568,7 +568,7 @@ class PipelineGenerator(BaseAppGenerator): queue_manager: AppQueueManager, context: contextvars.Context, variable_loader: VariableLoader, - workflow_thread_pool_id: Optional[str] = None, + workflow_thread_pool_id: str | None = None, ) -> None: """ Generate worker in a new thread. @@ -801,11 +801,11 @@ class PipelineGenerator(BaseAppGenerator): self, datasource_runtime: OnlineDriveDatasourcePlugin, prefix: str, - bucket: Optional[str], + bucket: str | None, user_id: str, all_files: list, datasource_info: Mapping[str, Any], - next_page_parameters: Optional[dict] = None, + next_page_parameters: dict | None = None, ): """ Get files in a folder. diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index f2f01d1ee7..3b9bd224d9 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -1,6 +1,6 @@ import logging import time -from typing import Optional, cast +from typing import cast from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig @@ -40,7 +40,7 @@ class PipelineRunner(WorkflowBasedAppRunner): variable_loader: VariableLoader, workflow: Workflow, system_user_id: str, - workflow_thread_pool_id: Optional[str] = None, + workflow_thread_pool_id: str | None = None, ) -> None: """ :param application_generate_entity: application generate entity @@ -69,13 +69,13 @@ class PipelineRunner(WorkflowBasedAppRunner): user_id = None if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: - end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() + end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first() if end_user: user_id = end_user.session_id else: user_id = self.application_generate_entity.user_id - pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first() + pipeline = db.session.query(Pipeline).where(Pipeline.id == app_config.app_id).first() if not pipeline: raise ValueError("Pipeline not found") @@ -188,7 +188,7 @@ class PipelineRunner(WorkflowBasedAppRunner): ) self._handle_event(workflow_entry, event) - def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]: + def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None: """ Get workflow """ @@ -205,7 +205,7 @@ class PipelineRunner(WorkflowBasedAppRunner): return workflow def _init_rag_pipeline_graph( - self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: Optional[str] = None + self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: str | None = None ) -> Graph: """ Init pipeline graph diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 14dd78489a..a5ed0f8fa3 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -242,7 +242,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): node_id: str inputs: dict - single_loop_run: Optional[SingleLoopRunEntity] = None + single_loop_run: SingleLoopRunEntity | None = None class RagPipelineGenerateEntity(WorkflowAppGenerateEntity): @@ -256,9 +256,9 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity): datasource_info: Mapping[str, Any] dataset_id: str batch: str - document_id: Optional[str] = None - original_document_id: Optional[str] = None - start_node_id: Optional[str] = None + document_id: str | None = None + original_document_id: str | None = None + start_node_id: str | None = None # Import TraceQueueManager at runtime to resolve forward references diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index fe8c916d3a..31dc1eea89 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -1,6 +1,6 @@ from collections.abc import Mapping, Sequence from enum import StrEnum -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, ConfigDict, Field @@ -252,8 +252,8 @@ class NodeStartStreamResponse(StreamResponse): node_type: str title: str index: int - predecessor_node_id: Optional[str] = None - inputs: Optional[Mapping[str, Any]] = None + predecessor_node_id: str | None = None + inputs: Mapping[str, Any] | None = None inputs_truncated: bool = False created_at: int extras: dict[str, object] = Field(default_factory=dict) @@ -310,12 +310,12 @@ class NodeFinishStreamResponse(StreamResponse): node_type: str title: str index: int - predecessor_node_id: Optional[str] = None - inputs: Optional[Mapping[str, Any]] = None + predecessor_node_id: str | None = None + inputs: Mapping[str, Any] | None = None inputs_truncated: bool = False - process_data: Optional[Mapping[str, Any]] = None + process_data: Mapping[str, Any] | None = None process_data_truncated: bool = False - outputs: Optional[Mapping[str, Any]] = None + outputs: Mapping[str, Any] | None = None outputs_truncated: bool = True status: str error: str | None = None @@ -382,12 +382,12 @@ class NodeRetryStreamResponse(StreamResponse): node_type: str title: str index: int - predecessor_node_id: Optional[str] = None - inputs: Optional[Mapping[str, Any]] = None + predecessor_node_id: str | None = None + inputs: Mapping[str, Any] | None = None inputs_truncated: bool = False - process_data: Optional[Mapping[str, Any]] = None + process_data: Mapping[str, Any] | None = None process_data_truncated: bool = False - outputs: Optional[Mapping[str, Any]] = None + outputs: Mapping[str, Any] | None = None outputs_truncated: bool = False status: str error: str | None = None @@ -503,11 +503,11 @@ class IterationNodeCompletedStreamResponse(StreamResponse): node_id: str node_type: str title: str - outputs: Optional[Mapping] = None + outputs: Mapping | None = None outputs_truncated: bool = False created_at: int - extras: Optional[dict] = None - inputs: Optional[Mapping] = None + extras: dict | None = None + inputs: Mapping | None = None inputs_truncated: bool = False status: WorkflowNodeExecutionStatus error: str | None = None @@ -541,8 +541,8 @@ class LoopNodeStartStreamResponse(StreamResponse): metadata: Mapping = {} inputs: Mapping = {} inputs_truncated: bool = False - parallel_id: Optional[str] = None - parallel_start_node_id: Optional[str] = None + parallel_id: str | None = None + parallel_start_node_id: str | None = None event: StreamEvent = StreamEvent.LOOP_STARTED workflow_run_id: str @@ -590,11 +590,11 @@ class LoopNodeCompletedStreamResponse(StreamResponse): node_id: str node_type: str title: str - outputs: Optional[Mapping] = None + outputs: Mapping | None = None outputs_truncated: bool = False created_at: int - extras: Optional[dict] = None - inputs: Optional[Mapping] = None + extras: dict | None = None + inputs: Mapping | None = None inputs_truncated: bool = False status: WorkflowNodeExecutionStatus error: str | None = None diff --git a/api/core/datasource/__base/datasource_runtime.py b/api/core/datasource/__base/datasource_runtime.py index 7d24bd7c6d..b7f280208a 100644 --- a/api/core/datasource/__base/datasource_runtime.py +++ b/api/core/datasource/__base/datasource_runtime.py @@ -17,9 +17,9 @@ class DatasourceRuntime(BaseModel): """ tenant_id: str - datasource_id: Optional[str] = None + datasource_id: str | None = None invoke_from: Optional["InvokeFrom"] = None - datasource_invoke_from: Optional[DatasourceInvokeFrom] = None + datasource_invoke_from: DatasourceInvokeFrom | None = None credentials: dict[str, Any] = Field(default_factory=dict) runtime_parameters: dict[str, Any] = Field(default_factory=dict) diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py index beb5ce7b04..f4e3c656bc 100644 --- a/api/core/datasource/datasource_file_manager.py +++ b/api/core/datasource/datasource_file_manager.py @@ -6,7 +6,7 @@ import os import time from datetime import datetime from mimetypes import guess_extension, guess_type -from typing import Optional, Union +from typing import Union from uuid import uuid4 import httpx @@ -62,10 +62,10 @@ class DatasourceFileManager: *, user_id: str, tenant_id: str, - conversation_id: Optional[str], + conversation_id: str | None, file_binary: bytes, mimetype: str, - filename: Optional[str] = None, + filename: str | None = None, ) -> UploadFile: extension = guess_extension(mimetype) or ".bin" unique_name = uuid4().hex @@ -106,7 +106,7 @@ class DatasourceFileManager: user_id: str, tenant_id: str, file_url: str, - conversation_id: Optional[str] = None, + conversation_id: str | None = None, ) -> ToolFile: # try to download image try: @@ -153,10 +153,7 @@ class DatasourceFileManager: :return: the binary of the file, mime type """ upload_file: UploadFile | None = ( - db.session.query(UploadFile) - .filter( - UploadFile.id == id, - ) + db.session.query(UploadFile).where(UploadFile.id == id) .first() ) @@ -177,10 +174,7 @@ class DatasourceFileManager: :return: the binary of the file, mime type """ message_file: MessageFile | None = ( - db.session.query(MessageFile) - .filter( - MessageFile.id == id, - ) + db.session.query(MessageFile).where(MessageFile.id == id) .first() ) @@ -197,10 +191,7 @@ class DatasourceFileManager: tool_file_id = None tool_file: ToolFile | None = ( - db.session.query(ToolFile) - .filter( - ToolFile.id == tool_file_id, - ) + db.session.query(ToolFile).where(ToolFile.id == tool_file_id) .first() ) @@ -221,10 +212,7 @@ class DatasourceFileManager: :return: the binary of the file, mime type """ upload_file: UploadFile | None = ( - db.session.query(UploadFile) - .filter( - UploadFile.id == upload_file_id, - ) + db.session.query(UploadFile).where(UploadFile.id == upload_file_id) .first() ) diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 81771719ea..af8ce4ed9b 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -12,9 +12,9 @@ class DatasourceApiEntity(BaseModel): name: str # identifier label: I18nObject # label description: I18nObject - parameters: Optional[list[DatasourceParameter]] = None + parameters: list[DatasourceParameter] | None = None labels: list[str] = Field(default_factory=list) - output_schema: Optional[dict] = None + output_schema: dict | None = None ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]] @@ -28,12 +28,12 @@ class DatasourceProviderApiEntity(BaseModel): icon: str | dict label: I18nObject # label type: str - masked_credentials: Optional[dict] = None - original_credentials: Optional[dict] = None + masked_credentials: dict | None = None + original_credentials: dict | None = None is_team_authorization: bool = False allow_delete: bool = True - plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource") - plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource") + plugin_id: str | None = Field(default="", description="The plugin id of the datasource") + plugin_unique_identifier: str | None = Field(default="", description="The unique identifier of the datasource") datasources: list[DatasourceApiEntity] = Field(default_factory=list) labels: list[str] = Field(default_factory=list) diff --git a/api/core/datasource/entities/common_entities.py b/api/core/datasource/entities/common_entities.py index 924e6fc0cf..98680a5779 100644 --- a/api/core/datasource/entities/common_entities.py +++ b/api/core/datasource/entities/common_entities.py @@ -1,4 +1,3 @@ -from typing import Optional from pydantic import BaseModel, Field @@ -9,9 +8,9 @@ class I18nObject(BaseModel): """ en_US: str - zh_Hans: Optional[str] = Field(default=None) - pt_BR: Optional[str] = Field(default=None) - ja_JP: Optional[str] = Field(default=None) + zh_Hans: str | None = Field(default=None) + pt_BR: str | None = Field(default=None) + ja_JP: str | None = Field(default=None) def __init__(self, **data): super().__init__(**data) diff --git a/api/core/datasource/entities/datasource_entities.py b/api/core/datasource/entities/datasource_entities.py index 0c2011f841..ac4f51ac75 100644 --- a/api/core/datasource/entities/datasource_entities.py +++ b/api/core/datasource/entities/datasource_entities.py @@ -1,6 +1,6 @@ import enum from enum import Enum -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field, ValidationInfo, field_validator from yarl import URL @@ -80,7 +80,7 @@ class DatasourceParameter(PluginParameter): name: str, typ: DatasourceParameterType, required: bool, - options: Optional[list[str]] = None, + options: list[str] | None = None, ) -> "DatasourceParameter": """ get a simple datasource parameter @@ -120,14 +120,14 @@ class DatasourceIdentity(BaseModel): name: str = Field(..., description="The name of the datasource") label: I18nObject = Field(..., description="The label of the datasource") provider: str = Field(..., description="The provider of the datasource") - icon: Optional[str] = None + icon: str | None = None class DatasourceEntity(BaseModel): identity: DatasourceIdentity parameters: list[DatasourceParameter] = Field(default_factory=list) description: I18nObject = Field(..., description="The label of the datasource") - output_schema: Optional[dict] = None + output_schema: dict | None = None @field_validator("parameters", mode="before") @classmethod @@ -141,7 +141,7 @@ class DatasourceProviderIdentity(BaseModel): description: I18nObject = Field(..., description="The description of the tool") icon: str = Field(..., description="The icon of the tool") label: I18nObject = Field(..., description="The label of the tool") - tags: Optional[list[ToolLabelEnum]] = Field( + tags: list[ToolLabelEnum] | None = Field( default=[], description="The tags of the tool", ) @@ -169,7 +169,7 @@ class DatasourceProviderEntity(BaseModel): identity: DatasourceProviderIdentity credentials_schema: list[ProviderConfig] = Field(default_factory=list) - oauth_schema: Optional[OAuthSchema] = None + oauth_schema: OAuthSchema | None = None provider_type: DatasourceProviderType @@ -183,8 +183,8 @@ class DatasourceInvokeMeta(BaseModel): """ time_cost: float = Field(..., description="The time cost of the tool invoke") - error: Optional[str] = None - tool_config: Optional[dict] = None + error: str | None = None + tool_config: dict | None = None @classmethod def empty(cls) -> "DatasourceInvokeMeta": @@ -233,10 +233,10 @@ class OnlineDocumentPage(BaseModel): page_id: str = Field(..., description="The page id") page_name: str = Field(..., description="The page title") - page_icon: Optional[dict] = Field(None, description="The page icon") + page_icon: dict | None = Field(None, description="The page icon") type: str = Field(..., description="The type of the page") last_edited_time: str = Field(..., description="The last edited time") - parent_id: Optional[str] = Field(None, description="The parent page id") + parent_id: str | None = Field(None, description="The parent page id") class OnlineDocumentInfo(BaseModel): @@ -244,9 +244,9 @@ class OnlineDocumentInfo(BaseModel): Online document info """ - workspace_id: Optional[str] = Field(None, description="The workspace id") - workspace_name: Optional[str] = Field(None, description="The workspace name") - workspace_icon: Optional[str] = Field(None, description="The workspace icon") + workspace_id: str | None = Field(None, description="The workspace id") + workspace_name: str | None = Field(None, description="The workspace name") + workspace_icon: str | None = Field(None, description="The workspace icon") total: int = Field(..., description="The total number of documents") pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document") @@ -307,10 +307,10 @@ class WebSiteInfo(BaseModel): Website info """ - status: Optional[str] = Field(..., description="crawl job status") - web_info_list: Optional[list[WebSiteInfoDetail]] = [] - total: Optional[int] = Field(default=0, description="The total number of websites") - completed: Optional[int] = Field(default=0, description="The number of completed websites") + status: str | None = Field(..., description="crawl job status") + web_info_list: list[WebSiteInfoDetail] | None = [] + total: int | None = Field(default=0, description="The total number of websites") + completed: int | None = Field(default=0, description="The number of completed websites") class WebsiteCrawlMessage(BaseModel): @@ -346,10 +346,10 @@ class OnlineDriveFileBucket(BaseModel): Online drive file bucket """ - bucket: Optional[str] = Field(None, description="The file bucket") + bucket: str | None = Field(None, description="The file bucket") files: list[OnlineDriveFile] = Field(..., description="The file list") is_truncated: bool = Field(False, description="Whether the result is truncated") - next_page_parameters: Optional[dict] = Field(None, description="Parameters for fetching the next page") + next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page") class OnlineDriveBrowseFilesRequest(BaseModel): @@ -357,10 +357,10 @@ class OnlineDriveBrowseFilesRequest(BaseModel): Get online drive file list request """ - bucket: Optional[str] = Field(None, description="The file bucket") + bucket: str | None = Field(None, description="The file bucket") prefix: str = Field(..., description="The parent folder ID") max_keys: int = Field(20, description="Page size for pagination") - next_page_parameters: Optional[dict] = Field(None, description="Parameters for fetching the next page") + next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page") class OnlineDriveBrowseFilesResponse(BaseModel): @@ -377,4 +377,4 @@ class OnlineDriveDownloadFileRequest(BaseModel): """ id: str = Field(..., description="The id of the file") - bucket: Optional[str] = Field(None, description="The name of the bucket") + bucket: str | None = Field(None, description="The name of the bucket") diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index bb6ac6c1fc..5aa25b573f 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -1,7 +1,6 @@ import logging from collections.abc import Generator from mimetypes import guess_extension, guess_type -from typing import Optional from core.datasource.entities.datasource_entities import DatasourceMessage from core.file import File, FileTransferMethod, FileType @@ -17,7 +16,7 @@ class DatasourceFileMessageTransformer: messages: Generator[DatasourceMessage, None, None], user_id: str, tenant_id: str, - conversation_id: Optional[str] = None, + conversation_id: str | None = None, ) -> Generator[DatasourceMessage, None, None]: """ Transform datasource message and handle file download @@ -121,5 +120,5 @@ class DatasourceFileMessageTransformer: yield message @classmethod - def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str: + def get_datasource_file_url(cls, datasource_file_id: str, extension: str | None) -> str: return f"/files/datasources/{datasource_file_id}{extension or '.bin'}" diff --git a/api/core/datasource/utils/parser.py b/api/core/datasource/utils/parser.py index 57ee15d7f2..db1766a059 100644 --- a/api/core/datasource/utils/parser.py +++ b/api/core/datasource/utils/parser.py @@ -3,7 +3,6 @@ import uuid from json import dumps as json_dumps from json import loads as json_loads from json.decoder import JSONDecodeError -from typing import Optional from flask import request from requests import get @@ -169,9 +168,9 @@ class ApiBasedToolSchemaParser: return bundles @staticmethod - def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]: + def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None: parameter = parameter or {} - typ: Optional[str] = None + typ: str | None = None if parameter.get("format") == "binary": return ToolParameter.ToolParameterType.FILE diff --git a/api/core/entities/knowledge_entities.py b/api/core/entities/knowledge_entities.py index 33e1f64579..f6da4c7094 100644 --- a/api/core/entities/knowledge_entities.py +++ b/api/core/entities/knowledge_entities.py @@ -1,4 +1,3 @@ -from typing import Optional from pydantic import BaseModel @@ -16,7 +15,7 @@ class QAPreviewDetail(BaseModel): class IndexingEstimate(BaseModel): total_segments: int preview: list[PreviewDetail] - qa_preview: Optional[list[QAPreviewDetail]] = None + qa_preview: list[QAPreviewDetail] | None = None class PipelineDataset(BaseModel): @@ -30,10 +29,10 @@ class PipelineDocument(BaseModel): id: str position: int data_source_type: str - data_source_info: Optional[dict] = None + data_source_info: dict | None = None name: str indexing_status: str - error: Optional[str] = None + error: str | None = None enabled: bool diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 3063cd39ae..57012bf495 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -1,7 +1,7 @@ import datetime from collections.abc import Mapping from enum import StrEnum, auto -from typing import Any, Optional +from typing import Any from packaging.version import InvalidVersion, Version from pydantic import BaseModel, Field, field_validator, model_validator @@ -67,10 +67,10 @@ class PluginCategory(StrEnum): class PluginDeclaration(BaseModel): class Plugins(BaseModel): - tools: Optional[list[str]] = Field(default_factory=list[str]) - models: Optional[list[str]] = Field(default_factory=list[str]) - endpoints: Optional[list[str]] = Field(default_factory=list[str]) - datasources: Optional[list[str]] = Field(default_factory=list[str]) + tools: list[str] | None = Field(default_factory=list[str]) + models: list[str] | None = Field(default_factory=list[str]) + endpoints: list[str] | None = Field(default_factory=list[str]) + datasources: list[str] | None = Field(default_factory=list[str]) class Meta(BaseModel): minimum_dify_version: str | None = Field(default=None) @@ -101,11 +101,11 @@ class PluginDeclaration(BaseModel): tags: list[str] = Field(default_factory=list) repo: str | None = Field(default=None) verified: bool = Field(default=False) - tool: Optional[ToolProviderEntity] = None - model: Optional[ProviderEntity] = None - endpoint: Optional[EndpointProviderDeclaration] = None - agent_strategy: Optional[AgentStrategyProviderEntity] = None - datasource: Optional[DatasourceProviderEntity] = None + tool: ToolProviderEntity | None = None + model: ProviderEntity | None = None + endpoint: EndpointProviderDeclaration | None = None + agent_strategy: AgentStrategyProviderEntity | None = None + datasource: DatasourceProviderEntity | None = None meta: Meta @field_validator("version") diff --git a/api/core/rag/entities/event.py b/api/core/rag/entities/event.py index a36e32fc9c..24db5d77be 100644 --- a/api/core/rag/entities/event.py +++ b/api/core/rag/entities/event.py @@ -1,6 +1,6 @@ from collections.abc import Mapping from enum import Enum -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field @@ -27,12 +27,12 @@ class DatasourceErrorEvent(BaseDatasourceEvent): class DatasourceCompletedEvent(BaseDatasourceEvent): event: str = DatasourceStreamEvent.COMPLETED.value data: Mapping[str, Any] | list = Field(..., description="result") - total: Optional[int] = Field(default=0, description="total") - completed: Optional[int] = Field(default=0, description="completed") - time_consuming: Optional[float] = Field(default=0.0, description="time consuming") + total: int | None = Field(default=0, description="total") + completed: int | None = Field(default=0, description="completed") + time_consuming: float | None = Field(default=0.0, description="time consuming") class DatasourceProcessingEvent(BaseDatasourceEvent): event: str = DatasourceStreamEvent.PROCESSING.value - total: Optional[int] = Field(..., description="total") - completed: Optional[int] = Field(..., description="completed") + total: int | None = Field(..., description="total") + completed: int | None = Field(..., description="completed") diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index c0e79b02c4..b5eea0bf30 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -1,4 +1,3 @@ -from typing import Optional from pydantic import BaseModel, ConfigDict @@ -11,7 +10,7 @@ class NotionInfo(BaseModel): Notion import info. """ - credential_id: Optional[str] = None + credential_id: str | None = None notion_workspace_id: str notion_obj_id: str notion_page_type: str diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index c1563840f0..bddf41af43 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -1,7 +1,7 @@ import json import logging import operator -from typing import Any, Optional, cast +from typing import Any, cast import requests @@ -35,9 +35,9 @@ class NotionExtractor(BaseExtractor): notion_obj_id: str, notion_page_type: str, tenant_id: str, - document_model: Optional[DocumentModel] = None, - notion_access_token: Optional[str] = None, - credential_id: Optional[str] = None, + document_model: DocumentModel | None = None, + notion_access_token: str | None = None, + credential_id: str | None = None, ): self._notion_access_token = None self._document_model = document_model @@ -369,7 +369,7 @@ class NotionExtractor(BaseExtractor): return cast(str, data["last_edited_time"]) @classmethod - def _get_access_token(cls, tenant_id: str, credential_id: Optional[str]) -> str: + def _get_access_token(cls, tenant_id: str, credential_id: str | None) -> str: # get credential from tenant_id and credential_id if not credential_id: raise Exception(f"No credential id found for tenant {tenant_id}") diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 219aec5a03..5226a1071f 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -7,7 +7,7 @@ import json import logging from collections.abc import Callable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor -from typing import Any, Optional, TypeVar, Union +from typing import Any, TypeVar, Union import psycopg2.errors from sqlalchemy import UnaryExpression, asc, desc, select @@ -530,7 +530,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) def get_by_workflow_run( self, workflow_run_id: str, - order_config: Optional[OrderConfig] = None, + order_config: OrderConfig | None = None, triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) -> Sequence[WorkflowNodeExecution]: """ diff --git a/api/core/schemas/registry.py b/api/core/schemas/registry.py index 339784267c..867e4803bc 100644 --- a/api/core/schemas/registry.py +++ b/api/core/schemas/registry.py @@ -85,7 +85,7 @@ class SchemaRegistry: except (OSError, json.JSONDecodeError) as e: print(f"Warning: failed to load schema {version}/{schema_name}: {e}") - def get_schema(self, uri: str) -> Optional[Any]: + def get_schema(self, uri: str) -> Any | None: """Retrieves a schema by URI with version support""" version, schema_name = self._parse_uri(uri) if not version or not schema_name: diff --git a/api/core/schemas/resolver.py b/api/core/schemas/resolver.py index 1c5dabd79b..1b57f5bb94 100644 --- a/api/core/schemas/resolver.py +++ b/api/core/schemas/resolver.py @@ -3,7 +3,7 @@ import re import threading from collections import deque from dataclasses import dataclass -from typing import Any, Optional, Union +from typing import Any, Union from core.schemas.registry import SchemaRegistry @@ -53,8 +53,8 @@ class QueueItem: """Represents an item in the BFS queue""" current: Any - parent: Optional[Any] - key: Optional[Union[str, int]] + parent: Any | None + key: Union[str, int] | None depth: int ref_path: set[str] @@ -65,7 +65,7 @@ class SchemaResolver: _cache: dict[str, SchemaDict] = {} _cache_lock = threading.Lock() - def __init__(self, registry: Optional[SchemaRegistry] = None, max_depth: int = 10): + def __init__(self, registry: SchemaRegistry | None = None, max_depth: int = 10): """ Initialize the schema resolver @@ -202,7 +202,7 @@ class SchemaResolver: ) ) - def _get_resolved_schema(self, ref_uri: str) -> Optional[SchemaDict]: + def _get_resolved_schema(self, ref_uri: str) -> SchemaDict | None: """Get resolved schema from cache or registry""" # Check cache first with self._cache_lock: @@ -223,7 +223,7 @@ class SchemaResolver: def resolve_dify_schema_refs( - schema: SchemaType, registry: Optional[SchemaRegistry] = None, max_depth: int = 30 + schema: SchemaType, registry: SchemaRegistry | None = None, max_depth: int = 30 ) -> SchemaType: """ Resolve $ref references in Dify schema to actual schema content diff --git a/api/core/schemas/schema_manager.py b/api/core/schemas/schema_manager.py index 3c9314db66..833ab609c7 100644 --- a/api/core/schemas/schema_manager.py +++ b/api/core/schemas/schema_manager.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from core.schemas.registry import SchemaRegistry @@ -7,7 +7,7 @@ from core.schemas.registry import SchemaRegistry class SchemaManager: """Schema manager provides high-level schema operations""" - def __init__(self, registry: Optional[SchemaRegistry] = None): + def __init__(self, registry: SchemaRegistry | None = None): self.registry = registry or SchemaRegistry.default_registry() def get_all_schema_definitions(self, version: str = "v1") -> list[Mapping[str, Any]]: @@ -22,7 +22,7 @@ class SchemaManager: """ return self.registry.get_all_schemas_for_version(version) - def get_schema_by_name(self, schema_name: str, version: str = "v1") -> Optional[Mapping[str, Any]]: + def get_schema_by_name(self, schema_name: str, version: str = "v1") -> Mapping[str, Any] | None: """ Get a specific schema by name diff --git a/api/core/workflow/entities/workflow_node_execution.py b/api/core/workflow/entities/workflow_node_execution.py index ef3022352a..4abc9c068d 100644 --- a/api/core/workflow/entities/workflow_node_execution.py +++ b/api/core/workflow/entities/workflow_node_execution.py @@ -8,7 +8,7 @@ and don't contain implementation details like tenant_id, app_id, etc. from collections.abc import Mapping from datetime import datetime -from typing import Any, Optional +from typing import Any from pydantic import BaseModel, Field, PrivateAttr @@ -53,9 +53,9 @@ class WorkflowNodeExecution(BaseModel): # Execution data # The `inputs` and `outputs` fields hold the full content - inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node - process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data - outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node + inputs: Mapping[str, Any] | None = None # Input variables used by this node + process_data: Mapping[str, Any] | None = None # Intermediate processing data + outputs: Mapping[str, Any] | None = None # Output variables produced by this node # Execution state status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index eb58ba14c1..6cf0c91c30 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,5 +1,5 @@ from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -50,7 +50,7 @@ class DatasourceNode(Node): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = DatasourceNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -59,7 +59,7 @@ class DatasourceNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: @@ -179,7 +179,7 @@ class DatasourceNode(Node): related_id = datasource_info.get("related_id") if not related_id: raise DatasourceNodeError("File is not exist") - upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first() + upload_file = db.session.query(UploadFile).where(UploadFile.id == related_id).first() if not upload_file: raise ValueError("Invalid upload file Info") diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index b182928baa..4802d3ed98 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -1,4 +1,4 @@ -from typing import Any, Literal, Optional, Union +from typing import Any, Literal, Union from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo @@ -10,7 +10,7 @@ class DatasourceEntity(BaseModel): plugin_id: str provider_name: str # redundancy provider_type: str - datasource_name: Optional[str] = "local_file" + datasource_name: str | None = "local_file" datasource_configurations: dict[str, Any] | None = None plugin_unique_identifier: str | None = None # redundancy @@ -19,7 +19,7 @@ class DatasourceNodeData(BaseNodeData, DatasourceEntity): class DatasourceInput(BaseModel): # TODO: check this type value: Union[Any, list[str]] - type: Optional[Literal["mixed", "variable", "constant"]] = None + type: Literal["mixed", "variable", "constant"] | None = None @field_validator("type", mode="before") @classmethod diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 85c0f695c6..2a2e983a0c 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -1,4 +1,4 @@ -from typing import Literal, Optional, Union +from typing import Literal, Union from pydantic import BaseModel @@ -65,12 +65,12 @@ class RetrievalSetting(BaseModel): search_method: Literal["semantic_search", "keyword_search", "fulltext_search", "hybrid_search"] top_k: int - score_threshold: Optional[float] = 0.5 + score_threshold: float | None = 0.5 score_threshold_enabled: bool = False reranking_mode: str = "reranking_model" reranking_enable: bool = True - reranking_model: Optional[RerankingModelConfig] = None - weights: Optional[WeightedScoreConfig] = None + reranking_model: RerankingModelConfig | None = None + weights: WeightedScoreConfig | None = None class IndexMethod(BaseModel): @@ -107,10 +107,10 @@ class OnlineDocumentInfo(BaseModel): """ provider: str - workspace_id: Optional[str] = None + workspace_id: str | None = None page_id: str page_type: str - icon: Optional[OnlineDocumentIcon] = None + icon: OnlineDocumentIcon | None = None class WebsiteInfo(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index d7641bc123..d5ced1a246 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -2,7 +2,7 @@ import datetime import logging import time from collections.abc import Mapping -from typing import Any, Optional, cast +from typing import Any, cast from sqlalchemy import func, select @@ -43,7 +43,7 @@ class KnowledgeIndexNode(Node): def init_node_data(self, data: Mapping[str, Any]) -> None: self._node_data = KnowledgeIndexNodeData.model_validate(data) - def _get_error_strategy(self) -> Optional[ErrorStrategy]: + def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy def _get_retry_config(self) -> RetryConfig: @@ -52,7 +52,7 @@ class KnowledgeIndexNode(Node): def _get_title(self) -> str: return self._node_data.title - def _get_description(self) -> Optional[str]: + def _get_description(self) -> str | None: return self._node_data.desc def _get_default_value_dict(self) -> dict[str, Any]: diff --git a/api/models/dataset.py b/api/models/dataset.py index 2c03a0c30c..d620d56006 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -10,7 +10,7 @@ import re import time from datetime import datetime from json import JSONDecodeError -from typing import Any, Optional, cast +from typing import Any, cast import sqlalchemy as sa from sqlalchemy import DateTime, String, func, select @@ -76,7 +76,7 @@ class Dataset(Base): @property def total_documents(self): - return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar() + return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar() @property def total_available_documents(self): @@ -173,10 +173,10 @@ class Dataset(Base): ) @property - def doc_form(self) -> Optional[str]: + def doc_form(self) -> str | None: if self.chunk_structure: return self.chunk_structure - document = db.session.query(Document).filter(Document.dataset_id == self.id).first() + document = db.session.query(Document).where(Document.dataset_id == self.id).first() if document: return document.doc_form return None @@ -234,7 +234,7 @@ class Dataset(Base): @property def is_published(self): if self.pipeline_id: - pipeline = db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first() + pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first() if pipeline: return pipeline.is_published return False @@ -1244,7 +1244,7 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined] @property def created_user_name(self): - account = db.session.query(Account).filter(Account.id == self.created_by).first() + account = db.session.query(Account).where(Account.id == self.created_by).first() if account: return account.name return "" @@ -1274,7 +1274,7 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined] @property def created_user_name(self): - account = db.session.query(Account).filter(Account.id == self.created_by).first() + account = db.session.query(Account).where(Account.id == self.created_by).first() if account: return account.name return "" @@ -1297,7 +1297,7 @@ class Pipeline(Base): # type: ignore[name-defined] updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def retrieve_dataset(self, session: Session): - return session.query(Dataset).filter(Dataset.pipeline_id == self.id).first() + return session.query(Dataset).where(Dataset.pipeline_id == self.id).first() class DocumentPipelineExecutionLog(Base): diff --git a/api/models/workflow.py b/api/models/workflow.py index bb7ea2c074..5f604a51a8 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1546,7 +1546,7 @@ class WorkflowDraftVariableFile(Base): comment="Size of the original variable content in bytes", ) - length: Mapped[Optional[int]] = mapped_column( + length: Mapped[int | None] = mapped_column( sa.Integer, nullable=True, comment=( diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 84e36cb80b..798233fd95 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -7,7 +7,7 @@ import time import uuid from collections import Counter from collections.abc import Sequence -from typing import Any, Literal, Optional +from typing import Any, Literal import sqlalchemy as sa from sqlalchemy import exists, func, select @@ -315,8 +315,8 @@ class DatasetService: return dataset @staticmethod - def get_dataset(dataset_id) -> Optional[Dataset]: - dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first() + def get_dataset(dataset_id) -> Dataset | None: + dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first() return dataset @staticmethod diff --git a/api/services/entities/knowledge_entities/rag_pipeline_entities.py b/api/services/entities/knowledge_entities/rag_pipeline_entities.py index e215a89c15..ac96b5c8ad 100644 --- a/api/services/entities/knowledge_entities/rag_pipeline_entities.py +++ b/api/services/entities/knowledge_entities/rag_pipeline_entities.py @@ -1,13 +1,13 @@ -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel, field_validator class IconInfo(BaseModel): icon: str - icon_background: Optional[str] = None - icon_type: Optional[str] = None - icon_url: Optional[str] = None + icon_background: str | None = None + icon_type: str | None = None + icon_url: str | None = None class PipelineTemplateInfoEntity(BaseModel): @@ -21,8 +21,8 @@ class RagPipelineDatasetCreateEntity(BaseModel): description: str icon_info: IconInfo permission: str - partial_member_list: Optional[list[str]] = None - yaml_content: Optional[str] = None + partial_member_list: list[str] | None = None + yaml_content: str | None = None class RerankingModelConfig(BaseModel): @@ -30,8 +30,8 @@ class RerankingModelConfig(BaseModel): Reranking Model Config. """ - reranking_provider_name: Optional[str] = "" - reranking_model_name: Optional[str] = "" + reranking_provider_name: str | None = "" + reranking_model_name: str | None = "" class VectorSetting(BaseModel): @@ -57,8 +57,8 @@ class WeightedScoreConfig(BaseModel): Weighted score Config. """ - vector_setting: Optional[VectorSetting] - keyword_setting: Optional[KeywordSetting] + vector_setting: VectorSetting | None + keyword_setting: KeywordSetting | None class EmbeddingSetting(BaseModel): @@ -85,12 +85,12 @@ class RetrievalSetting(BaseModel): search_method: Literal["semantic_search", "fulltext_search", "keyword_search", "hybrid_search"] top_k: int - score_threshold: Optional[float] = 0.5 + score_threshold: float | None = 0.5 score_threshold_enabled: bool = False - reranking_mode: Optional[str] = "reranking_model" - reranking_enable: Optional[bool] = True - reranking_model: Optional[RerankingModelConfig] = None - weights: Optional[WeightedScoreConfig] = None + reranking_mode: str | None = "reranking_model" + reranking_enable: bool | None = True + reranking_model: RerankingModelConfig | None = None + weights: WeightedScoreConfig | None = None class IndexMethod(BaseModel): @@ -112,7 +112,7 @@ class KnowledgeConfiguration(BaseModel): indexing_technique: Literal["high_quality", "economy"] embedding_model_provider: str = "" embedding_model: str = "" - keyword_number: Optional[int] = 10 + keyword_number: int | None = 10 retrieval_model: RetrievalSetting @field_validator("embedding_model_provider", mode="before") diff --git a/api/services/rag_pipeline/entity/pipeline_service_api_entities.py b/api/services/rag_pipeline/entity/pipeline_service_api_entities.py index 35005fad71..41f46a55a7 100644 --- a/api/services/rag_pipeline/entity/pipeline_service_api_entities.py +++ b/api/services/rag_pipeline/entity/pipeline_service_api_entities.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import Any, Optional +from typing import Any from pydantic import BaseModel @@ -9,7 +9,7 @@ class DatasourceNodeRunApiEntity(BaseModel): node_id: str inputs: Mapping[str, Any] datasource_type: str - credential_id: Optional[str] = None + credential_id: str | None = None is_published: bool diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py index 563174c528..e6cee64df6 100644 --- a/api/services/rag_pipeline/pipeline_generate_service.py +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -108,7 +108,7 @@ class PipelineGenerateService: Update document status to waiting :param document_id: document id """ - document = db.session.query(Document).filter(Document.id == document_id).first() + document = db.session.query(Document).where(Document.id == document_id).first() if document: document.indexing_status = "waiting" db.session.add(document) diff --git a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py index b0fa54115c..24baeb73b5 100644 --- a/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/built_in/built_in_retrieval.py @@ -1,7 +1,6 @@ import json from os import path from pathlib import Path -from typing import Optional from flask import current_app @@ -14,7 +13,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): Retrieval pipeline template from built-in, the location is constants/pipeline_templates.json """ - builtin_data: Optional[dict] = None + builtin_data: dict | None = None def get_type(self) -> str: return PipelineTemplateType.BUILTIN @@ -54,7 +53,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return builtin_data.get("pipeline_templates", {}).get(language, {}) @classmethod - def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> Optional[dict]: + def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> dict | None: """ Fetch pipeline template detail from builtin. :param template_id: Template ID diff --git a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py index 3380d23ec4..82a0a08ec6 100644 --- a/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/customized/customized_retrieval.py @@ -1,4 +1,3 @@ -from typing import Optional import yaml from flask_login import current_user @@ -56,14 +55,14 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return {"pipeline_templates": recommended_pipelines_results} @classmethod - def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]: + def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None: """ Fetch pipeline template detail from db. :param template_id: Template ID :return: """ pipeline_template = ( - db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first() + db.session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first() ) if not pipeline_template: return None diff --git a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py index 709702fe11..a544767465 100644 --- a/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/database/database_retrieval.py @@ -1,4 +1,3 @@ -from typing import Optional import yaml @@ -33,7 +32,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ pipeline_built_in_templates: list[PipelineBuiltInTemplate] = ( - db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all() + db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.language == language).all() ) recommended_pipelines_results = [] @@ -53,7 +52,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return {"pipeline_templates": recommended_pipelines_results} @classmethod - def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]: + def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None: """ Fetch pipeline template detail from db. :param pipeline_id: Pipeline ID @@ -61,7 +60,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): """ # is in public recommended list pipeline_template = ( - db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == template_id).first() + db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.id == template_id).first() ) if not pipeline_template: diff --git a/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py b/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py index fa6a38a357..21c30a4986 100644 --- a/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py +++ b/api/services/rag_pipeline/pipeline_template/pipeline_template_base.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from typing import Optional class PipelineTemplateRetrievalBase(ABC): @@ -10,7 +9,7 @@ class PipelineTemplateRetrievalBase(ABC): raise NotImplementedError @abstractmethod - def get_pipeline_template_detail(self, template_id: str) -> Optional[dict]: + def get_pipeline_template_detail(self, template_id: str) -> dict | None: raise NotImplementedError @abstractmethod diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py index e541a7bc0b..8f96842337 100644 --- a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py @@ -1,5 +1,4 @@ import logging -from typing import Optional import requests @@ -36,7 +35,7 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return PipelineTemplateType.REMOTE @classmethod - def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> Optional[dict]: + def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict | None: """ Fetch pipeline template detail from dify official. :param template_id: Pipeline ID diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index e27d78b980..88e1dab23e 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -5,7 +5,7 @@ import threading import time from collections.abc import Callable, Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import Any, Optional, Union, cast +from typing import Any, Union, cast from uuid import uuid4 from flask_login import current_user @@ -112,7 +112,7 @@ class RagPipelineService: return result @classmethod - def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> Optional[dict]: + def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None: """ Get pipeline template detail. :param template_id: template id @@ -121,12 +121,12 @@ class RagPipelineService: if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() - built_in_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) + built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id) return built_in_result else: mode = "customized" retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() - customized_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id) + customized_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id) return customized_result @classmethod @@ -185,7 +185,7 @@ class RagPipelineService: db.session.delete(customized_template) db.session.commit() - def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]: + def get_draft_workflow(self, pipeline: Pipeline) -> Workflow | None: """ Get draft workflow """ @@ -203,7 +203,7 @@ class RagPipelineService: # return draft workflow return workflow - def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]: + def get_published_workflow(self, pipeline: Pipeline) -> Workflow | None: """ Get published workflow """ @@ -267,7 +267,7 @@ class RagPipelineService: *, pipeline: Pipeline, graph: dict, - unique_hash: Optional[str], + unique_hash: str | None, account: Account, environment_variables: Sequence[Variable], conversation_variables: Sequence[Variable], @@ -387,9 +387,7 @@ class RagPipelineService: return default_block_configs - def get_default_block_config( - self, node_type: str, filters: Optional[dict] = None - ) -> Optional[Mapping[str, object]]: + def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None: """ Get default config of node. :param node_type: node type @@ -495,7 +493,7 @@ class RagPipelineService: account: Account, datasource_type: str, is_published: bool, - credential_id: Optional[str] = None, + credential_id: str | None = None, ) -> Generator[Mapping[str, Any], None, None]: """ Run published workflow datasource @@ -661,7 +659,7 @@ class RagPipelineService: account: Account, datasource_type: str, is_published: bool, - credential_id: Optional[str] = None, + credential_id: str | None = None, ) -> Mapping[str, Any]: """ Run published workflow datasource @@ -876,7 +874,7 @@ class RagPipelineService: if invoke_from.value == InvokeFrom.PUBLISHED.value: document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) if document_id: - document = db.session.query(Document).filter(Document.id == document_id.value).first() + document = db.session.query(Document).where(Document.id == document_id.value).first() if document: document.indexing_status = "error" document.error = error @@ -887,7 +885,7 @@ class RagPipelineService: def update_workflow( self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict - ) -> Optional[Workflow]: + ) -> Workflow | None: """ Update workflow attributes @@ -1057,7 +1055,7 @@ class RagPipelineService: return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) - def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]: + def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> WorkflowRun | None: """ Get workflow run detail @@ -1113,12 +1111,12 @@ class RagPipelineService: """ Publish customized pipeline template """ - pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first() + pipeline = db.session.query(Pipeline).where(Pipeline.id == pipeline_id).first() if not pipeline: raise ValueError("Pipeline not found") if not pipeline.workflow_id: raise ValueError("Pipeline workflow not found") - workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() + workflow = db.session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first() if not workflow: raise ValueError("Workflow not found") with Session(db.engine) as session: @@ -1142,7 +1140,7 @@ class RagPipelineService: max_position = ( db.session.query(func.max(PipelineCustomizedTemplate.position)) - .filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id) + .where(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id) .scalar() ) @@ -1278,7 +1276,7 @@ class RagPipelineService: # Query active recommended plugins pipeline_recommended_plugins = ( db.session.query(PipelineRecommendedPlugin) - .filter(PipelineRecommendedPlugin.active == True) + .where(PipelineRecommendedPlugin.active == True) .order_by(PipelineRecommendedPlugin.position.asc()) .all() ) @@ -1329,12 +1327,12 @@ class RagPipelineService: """ document_pipeline_excution_log = ( db.session.query(DocumentPipelineExecutionLog) - .filter(DocumentPipelineExecutionLog.document_id == document.id) + .where(DocumentPipelineExecutionLog.document_id == document.id) .first() ) if not document_pipeline_excution_log: raise ValueError("Document pipeline execution log not found") - pipeline = db.session.query(Pipeline).filter(Pipeline.id == document_pipeline_excution_log.pipeline_id).first() + pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_excution_log.pipeline_id).first() if not pipeline: raise ValueError("Pipeline not found") # convert to app config diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index fe92f6b084..e21d2d56bc 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -6,7 +6,7 @@ import uuid from collections.abc import Mapping from datetime import UTC, datetime from enum import StrEnum -from typing import Optional, cast +from typing import cast from urllib.parse import urlparse from uuid import uuid4 @@ -66,11 +66,11 @@ class ImportStatus(StrEnum): class RagPipelineImportInfo(BaseModel): id: str status: ImportStatus - pipeline_id: Optional[str] = None + pipeline_id: str | None = None current_dsl_version: str = CURRENT_DSL_VERSION imported_dsl_version: str = "" error: str = "" - dataset_id: Optional[str] = None + dataset_id: str | None = None class CheckDependenciesResult(BaseModel): @@ -121,12 +121,12 @@ class RagPipelineDslService: *, account: Account, import_mode: str, - yaml_content: Optional[str] = None, - yaml_url: Optional[str] = None, - pipeline_id: Optional[str] = None, - dataset: Optional[Dataset] = None, - dataset_name: Optional[str] = None, - icon_info: Optional[IconInfo] = None, + yaml_content: str | None = None, + yaml_url: str | None = None, + pipeline_id: str | None = None, + dataset: Dataset | None = None, + dataset_name: str | None = None, + icon_info: IconInfo | None = None, ) -> RagPipelineImportInfo: """Import an app from YAML content or URL.""" import_id = str(uuid.uuid4()) @@ -530,10 +530,10 @@ class RagPipelineDslService: def _create_or_update_pipeline( self, *, - pipeline: Optional[Pipeline], + pipeline: Pipeline | None, data: dict, account: Account, - dependencies: Optional[list[PluginDependency]] = None, + dependencies: list[PluginDependency] | None = None, ) -> Pipeline: """Create a new app or update an existing one.""" if not account.current_tenant_id: diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 78440b4889..c2dbb484cf 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -1,7 +1,6 @@ import json from datetime import UTC, datetime from pathlib import Path -from typing import Optional from uuid import uuid4 import yaml @@ -21,7 +20,7 @@ from services.plugin.plugin_service import PluginService class RagPipelineTransformService: def transform_dataset(self, dataset_id: str): - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: raise ValueError("Dataset not found") if dataset.pipeline_id and dataset.runtime_mode == "rag_pipeline": @@ -90,7 +89,7 @@ class RagPipelineTransformService: "status": "success", } - def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: Optional[str]): + def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None): pipeline_yaml = {} if doc_form == "text_model": match datasource_type: @@ -152,7 +151,7 @@ class RagPipelineTransformService: return node def _deal_knowledge_index( - self, dataset: Dataset, doc_form: str, indexing_technique: Optional[str], retrieval_model: dict, node: dict + self, dataset: Dataset, doc_form: str, indexing_technique: str | None, retrieval_model: dict, node: dict ): knowledge_configuration_dict = node.get("data", {}) knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration_dict) @@ -289,7 +288,7 @@ class RagPipelineTransformService: jina_node_id = "1752491761974" firecrawl_node_id = "1752565402678" - documents = db.session.query(Document).filter(Document.dataset_id == dataset.id).all() + documents = db.session.query(Document).where(Document.dataset_id == dataset.id).all() for document in documents: data_source_info_dict = document.data_source_info_dict @@ -299,7 +298,7 @@ class RagPipelineTransformService: document.data_source_type = "local_file" file_id = data_source_info_dict.get("upload_file_id") if file_id: - file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() + file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() if file: data_source_info = json.dumps( { diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index 9b3857b4a5..447443703a 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -1,6 +1,5 @@ import logging import time -from typing import Optional import click from celery import shared_task @@ -17,7 +16,7 @@ logger = logging.getLogger(__name__) @shared_task(queue="dataset") -def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: Optional[str], file_ids: list[str]): +def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]): """ Clean document when document deleted. :param document_ids: document ids diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py index dc266aef65..df4a76d94f 100644 --- a/api/tasks/deal_dataset_index_update_task.py +++ b/api/tasks/deal_dataset_index_update_task.py @@ -44,7 +44,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): if dataset_documents: dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( {"indexing_status": "indexing"}, synchronize_session=False ) db.session.commit() @@ -76,12 +76,12 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): # clean keywords index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False) index_processor.load(dataset, documents, with_keywords=False) - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "completed"}, synchronize_session=False ) db.session.commit() except Exception as e: - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "error", "error": str(e)}, synchronize_session=False ) db.session.commit() @@ -100,7 +100,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): if dataset_documents: # update document status dataset_documents_ids = [doc.id for doc in dataset_documents] - db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update( + db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( {"indexing_status": "indexing"}, synchronize_session=False ) db.session.commit() @@ -148,12 +148,12 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): documents.append(document) # save vector index index_processor.load(dataset, documents, with_keywords=False) - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "completed"}, synchronize_session=False ) db.session.commit() except Exception as e: - db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update( + db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( {"indexing_status": "error", "error": str(e)}, synchronize_session=False ) db.session.commit() diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index 4780e48558..028f635188 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -104,20 +104,20 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], with Session(db.engine, expire_on_commit=False) as session: # Load required entities - account = session.query(Account).filter(Account.id == user_id).first() + account = session.query(Account).where(Account.id == user_id).first() if not account: raise ValueError(f"Account {user_id} not found") - tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first() + tenant = session.query(Tenant).where(Tenant.id == tenant_id).first() if not tenant: raise ValueError(f"Tenant {tenant_id} not found") account.current_tenant = tenant - pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first() + pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first() if not pipeline: raise ValueError(f"Pipeline {pipeline_id} not found") - workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() + workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first() if not workflow: raise ValueError(f"Workflow {pipeline.workflow_id} not found") diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index 72916972df..ee904c4649 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -125,20 +125,20 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], with Session(db.engine) as session: # Load required entities - account = session.query(Account).filter(Account.id == user_id).first() + account = session.query(Account).where(Account.id == user_id).first() if not account: raise ValueError(f"Account {user_id} not found") - tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first() + tenant = session.query(Tenant).where(Tenant.id == tenant_id).first() if not tenant: raise ValueError(f"Tenant {tenant_id} not found") account.current_tenant = tenant - pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first() + pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first() if not pipeline: raise ValueError(f"Pipeline {pipeline_id} not found") - workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first() + workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first() if not workflow: raise ValueError(f"Workflow {pipeline.workflow_id} not found") diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index f4e9b52778..9c12696824 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -38,7 +38,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_ if not user: logger.info(click.style(f"User not found: {user_id}", fg="red")) return - tenant = db.session.query(Tenant).filter(Tenant.id == dataset.tenant_id).first() + tenant = db.session.query(Tenant).where(Tenant.id == dataset.tenant_id).first() if not tenant: raise ValueError("Tenant not found") user.current_tenant = tenant