chore(api): apply autofix manully

This commit is contained in:
QuantumGhost 2025-09-17 22:31:19 +08:00
parent 73d4bb596a
commit eefcd3ecc4
47 changed files with 241 additions and 268 deletions

View File

@ -104,7 +104,7 @@ class CustomizedPipelineTemplateApi(Resource):
def post(self, template_id: str):
with Session(db.engine) as session:
template = (
session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
)
if not template:
raise ValueError("Customized pipeline template not found.")

View File

@ -1,6 +1,5 @@
from collections.abc import Callable
from functools import wraps
from typing import Optional
from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db
@ -10,7 +9,7 @@ from models.dataset import Pipeline
def get_rag_pipeline(
view: Optional[Callable] = None,
view: Callable | None = None,
):
def decorator(view_func):
@wraps(view_func)

View File

@ -1,6 +1,6 @@
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal, Optional
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
@ -114,9 +114,9 @@ class VariableEntity(BaseModel):
hide: bool = False
max_length: int | None = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Optional[Sequence[FileType]] = Field(default_factory=list)
allowed_file_extensions: Optional[Sequence[str]] = Field(default_factory=list)
allowed_file_upload_methods: Optional[Sequence[FileTransferMethod]] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
@field_validator("description", mode="before")
@classmethod
@ -134,8 +134,8 @@ class RagPipelineVariableEntity(VariableEntity):
Rag Pipeline Variable Entity.
"""
tooltips: Optional[str] = None
placeholder: Optional[str] = None
tooltips: str | None = None
placeholder: str | None = None
belong_to_node_id: str
@ -298,7 +298,7 @@ class AppConfig(BaseModel):
tenant_id: str
app_id: str
app_mode: AppMode
additional_features: Optional[AppAdditionalFeatures] = None
additional_features: AppAdditionalFeatures | None = None
variables: list[VariableEntity] = []
sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None

View File

@ -7,7 +7,7 @@ import threading
import time
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Optional, Union, cast, overload
from typing import Any, Literal, Union, cast, overload
from flask import Flask, current_app
from pydantic import ValidationError
@ -69,7 +69,7 @@ class PipelineGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: Optional[str],
workflow_thread_pool_id: str | None,
is_retry: bool = False,
) -> Generator[Mapping | str, None, None]: ...
@ -84,7 +84,7 @@ class PipelineGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
workflow_thread_pool_id: Optional[str],
workflow_thread_pool_id: str | None,
is_retry: bool = False,
) -> Mapping[str, Any]: ...
@ -99,7 +99,7 @@ class PipelineGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_thread_pool_id: Optional[str],
workflow_thread_pool_id: str | None,
is_retry: bool = False,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
@ -113,7 +113,7 @@ class PipelineGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
workflow_thread_pool_id: str | None = None,
is_retry: bool = False,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
# Add null check for dataset
@ -314,7 +314,7 @@ class PipelineGenerator(BaseAppGenerator):
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
workflow_thread_pool_id: Optional[str] = None,
workflow_thread_pool_id: str | None = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@ -331,7 +331,7 @@ class PipelineGenerator(BaseAppGenerator):
"""
with preserve_flask_contexts(flask_app, context_vars=context):
# init queue manager
workflow = db.session.query(Workflow).filter(Workflow.id == workflow_id).first()
workflow = db.session.query(Workflow).where(Workflow.id == workflow_id).first()
if not workflow:
raise ValueError(f"Workflow not found: {workflow_id}")
queue_manager = PipelineQueueManager(
@ -568,7 +568,7 @@ class PipelineGenerator(BaseAppGenerator):
queue_manager: AppQueueManager,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None,
workflow_thread_pool_id: str | None = None,
) -> None:
"""
Generate worker in a new thread.
@ -801,11 +801,11 @@ class PipelineGenerator(BaseAppGenerator):
self,
datasource_runtime: OnlineDriveDatasourcePlugin,
prefix: str,
bucket: Optional[str],
bucket: str | None,
user_id: str,
all_files: list,
datasource_info: Mapping[str, Any],
next_page_parameters: Optional[dict] = None,
next_page_parameters: dict | None = None,
):
"""
Get files in a folder.

View File

@ -1,6 +1,6 @@
import logging
import time
from typing import Optional, cast
from typing import cast
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
@ -40,7 +40,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
workflow_thread_pool_id: Optional[str] = None,
workflow_thread_pool_id: str | None = None,
) -> None:
"""
:param application_generate_entity: application generate entity
@ -69,13 +69,13 @@ class PipelineRunner(WorkflowBasedAppRunner):
user_id = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first()
pipeline = db.session.query(Pipeline).where(Pipeline.id == app_config.app_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
@ -188,7 +188,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
)
self._handle_event(workflow_entry, event)
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]:
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
"""
Get workflow
"""
@ -205,7 +205,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
return workflow
def _init_rag_pipeline_graph(
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: Optional[str] = None
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: str | None = None
) -> Graph:
"""
Init pipeline graph

View File

@ -242,7 +242,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
node_id: str
inputs: dict
single_loop_run: Optional[SingleLoopRunEntity] = None
single_loop_run: SingleLoopRunEntity | None = None
class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
@ -256,9 +256,9 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
datasource_info: Mapping[str, Any]
dataset_id: str
batch: str
document_id: Optional[str] = None
original_document_id: Optional[str] = None
start_node_id: Optional[str] = None
document_id: str | None = None
original_document_id: str | None = None
start_node_id: str | None = None
# Import TraceQueueManager at runtime to resolve forward references

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
@ -252,8 +252,8 @@ class NodeStartStreamResponse(StreamResponse):
node_type: str
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: str | None = None
inputs: Mapping[str, Any] | None = None
inputs_truncated: bool = False
created_at: int
extras: dict[str, object] = Field(default_factory=dict)
@ -310,12 +310,12 @@ class NodeFinishStreamResponse(StreamResponse):
node_type: str
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: str | None = None
inputs: Mapping[str, Any] | None = None
inputs_truncated: bool = False
process_data: Optional[Mapping[str, Any]] = None
process_data: Mapping[str, Any] | None = None
process_data_truncated: bool = False
outputs: Optional[Mapping[str, Any]] = None
outputs: Mapping[str, Any] | None = None
outputs_truncated: bool = True
status: str
error: str | None = None
@ -382,12 +382,12 @@ class NodeRetryStreamResponse(StreamResponse):
node_type: str
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: str | None = None
inputs: Mapping[str, Any] | None = None
inputs_truncated: bool = False
process_data: Optional[Mapping[str, Any]] = None
process_data: Mapping[str, Any] | None = None
process_data_truncated: bool = False
outputs: Optional[Mapping[str, Any]] = None
outputs: Mapping[str, Any] | None = None
outputs_truncated: bool = False
status: str
error: str | None = None
@ -503,11 +503,11 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
node_id: str
node_type: str
title: str
outputs: Optional[Mapping] = None
outputs: Mapping | None = None
outputs_truncated: bool = False
created_at: int
extras: Optional[dict] = None
inputs: Optional[Mapping] = None
extras: dict | None = None
inputs: Mapping | None = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus
error: str | None = None
@ -541,8 +541,8 @@ class LoopNodeStartStreamResponse(StreamResponse):
metadata: Mapping = {}
inputs: Mapping = {}
inputs_truncated: bool = False
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parallel_id: str | None = None
parallel_start_node_id: str | None = None
event: StreamEvent = StreamEvent.LOOP_STARTED
workflow_run_id: str
@ -590,11 +590,11 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
node_id: str
node_type: str
title: str
outputs: Optional[Mapping] = None
outputs: Mapping | None = None
outputs_truncated: bool = False
created_at: int
extras: Optional[dict] = None
inputs: Optional[Mapping] = None
extras: dict | None = None
inputs: Mapping | None = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus
error: str | None = None

View File

@ -17,9 +17,9 @@ class DatasourceRuntime(BaseModel):
"""
tenant_id: str
datasource_id: Optional[str] = None
datasource_id: str | None = None
invoke_from: Optional["InvokeFrom"] = None
datasource_invoke_from: Optional[DatasourceInvokeFrom] = None
datasource_invoke_from: DatasourceInvokeFrom | None = None
credentials: dict[str, Any] = Field(default_factory=dict)
runtime_parameters: dict[str, Any] = Field(default_factory=dict)

View File

@ -6,7 +6,7 @@ import os
import time
from datetime import datetime
from mimetypes import guess_extension, guess_type
from typing import Optional, Union
from typing import Union
from uuid import uuid4
import httpx
@ -62,10 +62,10 @@ class DatasourceFileManager:
*,
user_id: str,
tenant_id: str,
conversation_id: Optional[str],
conversation_id: str | None,
file_binary: bytes,
mimetype: str,
filename: Optional[str] = None,
filename: str | None = None,
) -> UploadFile:
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
@ -106,7 +106,7 @@ class DatasourceFileManager:
user_id: str,
tenant_id: str,
file_url: str,
conversation_id: Optional[str] = None,
conversation_id: str | None = None,
) -> ToolFile:
# try to download image
try:
@ -153,10 +153,7 @@ class DatasourceFileManager:
:return: the binary of the file, mime type
"""
upload_file: UploadFile | None = (
db.session.query(UploadFile)
.filter(
UploadFile.id == id,
)
db.session.query(UploadFile).where(UploadFile.id == id)
.first()
)
@ -177,10 +174,7 @@ class DatasourceFileManager:
:return: the binary of the file, mime type
"""
message_file: MessageFile | None = (
db.session.query(MessageFile)
.filter(
MessageFile.id == id,
)
db.session.query(MessageFile).where(MessageFile.id == id)
.first()
)
@ -197,10 +191,7 @@ class DatasourceFileManager:
tool_file_id = None
tool_file: ToolFile | None = (
db.session.query(ToolFile)
.filter(
ToolFile.id == tool_file_id,
)
db.session.query(ToolFile).where(ToolFile.id == tool_file_id)
.first()
)
@ -221,10 +212,7 @@ class DatasourceFileManager:
:return: the binary of the file, mime type
"""
upload_file: UploadFile | None = (
db.session.query(UploadFile)
.filter(
UploadFile.id == upload_file_id,
)
db.session.query(UploadFile).where(UploadFile.id == upload_file_id)
.first()
)

View File

@ -12,9 +12,9 @@ class DatasourceApiEntity(BaseModel):
name: str # identifier
label: I18nObject # label
description: I18nObject
parameters: Optional[list[DatasourceParameter]] = None
parameters: list[DatasourceParameter] | None = None
labels: list[str] = Field(default_factory=list)
output_schema: Optional[dict] = None
output_schema: dict | None = None
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
@ -28,12 +28,12 @@ class DatasourceProviderApiEntity(BaseModel):
icon: str | dict
label: I18nObject # label
type: str
masked_credentials: Optional[dict] = None
original_credentials: Optional[dict] = None
masked_credentials: dict | None = None
original_credentials: dict | None = None
is_team_authorization: bool = False
allow_delete: bool = True
plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource")
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource")
plugin_id: str | None = Field(default="", description="The plugin id of the datasource")
plugin_unique_identifier: str | None = Field(default="", description="The unique identifier of the datasource")
datasources: list[DatasourceApiEntity] = Field(default_factory=list)
labels: list[str] = Field(default_factory=list)

View File

@ -1,4 +1,3 @@
from typing import Optional
from pydantic import BaseModel, Field
@ -9,9 +8,9 @@ class I18nObject(BaseModel):
"""
en_US: str
zh_Hans: Optional[str] = Field(default=None)
pt_BR: Optional[str] = Field(default=None)
ja_JP: Optional[str] = Field(default=None)
zh_Hans: str | None = Field(default=None)
pt_BR: str | None = Field(default=None)
ja_JP: str | None = Field(default=None)
def __init__(self, **data):
super().__init__(**data)

View File

@ -1,6 +1,6 @@
import enum
from enum import Enum
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from yarl import URL
@ -80,7 +80,7 @@ class DatasourceParameter(PluginParameter):
name: str,
typ: DatasourceParameterType,
required: bool,
options: Optional[list[str]] = None,
options: list[str] | None = None,
) -> "DatasourceParameter":
"""
get a simple datasource parameter
@ -120,14 +120,14 @@ class DatasourceIdentity(BaseModel):
name: str = Field(..., description="The name of the datasource")
label: I18nObject = Field(..., description="The label of the datasource")
provider: str = Field(..., description="The provider of the datasource")
icon: Optional[str] = None
icon: str | None = None
class DatasourceEntity(BaseModel):
identity: DatasourceIdentity
parameters: list[DatasourceParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The label of the datasource")
output_schema: Optional[dict] = None
output_schema: dict | None = None
@field_validator("parameters", mode="before")
@classmethod
@ -141,7 +141,7 @@ class DatasourceProviderIdentity(BaseModel):
description: I18nObject = Field(..., description="The description of the tool")
icon: str = Field(..., description="The icon of the tool")
label: I18nObject = Field(..., description="The label of the tool")
tags: Optional[list[ToolLabelEnum]] = Field(
tags: list[ToolLabelEnum] | None = Field(
default=[],
description="The tags of the tool",
)
@ -169,7 +169,7 @@ class DatasourceProviderEntity(BaseModel):
identity: DatasourceProviderIdentity
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
oauth_schema: Optional[OAuthSchema] = None
oauth_schema: OAuthSchema | None = None
provider_type: DatasourceProviderType
@ -183,8 +183,8 @@ class DatasourceInvokeMeta(BaseModel):
"""
time_cost: float = Field(..., description="The time cost of the tool invoke")
error: Optional[str] = None
tool_config: Optional[dict] = None
error: str | None = None
tool_config: dict | None = None
@classmethod
def empty(cls) -> "DatasourceInvokeMeta":
@ -233,10 +233,10 @@ class OnlineDocumentPage(BaseModel):
page_id: str = Field(..., description="The page id")
page_name: str = Field(..., description="The page title")
page_icon: Optional[dict] = Field(None, description="The page icon")
page_icon: dict | None = Field(None, description="The page icon")
type: str = Field(..., description="The type of the page")
last_edited_time: str = Field(..., description="The last edited time")
parent_id: Optional[str] = Field(None, description="The parent page id")
parent_id: str | None = Field(None, description="The parent page id")
class OnlineDocumentInfo(BaseModel):
@ -244,9 +244,9 @@ class OnlineDocumentInfo(BaseModel):
Online document info
"""
workspace_id: Optional[str] = Field(None, description="The workspace id")
workspace_name: Optional[str] = Field(None, description="The workspace name")
workspace_icon: Optional[str] = Field(None, description="The workspace icon")
workspace_id: str | None = Field(None, description="The workspace id")
workspace_name: str | None = Field(None, description="The workspace name")
workspace_icon: str | None = Field(None, description="The workspace icon")
total: int = Field(..., description="The total number of documents")
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")
@ -307,10 +307,10 @@ class WebSiteInfo(BaseModel):
Website info
"""
status: Optional[str] = Field(..., description="crawl job status")
web_info_list: Optional[list[WebSiteInfoDetail]] = []
total: Optional[int] = Field(default=0, description="The total number of websites")
completed: Optional[int] = Field(default=0, description="The number of completed websites")
status: str | None = Field(..., description="crawl job status")
web_info_list: list[WebSiteInfoDetail] | None = []
total: int | None = Field(default=0, description="The total number of websites")
completed: int | None = Field(default=0, description="The number of completed websites")
class WebsiteCrawlMessage(BaseModel):
@ -346,10 +346,10 @@ class OnlineDriveFileBucket(BaseModel):
Online drive file bucket
"""
bucket: Optional[str] = Field(None, description="The file bucket")
bucket: str | None = Field(None, description="The file bucket")
files: list[OnlineDriveFile] = Field(..., description="The file list")
is_truncated: bool = Field(False, description="Whether the result is truncated")
next_page_parameters: Optional[dict] = Field(None, description="Parameters for fetching the next page")
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
class OnlineDriveBrowseFilesRequest(BaseModel):
@ -357,10 +357,10 @@ class OnlineDriveBrowseFilesRequest(BaseModel):
Get online drive file list request
"""
bucket: Optional[str] = Field(None, description="The file bucket")
bucket: str | None = Field(None, description="The file bucket")
prefix: str = Field(..., description="The parent folder ID")
max_keys: int = Field(20, description="Page size for pagination")
next_page_parameters: Optional[dict] = Field(None, description="Parameters for fetching the next page")
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
class OnlineDriveBrowseFilesResponse(BaseModel):
@ -377,4 +377,4 @@ class OnlineDriveDownloadFileRequest(BaseModel):
"""
id: str = Field(..., description="The id of the file")
bucket: Optional[str] = Field(None, description="The name of the bucket")
bucket: str | None = Field(None, description="The name of the bucket")

View File

@ -1,7 +1,6 @@
import logging
from collections.abc import Generator
from mimetypes import guess_extension, guess_type
from typing import Optional
from core.datasource.entities.datasource_entities import DatasourceMessage
from core.file import File, FileTransferMethod, FileType
@ -17,7 +16,7 @@ class DatasourceFileMessageTransformer:
messages: Generator[DatasourceMessage, None, None],
user_id: str,
tenant_id: str,
conversation_id: Optional[str] = None,
conversation_id: str | None = None,
) -> Generator[DatasourceMessage, None, None]:
"""
Transform datasource message and handle file download
@ -121,5 +120,5 @@ class DatasourceFileMessageTransformer:
yield message
@classmethod
def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str:
def get_datasource_file_url(cls, datasource_file_id: str, extension: str | None) -> str:
return f"/files/datasources/{datasource_file_id}{extension or '.bin'}"

View File

@ -3,7 +3,6 @@ import uuid
from json import dumps as json_dumps
from json import loads as json_loads
from json.decoder import JSONDecodeError
from typing import Optional
from flask import request
from requests import get
@ -169,9 +168,9 @@ class ApiBasedToolSchemaParser:
return bundles
@staticmethod
def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
parameter = parameter or {}
typ: Optional[str] = None
typ: str | None = None
if parameter.get("format") == "binary":
return ToolParameter.ToolParameterType.FILE

View File

@ -1,4 +1,3 @@
from typing import Optional
from pydantic import BaseModel
@ -16,7 +15,7 @@ class QAPreviewDetail(BaseModel):
class IndexingEstimate(BaseModel):
total_segments: int
preview: list[PreviewDetail]
qa_preview: Optional[list[QAPreviewDetail]] = None
qa_preview: list[QAPreviewDetail] | None = None
class PipelineDataset(BaseModel):
@ -30,10 +29,10 @@ class PipelineDocument(BaseModel):
id: str
position: int
data_source_type: str
data_source_info: Optional[dict] = None
data_source_info: dict | None = None
name: str
indexing_status: str
error: Optional[str] = None
error: str | None = None
enabled: bool

View File

@ -1,7 +1,7 @@
import datetime
from collections.abc import Mapping
from enum import StrEnum, auto
from typing import Any, Optional
from typing import Any
from packaging.version import InvalidVersion, Version
from pydantic import BaseModel, Field, field_validator, model_validator
@ -67,10 +67,10 @@ class PluginCategory(StrEnum):
class PluginDeclaration(BaseModel):
class Plugins(BaseModel):
tools: Optional[list[str]] = Field(default_factory=list[str])
models: Optional[list[str]] = Field(default_factory=list[str])
endpoints: Optional[list[str]] = Field(default_factory=list[str])
datasources: Optional[list[str]] = Field(default_factory=list[str])
tools: list[str] | None = Field(default_factory=list[str])
models: list[str] | None = Field(default_factory=list[str])
endpoints: list[str] | None = Field(default_factory=list[str])
datasources: list[str] | None = Field(default_factory=list[str])
class Meta(BaseModel):
minimum_dify_version: str | None = Field(default=None)
@ -101,11 +101,11 @@ class PluginDeclaration(BaseModel):
tags: list[str] = Field(default_factory=list)
repo: str | None = Field(default=None)
verified: bool = Field(default=False)
tool: Optional[ToolProviderEntity] = None
model: Optional[ProviderEntity] = None
endpoint: Optional[EndpointProviderDeclaration] = None
agent_strategy: Optional[AgentStrategyProviderEntity] = None
datasource: Optional[DatasourceProviderEntity] = None
tool: ToolProviderEntity | None = None
model: ProviderEntity | None = None
endpoint: EndpointProviderDeclaration | None = None
agent_strategy: AgentStrategyProviderEntity | None = None
datasource: DatasourceProviderEntity | None = None
meta: Meta
@field_validator("version")

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping
from enum import Enum
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel, Field
@ -27,12 +27,12 @@ class DatasourceErrorEvent(BaseDatasourceEvent):
class DatasourceCompletedEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.COMPLETED.value
data: Mapping[str, Any] | list = Field(..., description="result")
total: Optional[int] = Field(default=0, description="total")
completed: Optional[int] = Field(default=0, description="completed")
time_consuming: Optional[float] = Field(default=0.0, description="time consuming")
total: int | None = Field(default=0, description="total")
completed: int | None = Field(default=0, description="completed")
time_consuming: float | None = Field(default=0.0, description="time consuming")
class DatasourceProcessingEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.PROCESSING.value
total: Optional[int] = Field(..., description="total")
completed: Optional[int] = Field(..., description="completed")
total: int | None = Field(..., description="total")
completed: int | None = Field(..., description="completed")

View File

@ -1,4 +1,3 @@
from typing import Optional
from pydantic import BaseModel, ConfigDict
@ -11,7 +10,7 @@ class NotionInfo(BaseModel):
Notion import info.
"""
credential_id: Optional[str] = None
credential_id: str | None = None
notion_workspace_id: str
notion_obj_id: str
notion_page_type: str

View File

@ -1,7 +1,7 @@
import json
import logging
import operator
from typing import Any, Optional, cast
from typing import Any, cast
import requests
@ -35,9 +35,9 @@ class NotionExtractor(BaseExtractor):
notion_obj_id: str,
notion_page_type: str,
tenant_id: str,
document_model: Optional[DocumentModel] = None,
notion_access_token: Optional[str] = None,
credential_id: Optional[str] = None,
document_model: DocumentModel | None = None,
notion_access_token: str | None = None,
credential_id: str | None = None,
):
self._notion_access_token = None
self._document_model = document_model
@ -369,7 +369,7 @@ class NotionExtractor(BaseExtractor):
return cast(str, data["last_edited_time"])
@classmethod
def _get_access_token(cls, tenant_id: str, credential_id: Optional[str]) -> str:
def _get_access_token(cls, tenant_id: str, credential_id: str | None) -> str:
# get credential from tenant_id and credential_id
if not credential_id:
raise Exception(f"No credential id found for tenant {tenant_id}")

View File

@ -7,7 +7,7 @@ import json
import logging
from collections.abc import Callable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional, TypeVar, Union
from typing import Any, TypeVar, Union
import psycopg2.errors
from sqlalchemy import UnaryExpression, asc, desc, select
@ -530,7 +530,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
def get_by_workflow_run(
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
order_config: OrderConfig | None = None,
triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
) -> Sequence[WorkflowNodeExecution]:
"""

View File

@ -85,7 +85,7 @@ class SchemaRegistry:
except (OSError, json.JSONDecodeError) as e:
print(f"Warning: failed to load schema {version}/{schema_name}: {e}")
def get_schema(self, uri: str) -> Optional[Any]:
def get_schema(self, uri: str) -> Any | None:
"""Retrieves a schema by URI with version support"""
version, schema_name = self._parse_uri(uri)
if not version or not schema_name:

View File

@ -3,7 +3,7 @@ import re
import threading
from collections import deque
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any, Union
from core.schemas.registry import SchemaRegistry
@ -53,8 +53,8 @@ class QueueItem:
"""Represents an item in the BFS queue"""
current: Any
parent: Optional[Any]
key: Optional[Union[str, int]]
parent: Any | None
key: Union[str, int] | None
depth: int
ref_path: set[str]
@ -65,7 +65,7 @@ class SchemaResolver:
_cache: dict[str, SchemaDict] = {}
_cache_lock = threading.Lock()
def __init__(self, registry: Optional[SchemaRegistry] = None, max_depth: int = 10):
def __init__(self, registry: SchemaRegistry | None = None, max_depth: int = 10):
"""
Initialize the schema resolver
@ -202,7 +202,7 @@ class SchemaResolver:
)
)
def _get_resolved_schema(self, ref_uri: str) -> Optional[SchemaDict]:
def _get_resolved_schema(self, ref_uri: str) -> SchemaDict | None:
"""Get resolved schema from cache or registry"""
# Check cache first
with self._cache_lock:
@ -223,7 +223,7 @@ class SchemaResolver:
def resolve_dify_schema_refs(
schema: SchemaType, registry: Optional[SchemaRegistry] = None, max_depth: int = 30
schema: SchemaType, registry: SchemaRegistry | None = None, max_depth: int = 30
) -> SchemaType:
"""
Resolve $ref references in Dify schema to actual schema content

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.schemas.registry import SchemaRegistry
@ -7,7 +7,7 @@ from core.schemas.registry import SchemaRegistry
class SchemaManager:
"""Schema manager provides high-level schema operations"""
def __init__(self, registry: Optional[SchemaRegistry] = None):
def __init__(self, registry: SchemaRegistry | None = None):
self.registry = registry or SchemaRegistry.default_registry()
def get_all_schema_definitions(self, version: str = "v1") -> list[Mapping[str, Any]]:
@ -22,7 +22,7 @@ class SchemaManager:
"""
return self.registry.get_all_schemas_for_version(version)
def get_schema_by_name(self, schema_name: str, version: str = "v1") -> Optional[Mapping[str, Any]]:
def get_schema_by_name(self, schema_name: str, version: str = "v1") -> Mapping[str, Any] | None:
"""
Get a specific schema by name

View File

@ -8,7 +8,7 @@ and don't contain implementation details like tenant_id, app_id, etc.
from collections.abc import Mapping
from datetime import datetime
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel, Field, PrivateAttr
@ -53,9 +53,9 @@ class WorkflowNodeExecution(BaseModel):
# Execution data
# The `inputs` and `outputs` fields hold the full content
inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node
process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data
outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node
inputs: Mapping[str, Any] | None = None # Input variables used by this node
process_data: Mapping[str, Any] | None = None # Intermediate processing data
outputs: Mapping[str, Any] | None = None # Output variables produced by this node
# Execution state
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -50,7 +50,7 @@ class DatasourceNode(Node):
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = DatasourceNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -59,7 +59,7 @@ class DatasourceNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -179,7 +179,7 @@ class DatasourceNode(Node):
related_id = datasource_info.get("related_id")
if not related_id:
raise DatasourceNodeError("File is not exist")
upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first()
upload_file = db.session.query(UploadFile).where(UploadFile.id == related_id).first()
if not upload_file:
raise ValueError("Invalid upload file Info")

View File

@ -1,4 +1,4 @@
from typing import Any, Literal, Optional, Union
from typing import Any, Literal, Union
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
@ -10,7 +10,7 @@ class DatasourceEntity(BaseModel):
plugin_id: str
provider_name: str # redundancy
provider_type: str
datasource_name: Optional[str] = "local_file"
datasource_name: str | None = "local_file"
datasource_configurations: dict[str, Any] | None = None
plugin_unique_identifier: str | None = None # redundancy
@ -19,7 +19,7 @@ class DatasourceNodeData(BaseNodeData, DatasourceEntity):
class DatasourceInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Optional[Literal["mixed", "variable", "constant"]] = None
type: Literal["mixed", "variable", "constant"] | None = None
@field_validator("type", mode="before")
@classmethod

View File

@ -1,4 +1,4 @@
from typing import Literal, Optional, Union
from typing import Literal, Union
from pydantic import BaseModel
@ -65,12 +65,12 @@ class RetrievalSetting(BaseModel):
search_method: Literal["semantic_search", "keyword_search", "fulltext_search", "hybrid_search"]
top_k: int
score_threshold: Optional[float] = 0.5
score_threshold: float | None = 0.5
score_threshold_enabled: bool = False
reranking_mode: str = "reranking_model"
reranking_enable: bool = True
reranking_model: Optional[RerankingModelConfig] = None
weights: Optional[WeightedScoreConfig] = None
reranking_model: RerankingModelConfig | None = None
weights: WeightedScoreConfig | None = None
class IndexMethod(BaseModel):
@ -107,10 +107,10 @@ class OnlineDocumentInfo(BaseModel):
"""
provider: str
workspace_id: Optional[str] = None
workspace_id: str | None = None
page_id: str
page_type: str
icon: Optional[OnlineDocumentIcon] = None
icon: OnlineDocumentIcon | None = None
class WebsiteInfo(BaseModel):

View File

@ -2,7 +2,7 @@ import datetime
import logging
import time
from collections.abc import Mapping
from typing import Any, Optional, cast
from typing import Any, cast
from sqlalchemy import func, select
@ -43,7 +43,7 @@ class KnowledgeIndexNode(Node):
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = KnowledgeIndexNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -52,7 +52,7 @@ class KnowledgeIndexNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:

View File

@ -10,7 +10,7 @@ import re
import time
from datetime import datetime
from json import JSONDecodeError
from typing import Any, Optional, cast
from typing import Any, cast
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func, select
@ -76,7 +76,7 @@ class Dataset(Base):
@property
def total_documents(self):
return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
@property
def total_available_documents(self):
@ -173,10 +173,10 @@ class Dataset(Base):
)
@property
def doc_form(self) -> Optional[str]:
def doc_form(self) -> str | None:
if self.chunk_structure:
return self.chunk_structure
document = db.session.query(Document).filter(Document.dataset_id == self.id).first()
document = db.session.query(Document).where(Document.dataset_id == self.id).first()
if document:
return document.doc_form
return None
@ -234,7 +234,7 @@ class Dataset(Base):
@property
def is_published(self):
if self.pipeline_id:
pipeline = db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first()
pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first()
if pipeline:
return pipeline.is_published
return False
@ -1244,7 +1244,7 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
@property
def created_user_name(self):
account = db.session.query(Account).filter(Account.id == self.created_by).first()
account = db.session.query(Account).where(Account.id == self.created_by).first()
if account:
return account.name
return ""
@ -1274,7 +1274,7 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
@property
def created_user_name(self):
account = db.session.query(Account).filter(Account.id == self.created_by).first()
account = db.session.query(Account).where(Account.id == self.created_by).first()
if account:
return account.name
return ""
@ -1297,7 +1297,7 @@ class Pipeline(Base): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
def retrieve_dataset(self, session: Session):
return session.query(Dataset).filter(Dataset.pipeline_id == self.id).first()
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
class DocumentPipelineExecutionLog(Base):

View File

@ -1546,7 +1546,7 @@ class WorkflowDraftVariableFile(Base):
comment="Size of the original variable content in bytes",
)
length: Mapped[Optional[int]] = mapped_column(
length: Mapped[int | None] = mapped_column(
sa.Integer,
nullable=True,
comment=(

View File

@ -7,7 +7,7 @@ import time
import uuid
from collections import Counter
from collections.abc import Sequence
from typing import Any, Literal, Optional
from typing import Any, Literal
import sqlalchemy as sa
from sqlalchemy import exists, func, select
@ -315,8 +315,8 @@ class DatasetService:
return dataset
@staticmethod
def get_dataset(dataset_id) -> Optional[Dataset]:
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
def get_dataset(dataset_id) -> Dataset | None:
dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first()
return dataset
@staticmethod

View File

@ -1,13 +1,13 @@
from typing import Literal, Optional
from typing import Literal
from pydantic import BaseModel, field_validator
class IconInfo(BaseModel):
icon: str
icon_background: Optional[str] = None
icon_type: Optional[str] = None
icon_url: Optional[str] = None
icon_background: str | None = None
icon_type: str | None = None
icon_url: str | None = None
class PipelineTemplateInfoEntity(BaseModel):
@ -21,8 +21,8 @@ class RagPipelineDatasetCreateEntity(BaseModel):
description: str
icon_info: IconInfo
permission: str
partial_member_list: Optional[list[str]] = None
yaml_content: Optional[str] = None
partial_member_list: list[str] | None = None
yaml_content: str | None = None
class RerankingModelConfig(BaseModel):
@ -30,8 +30,8 @@ class RerankingModelConfig(BaseModel):
Reranking Model Config.
"""
reranking_provider_name: Optional[str] = ""
reranking_model_name: Optional[str] = ""
reranking_provider_name: str | None = ""
reranking_model_name: str | None = ""
class VectorSetting(BaseModel):
@ -57,8 +57,8 @@ class WeightedScoreConfig(BaseModel):
Weighted score Config.
"""
vector_setting: Optional[VectorSetting]
keyword_setting: Optional[KeywordSetting]
vector_setting: VectorSetting | None
keyword_setting: KeywordSetting | None
class EmbeddingSetting(BaseModel):
@ -85,12 +85,12 @@ class RetrievalSetting(BaseModel):
search_method: Literal["semantic_search", "fulltext_search", "keyword_search", "hybrid_search"]
top_k: int
score_threshold: Optional[float] = 0.5
score_threshold: float | None = 0.5
score_threshold_enabled: bool = False
reranking_mode: Optional[str] = "reranking_model"
reranking_enable: Optional[bool] = True
reranking_model: Optional[RerankingModelConfig] = None
weights: Optional[WeightedScoreConfig] = None
reranking_mode: str | None = "reranking_model"
reranking_enable: bool | None = True
reranking_model: RerankingModelConfig | None = None
weights: WeightedScoreConfig | None = None
class IndexMethod(BaseModel):
@ -112,7 +112,7 @@ class KnowledgeConfiguration(BaseModel):
indexing_technique: Literal["high_quality", "economy"]
embedding_model_provider: str = ""
embedding_model: str = ""
keyword_number: Optional[int] = 10
keyword_number: int | None = 10
retrieval_model: RetrievalSetting
@field_validator("embedding_model_provider", mode="before")

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel
@ -9,7 +9,7 @@ class DatasourceNodeRunApiEntity(BaseModel):
node_id: str
inputs: Mapping[str, Any]
datasource_type: str
credential_id: Optional[str] = None
credential_id: str | None = None
is_published: bool

View File

@ -108,7 +108,7 @@ class PipelineGenerateService:
Update document status to waiting
:param document_id: document id
"""
document = db.session.query(Document).filter(Document.id == document_id).first()
document = db.session.query(Document).where(Document.id == document_id).first()
if document:
document.indexing_status = "waiting"
db.session.add(document)

View File

@ -1,7 +1,6 @@
import json
from os import path
from pathlib import Path
from typing import Optional
from flask import current_app
@ -14,7 +13,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
Retrieval pipeline template from built-in, the location is constants/pipeline_templates.json
"""
builtin_data: Optional[dict] = None
builtin_data: dict | None = None
def get_type(self) -> str:
return PipelineTemplateType.BUILTIN
@ -54,7 +53,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return builtin_data.get("pipeline_templates", {}).get(language, {})
@classmethod
def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> Optional[dict]:
def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> dict | None:
"""
Fetch pipeline template detail from builtin.
:param template_id: Template ID

View File

@ -1,4 +1,3 @@
from typing import Optional
import yaml
from flask_login import current_user
@ -56,14 +55,14 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return {"pipeline_templates": recommended_pipelines_results}
@classmethod
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]:
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None:
"""
Fetch pipeline template detail from db.
:param template_id: Template ID
:return:
"""
pipeline_template = (
db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
db.session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
)
if not pipeline_template:
return None

View File

@ -1,4 +1,3 @@
from typing import Optional
import yaml
@ -33,7 +32,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
pipeline_built_in_templates: list[PipelineBuiltInTemplate] = (
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all()
db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.language == language).all()
)
recommended_pipelines_results = []
@ -53,7 +52,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return {"pipeline_templates": recommended_pipelines_results}
@classmethod
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]:
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None:
"""
Fetch pipeline template detail from db.
:param pipeline_id: Pipeline ID
@ -61,7 +60,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
# is in public recommended list
pipeline_template = (
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == template_id).first()
db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.id == template_id).first()
)
if not pipeline_template:

View File

@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from typing import Optional
class PipelineTemplateRetrievalBase(ABC):
@ -10,7 +9,7 @@ class PipelineTemplateRetrievalBase(ABC):
raise NotImplementedError
@abstractmethod
def get_pipeline_template_detail(self, template_id: str) -> Optional[dict]:
def get_pipeline_template_detail(self, template_id: str) -> dict | None:
raise NotImplementedError
@abstractmethod

View File

@ -1,5 +1,4 @@
import logging
from typing import Optional
import requests
@ -36,7 +35,7 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return PipelineTemplateType.REMOTE
@classmethod
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> Optional[dict]:
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict | None:
"""
Fetch pipeline template detail from dify official.
:param template_id: Pipeline ID

View File

@ -5,7 +5,7 @@ import threading
import time
from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Optional, Union, cast
from typing import Any, Union, cast
from uuid import uuid4
from flask_login import current_user
@ -112,7 +112,7 @@ class RagPipelineService:
return result
@classmethod
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> Optional[dict]:
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None:
"""
Get pipeline template detail.
:param template_id: template id
@ -121,12 +121,12 @@ class RagPipelineService:
if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
built_in_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id)
return built_in_result
else:
mode = "customized"
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
customized_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
customized_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id)
return customized_result
@classmethod
@ -185,7 +185,7 @@ class RagPipelineService:
db.session.delete(customized_template)
db.session.commit()
def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
def get_draft_workflow(self, pipeline: Pipeline) -> Workflow | None:
"""
Get draft workflow
"""
@ -203,7 +203,7 @@ class RagPipelineService:
# return draft workflow
return workflow
def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
def get_published_workflow(self, pipeline: Pipeline) -> Workflow | None:
"""
Get published workflow
"""
@ -267,7 +267,7 @@ class RagPipelineService:
*,
pipeline: Pipeline,
graph: dict,
unique_hash: Optional[str],
unique_hash: str | None,
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
@ -387,9 +387,7 @@ class RagPipelineService:
return default_block_configs
def get_default_block_config(
self, node_type: str, filters: Optional[dict] = None
) -> Optional[Mapping[str, object]]:
def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None:
"""
Get default config of node.
:param node_type: node type
@ -495,7 +493,7 @@ class RagPipelineService:
account: Account,
datasource_type: str,
is_published: bool,
credential_id: Optional[str] = None,
credential_id: str | None = None,
) -> Generator[Mapping[str, Any], None, None]:
"""
Run published workflow datasource
@ -661,7 +659,7 @@ class RagPipelineService:
account: Account,
datasource_type: str,
is_published: bool,
credential_id: Optional[str] = None,
credential_id: str | None = None,
) -> Mapping[str, Any]:
"""
Run published workflow datasource
@ -876,7 +874,7 @@ class RagPipelineService:
if invoke_from.value == InvokeFrom.PUBLISHED.value:
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if document_id:
document = db.session.query(Document).filter(Document.id == document_id.value).first()
document = db.session.query(Document).where(Document.id == document_id.value).first()
if document:
document.indexing_status = "error"
document.error = error
@ -887,7 +885,7 @@ class RagPipelineService:
def update_workflow(
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
) -> Optional[Workflow]:
) -> Workflow | None:
"""
Update workflow attributes
@ -1057,7 +1055,7 @@ class RagPipelineService:
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]:
def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> WorkflowRun | None:
"""
Get workflow run detail
@ -1113,12 +1111,12 @@ class RagPipelineService:
"""
Publish customized pipeline template
"""
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
pipeline = db.session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
if not pipeline.workflow_id:
raise ValueError("Pipeline workflow not found")
workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
workflow = db.session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError("Workflow not found")
with Session(db.engine) as session:
@ -1142,7 +1140,7 @@ class RagPipelineService:
max_position = (
db.session.query(func.max(PipelineCustomizedTemplate.position))
.filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
.where(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
.scalar()
)
@ -1278,7 +1276,7 @@ class RagPipelineService:
# Query active recommended plugins
pipeline_recommended_plugins = (
db.session.query(PipelineRecommendedPlugin)
.filter(PipelineRecommendedPlugin.active == True)
.where(PipelineRecommendedPlugin.active == True)
.order_by(PipelineRecommendedPlugin.position.asc())
.all()
)
@ -1329,12 +1327,12 @@ class RagPipelineService:
"""
document_pipeline_excution_log = (
db.session.query(DocumentPipelineExecutionLog)
.filter(DocumentPipelineExecutionLog.document_id == document.id)
.where(DocumentPipelineExecutionLog.document_id == document.id)
.first()
)
if not document_pipeline_excution_log:
raise ValueError("Document pipeline execution log not found")
pipeline = db.session.query(Pipeline).filter(Pipeline.id == document_pipeline_excution_log.pipeline_id).first()
pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_excution_log.pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
# convert to app config

View File

@ -6,7 +6,7 @@ import uuid
from collections.abc import Mapping
from datetime import UTC, datetime
from enum import StrEnum
from typing import Optional, cast
from typing import cast
from urllib.parse import urlparse
from uuid import uuid4
@ -66,11 +66,11 @@ class ImportStatus(StrEnum):
class RagPipelineImportInfo(BaseModel):
id: str
status: ImportStatus
pipeline_id: Optional[str] = None
pipeline_id: str | None = None
current_dsl_version: str = CURRENT_DSL_VERSION
imported_dsl_version: str = ""
error: str = ""
dataset_id: Optional[str] = None
dataset_id: str | None = None
class CheckDependenciesResult(BaseModel):
@ -121,12 +121,12 @@ class RagPipelineDslService:
*,
account: Account,
import_mode: str,
yaml_content: Optional[str] = None,
yaml_url: Optional[str] = None,
pipeline_id: Optional[str] = None,
dataset: Optional[Dataset] = None,
dataset_name: Optional[str] = None,
icon_info: Optional[IconInfo] = None,
yaml_content: str | None = None,
yaml_url: str | None = None,
pipeline_id: str | None = None,
dataset: Dataset | None = None,
dataset_name: str | None = None,
icon_info: IconInfo | None = None,
) -> RagPipelineImportInfo:
"""Import an app from YAML content or URL."""
import_id = str(uuid.uuid4())
@ -530,10 +530,10 @@ class RagPipelineDslService:
def _create_or_update_pipeline(
self,
*,
pipeline: Optional[Pipeline],
pipeline: Pipeline | None,
data: dict,
account: Account,
dependencies: Optional[list[PluginDependency]] = None,
dependencies: list[PluginDependency] | None = None,
) -> Pipeline:
"""Create a new app or update an existing one."""
if not account.current_tenant_id:

View File

@ -1,7 +1,6 @@
import json
from datetime import UTC, datetime
from pathlib import Path
from typing import Optional
from uuid import uuid4
import yaml
@ -21,7 +20,7 @@ from services.plugin.plugin_service import PluginService
class RagPipelineTransformService:
def transform_dataset(self, dataset_id: str):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset not found")
if dataset.pipeline_id and dataset.runtime_mode == "rag_pipeline":
@ -90,7 +89,7 @@ class RagPipelineTransformService:
"status": "success",
}
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: Optional[str]):
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None):
pipeline_yaml = {}
if doc_form == "text_model":
match datasource_type:
@ -152,7 +151,7 @@ class RagPipelineTransformService:
return node
def _deal_knowledge_index(
self, dataset: Dataset, doc_form: str, indexing_technique: Optional[str], retrieval_model: dict, node: dict
self, dataset: Dataset, doc_form: str, indexing_technique: str | None, retrieval_model: dict, node: dict
):
knowledge_configuration_dict = node.get("data", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration_dict)
@ -289,7 +288,7 @@ class RagPipelineTransformService:
jina_node_id = "1752491761974"
firecrawl_node_id = "1752565402678"
documents = db.session.query(Document).filter(Document.dataset_id == dataset.id).all()
documents = db.session.query(Document).where(Document.dataset_id == dataset.id).all()
for document in documents:
data_source_info_dict = document.data_source_info_dict
@ -299,7 +298,7 @@ class RagPipelineTransformService:
document.data_source_type = "local_file"
file_id = data_source_info_dict.get("upload_file_id")
if file_id:
file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if file:
data_source_info = json.dumps(
{

View File

@ -1,6 +1,5 @@
import logging
import time
from typing import Optional
import click
from celery import shared_task
@ -17,7 +16,7 @@ logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: Optional[str], file_ids: list[str]):
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]):
"""
Clean document when document deleted.
:param document_ids: document ids

View File

@ -44,7 +44,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update(
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
@ -76,12 +76,12 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
# clean keywords
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
index_processor.load(dataset, documents, with_keywords=False)
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
@ -100,7 +100,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update(
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
@ -148,12 +148,12 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()

View File

@ -104,20 +104,20 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
with Session(db.engine, expire_on_commit=False) as session:
# Load required entities
account = session.query(Account).filter(Account.id == user_id).first()
account = session.query(Account).where(Account.id == user_id).first()
if not account:
raise ValueError(f"Account {user_id} not found")
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
tenant = session.query(Tenant).where(Tenant.id == tenant_id).first()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
account.current_tenant = tenant
pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
if not pipeline:
raise ValueError(f"Pipeline {pipeline_id} not found")
workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError(f"Workflow {pipeline.workflow_id} not found")

View File

@ -125,20 +125,20 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
with Session(db.engine) as session:
# Load required entities
account = session.query(Account).filter(Account.id == user_id).first()
account = session.query(Account).where(Account.id == user_id).first()
if not account:
raise ValueError(f"Account {user_id} not found")
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
tenant = session.query(Tenant).where(Tenant.id == tenant_id).first()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
account.current_tenant = tenant
pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
if not pipeline:
raise ValueError(f"Pipeline {pipeline_id} not found")
workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError(f"Workflow {pipeline.workflow_id} not found")

View File

@ -38,7 +38,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
if not user:
logger.info(click.style(f"User not found: {user_id}", fg="red"))
return
tenant = db.session.query(Tenant).filter(Tenant.id == dataset.tenant_id).first()
tenant = db.session.query(Tenant).where(Tenant.id == dataset.tenant_id).first()
if not tenant:
raise ValueError("Tenant not found")
user.current_tenant = tenant