mirror of https://github.com/langgenius/dify.git
chore(api): apply autofix manully
This commit is contained in:
parent
73d4bb596a
commit
eefcd3ecc4
|
|
@ -104,7 +104,7 @@ class CustomizedPipelineTemplateApi(Resource):
|
|||
def post(self, template_id: str):
|
||||
with Session(db.engine) as session:
|
||||
template = (
|
||||
session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
|
||||
session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
|
||||
)
|
||||
if not template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from controllers.console.datasets.error import PipelineNotFoundError
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -10,7 +9,7 @@ from models.dataset import Pipeline
|
|||
|
||||
|
||||
def get_rag_pipeline(
|
||||
view: Optional[Callable] = None,
|
||||
view: Callable | None = None,
|
||||
):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
|
@ -114,9 +114,9 @@ class VariableEntity(BaseModel):
|
|||
hide: bool = False
|
||||
max_length: int | None = None
|
||||
options: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_types: Optional[Sequence[FileType]] = Field(default_factory=list)
|
||||
allowed_file_extensions: Optional[Sequence[str]] = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Optional[Sequence[FileTransferMethod]] = Field(default_factory=list)
|
||||
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -134,8 +134,8 @@ class RagPipelineVariableEntity(VariableEntity):
|
|||
Rag Pipeline Variable Entity.
|
||||
"""
|
||||
|
||||
tooltips: Optional[str] = None
|
||||
placeholder: Optional[str] = None
|
||||
tooltips: str | None = None
|
||||
placeholder: str | None = None
|
||||
belong_to_node_id: str
|
||||
|
||||
|
||||
|
|
@ -298,7 +298,7 @@ class AppConfig(BaseModel):
|
|||
tenant_id: str
|
||||
app_id: str
|
||||
app_mode: AppMode
|
||||
additional_features: Optional[AppAdditionalFeatures] = None
|
||||
additional_features: AppAdditionalFeatures | None = None
|
||||
variables: list[VariableEntity] = []
|
||||
sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import threading
|
|||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Optional, Union, cast, overload
|
||||
from typing import Any, Literal, Union, cast, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
|
@ -69,7 +69,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
workflow_thread_pool_id: str | None,
|
||||
is_retry: bool = False,
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
|
|
@ -84,7 +84,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
workflow_thread_pool_id: str | None,
|
||||
is_retry: bool = False,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
|
|
@ -99,7 +99,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
workflow_thread_pool_id: str | None,
|
||||
is_retry: bool = False,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
|
|
@ -113,7 +113,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
is_retry: bool = False,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
|
||||
# Add null check for dataset
|
||||
|
|
@ -314,7 +314,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
streaming: bool = True,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
|
@ -331,7 +331,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
"""
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
# init queue manager
|
||||
workflow = db.session.query(Workflow).filter(Workflow.id == workflow_id).first()
|
||||
workflow = db.session.query(Workflow).where(Workflow.id == workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_id}")
|
||||
queue_manager = PipelineQueueManager(
|
||||
|
|
@ -568,7 +568,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
|
|
@ -801,11 +801,11 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
self,
|
||||
datasource_runtime: OnlineDriveDatasourcePlugin,
|
||||
prefix: str,
|
||||
bucket: Optional[str],
|
||||
bucket: str | None,
|
||||
user_id: str,
|
||||
all_files: list,
|
||||
datasource_info: Mapping[str, Any],
|
||||
next_page_parameters: Optional[dict] = None,
|
||||
next_page_parameters: dict | None = None,
|
||||
):
|
||||
"""
|
||||
Get files in a folder.
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
import time
|
||||
from typing import Optional, cast
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
|
||||
|
|
@ -40,7 +40,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
variable_loader: VariableLoader,
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
:param application_generate_entity: application generate entity
|
||||
|
|
@ -69,13 +69,13 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
|
||||
user_id = None
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first()
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == app_config.app_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
|
||||
|
|
@ -188,7 +188,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
)
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]:
|
||||
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
|
|
@ -205,7 +205,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
return workflow
|
||||
|
||||
def _init_rag_pipeline_graph(
|
||||
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: Optional[str] = None
|
||||
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: str | None = None
|
||||
) -> Graph:
|
||||
"""
|
||||
Init pipeline graph
|
||||
|
|
|
|||
|
|
@ -242,7 +242,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
|||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_loop_run: Optional[SingleLoopRunEntity] = None
|
||||
single_loop_run: SingleLoopRunEntity | None = None
|
||||
|
||||
|
||||
class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
||||
|
|
@ -256,9 +256,9 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
|||
datasource_info: Mapping[str, Any]
|
||||
dataset_id: str
|
||||
batch: str
|
||||
document_id: Optional[str] = None
|
||||
original_document_id: Optional[str] = None
|
||||
start_node_id: Optional[str] = None
|
||||
document_id: str | None = None
|
||||
original_document_id: str | None = None
|
||||
start_node_id: str | None = None
|
||||
|
||||
|
||||
# Import TraceQueueManager at runtime to resolve forward references
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
|
@ -252,8 +252,8 @@ class NodeStartStreamResponse(StreamResponse):
|
|||
node_type: str
|
||||
title: str
|
||||
index: int
|
||||
predecessor_node_id: Optional[str] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: str | None = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
inputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: dict[str, object] = Field(default_factory=dict)
|
||||
|
|
@ -310,12 +310,12 @@ class NodeFinishStreamResponse(StreamResponse):
|
|||
node_type: str
|
||||
title: str
|
||||
index: int
|
||||
predecessor_node_id: Optional[str] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: str | None = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
inputs_truncated: bool = False
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
process_data_truncated: bool = False
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
outputs_truncated: bool = True
|
||||
status: str
|
||||
error: str | None = None
|
||||
|
|
@ -382,12 +382,12 @@ class NodeRetryStreamResponse(StreamResponse):
|
|||
node_type: str
|
||||
title: str
|
||||
index: int
|
||||
predecessor_node_id: Optional[str] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: str | None = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
inputs_truncated: bool = False
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
process_data_truncated: bool = False
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
outputs_truncated: bool = False
|
||||
status: str
|
||||
error: str | None = None
|
||||
|
|
@ -503,11 +503,11 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
|||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
outputs: Optional[Mapping] = None
|
||||
outputs: Mapping | None = None
|
||||
outputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: Optional[dict] = None
|
||||
inputs: Optional[Mapping] = None
|
||||
extras: dict | None = None
|
||||
inputs: Mapping | None = None
|
||||
inputs_truncated: bool = False
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: str | None = None
|
||||
|
|
@ -541,8 +541,8 @@ class LoopNodeStartStreamResponse(StreamResponse):
|
|||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
inputs_truncated: bool = False
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.LOOP_STARTED
|
||||
workflow_run_id: str
|
||||
|
|
@ -590,11 +590,11 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
|
|||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
outputs: Optional[Mapping] = None
|
||||
outputs: Mapping | None = None
|
||||
outputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: Optional[dict] = None
|
||||
inputs: Optional[Mapping] = None
|
||||
extras: dict | None = None
|
||||
inputs: Mapping | None = None
|
||||
inputs_truncated: bool = False
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: str | None = None
|
||||
|
|
|
|||
|
|
@ -17,9 +17,9 @@ class DatasourceRuntime(BaseModel):
|
|||
"""
|
||||
|
||||
tenant_id: str
|
||||
datasource_id: Optional[str] = None
|
||||
datasource_id: str | None = None
|
||||
invoke_from: Optional["InvokeFrom"] = None
|
||||
datasource_invoke_from: Optional[DatasourceInvokeFrom] = None
|
||||
datasource_invoke_from: DatasourceInvokeFrom | None = None
|
||||
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import os
|
|||
import time
|
||||
from datetime import datetime
|
||||
from mimetypes import guess_extension, guess_type
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
|
|
@ -62,10 +62,10 @@ class DatasourceFileManager:
|
|||
*,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: Optional[str],
|
||||
conversation_id: str | None,
|
||||
file_binary: bytes,
|
||||
mimetype: str,
|
||||
filename: Optional[str] = None,
|
||||
filename: str | None = None,
|
||||
) -> UploadFile:
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
unique_name = uuid4().hex
|
||||
|
|
@ -106,7 +106,7 @@ class DatasourceFileManager:
|
|||
user_id: str,
|
||||
tenant_id: str,
|
||||
file_url: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
conversation_id: str | None = None,
|
||||
) -> ToolFile:
|
||||
# try to download image
|
||||
try:
|
||||
|
|
@ -153,10 +153,7 @@ class DatasourceFileManager:
|
|||
:return: the binary of the file, mime type
|
||||
"""
|
||||
upload_file: UploadFile | None = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == id,
|
||||
)
|
||||
db.session.query(UploadFile).where(UploadFile.id == id)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -177,10 +174,7 @@ class DatasourceFileManager:
|
|||
:return: the binary of the file, mime type
|
||||
"""
|
||||
message_file: MessageFile | None = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(
|
||||
MessageFile.id == id,
|
||||
)
|
||||
db.session.query(MessageFile).where(MessageFile.id == id)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -197,10 +191,7 @@ class DatasourceFileManager:
|
|||
tool_file_id = None
|
||||
|
||||
tool_file: ToolFile | None = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
db.session.query(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -221,10 +212,7 @@ class DatasourceFileManager:
|
|||
:return: the binary of the file, mime type
|
||||
"""
|
||||
upload_file: UploadFile | None = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == upload_file_id,
|
||||
)
|
||||
db.session.query(UploadFile).where(UploadFile.id == upload_file_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,9 +12,9 @@ class DatasourceApiEntity(BaseModel):
|
|||
name: str # identifier
|
||||
label: I18nObject # label
|
||||
description: I18nObject
|
||||
parameters: Optional[list[DatasourceParameter]] = None
|
||||
parameters: list[DatasourceParameter] | None = None
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
output_schema: Optional[dict] = None
|
||||
output_schema: dict | None = None
|
||||
|
||||
|
||||
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||
|
|
@ -28,12 +28,12 @@ class DatasourceProviderApiEntity(BaseModel):
|
|||
icon: str | dict
|
||||
label: I18nObject # label
|
||||
type: str
|
||||
masked_credentials: Optional[dict] = None
|
||||
original_credentials: Optional[dict] = None
|
||||
masked_credentials: dict | None = None
|
||||
original_credentials: dict | None = None
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource")
|
||||
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource")
|
||||
plugin_id: str | None = Field(default="", description="The plugin id of the datasource")
|
||||
plugin_unique_identifier: str | None = Field(default="", description="The unique identifier of the datasource")
|
||||
datasources: list[DatasourceApiEntity] = Field(default_factory=list)
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -9,9 +8,9 @@ class I18nObject(BaseModel):
|
|||
"""
|
||||
|
||||
en_US: str
|
||||
zh_Hans: Optional[str] = Field(default=None)
|
||||
pt_BR: Optional[str] = Field(default=None)
|
||||
ja_JP: Optional[str] = Field(default=None)
|
||||
zh_Hans: str | None = Field(default=None)
|
||||
pt_BR: str | None = Field(default=None)
|
||||
ja_JP: str | None = Field(default=None)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import enum
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
from yarl import URL
|
||||
|
|
@ -80,7 +80,7 @@ class DatasourceParameter(PluginParameter):
|
|||
name: str,
|
||||
typ: DatasourceParameterType,
|
||||
required: bool,
|
||||
options: Optional[list[str]] = None,
|
||||
options: list[str] | None = None,
|
||||
) -> "DatasourceParameter":
|
||||
"""
|
||||
get a simple datasource parameter
|
||||
|
|
@ -120,14 +120,14 @@ class DatasourceIdentity(BaseModel):
|
|||
name: str = Field(..., description="The name of the datasource")
|
||||
label: I18nObject = Field(..., description="The label of the datasource")
|
||||
provider: str = Field(..., description="The provider of the datasource")
|
||||
icon: Optional[str] = None
|
||||
icon: str | None = None
|
||||
|
||||
|
||||
class DatasourceEntity(BaseModel):
|
||||
identity: DatasourceIdentity
|
||||
parameters: list[DatasourceParameter] = Field(default_factory=list)
|
||||
description: I18nObject = Field(..., description="The label of the datasource")
|
||||
output_schema: Optional[dict] = None
|
||||
output_schema: dict | None = None
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -141,7 +141,7 @@ class DatasourceProviderIdentity(BaseModel):
|
|||
description: I18nObject = Field(..., description="The description of the tool")
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
tags: Optional[list[ToolLabelEnum]] = Field(
|
||||
tags: list[ToolLabelEnum] | None = Field(
|
||||
default=[],
|
||||
description="The tags of the tool",
|
||||
)
|
||||
|
|
@ -169,7 +169,7 @@ class DatasourceProviderEntity(BaseModel):
|
|||
|
||||
identity: DatasourceProviderIdentity
|
||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||
oauth_schema: Optional[OAuthSchema] = None
|
||||
oauth_schema: OAuthSchema | None = None
|
||||
provider_type: DatasourceProviderType
|
||||
|
||||
|
||||
|
|
@ -183,8 +183,8 @@ class DatasourceInvokeMeta(BaseModel):
|
|||
"""
|
||||
|
||||
time_cost: float = Field(..., description="The time cost of the tool invoke")
|
||||
error: Optional[str] = None
|
||||
tool_config: Optional[dict] = None
|
||||
error: str | None = None
|
||||
tool_config: dict | None = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "DatasourceInvokeMeta":
|
||||
|
|
@ -233,10 +233,10 @@ class OnlineDocumentPage(BaseModel):
|
|||
|
||||
page_id: str = Field(..., description="The page id")
|
||||
page_name: str = Field(..., description="The page title")
|
||||
page_icon: Optional[dict] = Field(None, description="The page icon")
|
||||
page_icon: dict | None = Field(None, description="The page icon")
|
||||
type: str = Field(..., description="The type of the page")
|
||||
last_edited_time: str = Field(..., description="The last edited time")
|
||||
parent_id: Optional[str] = Field(None, description="The parent page id")
|
||||
parent_id: str | None = Field(None, description="The parent page id")
|
||||
|
||||
|
||||
class OnlineDocumentInfo(BaseModel):
|
||||
|
|
@ -244,9 +244,9 @@ class OnlineDocumentInfo(BaseModel):
|
|||
Online document info
|
||||
"""
|
||||
|
||||
workspace_id: Optional[str] = Field(None, description="The workspace id")
|
||||
workspace_name: Optional[str] = Field(None, description="The workspace name")
|
||||
workspace_icon: Optional[str] = Field(None, description="The workspace icon")
|
||||
workspace_id: str | None = Field(None, description="The workspace id")
|
||||
workspace_name: str | None = Field(None, description="The workspace name")
|
||||
workspace_icon: str | None = Field(None, description="The workspace icon")
|
||||
total: int = Field(..., description="The total number of documents")
|
||||
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")
|
||||
|
||||
|
|
@ -307,10 +307,10 @@ class WebSiteInfo(BaseModel):
|
|||
Website info
|
||||
"""
|
||||
|
||||
status: Optional[str] = Field(..., description="crawl job status")
|
||||
web_info_list: Optional[list[WebSiteInfoDetail]] = []
|
||||
total: Optional[int] = Field(default=0, description="The total number of websites")
|
||||
completed: Optional[int] = Field(default=0, description="The number of completed websites")
|
||||
status: str | None = Field(..., description="crawl job status")
|
||||
web_info_list: list[WebSiteInfoDetail] | None = []
|
||||
total: int | None = Field(default=0, description="The total number of websites")
|
||||
completed: int | None = Field(default=0, description="The number of completed websites")
|
||||
|
||||
|
||||
class WebsiteCrawlMessage(BaseModel):
|
||||
|
|
@ -346,10 +346,10 @@ class OnlineDriveFileBucket(BaseModel):
|
|||
Online drive file bucket
|
||||
"""
|
||||
|
||||
bucket: Optional[str] = Field(None, description="The file bucket")
|
||||
bucket: str | None = Field(None, description="The file bucket")
|
||||
files: list[OnlineDriveFile] = Field(..., description="The file list")
|
||||
is_truncated: bool = Field(False, description="Whether the result is truncated")
|
||||
next_page_parameters: Optional[dict] = Field(None, description="Parameters for fetching the next page")
|
||||
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
|
||||
|
||||
|
||||
class OnlineDriveBrowseFilesRequest(BaseModel):
|
||||
|
|
@ -357,10 +357,10 @@ class OnlineDriveBrowseFilesRequest(BaseModel):
|
|||
Get online drive file list request
|
||||
"""
|
||||
|
||||
bucket: Optional[str] = Field(None, description="The file bucket")
|
||||
bucket: str | None = Field(None, description="The file bucket")
|
||||
prefix: str = Field(..., description="The parent folder ID")
|
||||
max_keys: int = Field(20, description="Page size for pagination")
|
||||
next_page_parameters: Optional[dict] = Field(None, description="Parameters for fetching the next page")
|
||||
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
|
||||
|
||||
|
||||
class OnlineDriveBrowseFilesResponse(BaseModel):
|
||||
|
|
@ -377,4 +377,4 @@ class OnlineDriveDownloadFileRequest(BaseModel):
|
|||
"""
|
||||
|
||||
id: str = Field(..., description="The id of the file")
|
||||
bucket: Optional[str] = Field(None, description="The name of the bucket")
|
||||
bucket: str | None = Field(None, description="The name of the bucket")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import logging
|
||||
from collections.abc import Generator
|
||||
from mimetypes import guess_extension, guess_type
|
||||
from typing import Optional
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
|
|
@ -17,7 +16,7 @@ class DatasourceFileMessageTransformer:
|
|||
messages: Generator[DatasourceMessage, None, None],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
conversation_id: str | None = None,
|
||||
) -> Generator[DatasourceMessage, None, None]:
|
||||
"""
|
||||
Transform datasource message and handle file download
|
||||
|
|
@ -121,5 +120,5 @@ class DatasourceFileMessageTransformer:
|
|||
yield message
|
||||
|
||||
@classmethod
|
||||
def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str:
|
||||
def get_datasource_file_url(cls, datasource_file_id: str, extension: str | None) -> str:
|
||||
return f"/files/datasources/{datasource_file_id}{extension or '.bin'}"
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import uuid
|
|||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from flask import request
|
||||
from requests import get
|
||||
|
|
@ -169,9 +168,9 @@ class ApiBasedToolSchemaParser:
|
|||
return bundles
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
|
||||
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
|
||||
parameter = parameter or {}
|
||||
typ: Optional[str] = None
|
||||
typ: str | None = None
|
||||
if parameter.get("format") == "binary":
|
||||
return ToolParameter.ToolParameterType.FILE
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -16,7 +15,7 @@ class QAPreviewDetail(BaseModel):
|
|||
class IndexingEstimate(BaseModel):
|
||||
total_segments: int
|
||||
preview: list[PreviewDetail]
|
||||
qa_preview: Optional[list[QAPreviewDetail]] = None
|
||||
qa_preview: list[QAPreviewDetail] | None = None
|
||||
|
||||
|
||||
class PipelineDataset(BaseModel):
|
||||
|
|
@ -30,10 +29,10 @@ class PipelineDocument(BaseModel):
|
|||
id: str
|
||||
position: int
|
||||
data_source_type: str
|
||||
data_source_info: Optional[dict] = None
|
||||
data_source_info: dict | None = None
|
||||
name: str
|
||||
indexing_status: str
|
||||
error: Optional[str] = None
|
||||
error: str | None = None
|
||||
enabled: bool
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import datetime
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from packaging.version import InvalidVersion, Version
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
|
@ -67,10 +67,10 @@ class PluginCategory(StrEnum):
|
|||
|
||||
class PluginDeclaration(BaseModel):
|
||||
class Plugins(BaseModel):
|
||||
tools: Optional[list[str]] = Field(default_factory=list[str])
|
||||
models: Optional[list[str]] = Field(default_factory=list[str])
|
||||
endpoints: Optional[list[str]] = Field(default_factory=list[str])
|
||||
datasources: Optional[list[str]] = Field(default_factory=list[str])
|
||||
tools: list[str] | None = Field(default_factory=list[str])
|
||||
models: list[str] | None = Field(default_factory=list[str])
|
||||
endpoints: list[str] | None = Field(default_factory=list[str])
|
||||
datasources: list[str] | None = Field(default_factory=list[str])
|
||||
|
||||
class Meta(BaseModel):
|
||||
minimum_dify_version: str | None = Field(default=None)
|
||||
|
|
@ -101,11 +101,11 @@ class PluginDeclaration(BaseModel):
|
|||
tags: list[str] = Field(default_factory=list)
|
||||
repo: str | None = Field(default=None)
|
||||
verified: bool = Field(default=False)
|
||||
tool: Optional[ToolProviderEntity] = None
|
||||
model: Optional[ProviderEntity] = None
|
||||
endpoint: Optional[EndpointProviderDeclaration] = None
|
||||
agent_strategy: Optional[AgentStrategyProviderEntity] = None
|
||||
datasource: Optional[DatasourceProviderEntity] = None
|
||||
tool: ToolProviderEntity | None = None
|
||||
model: ProviderEntity | None = None
|
||||
endpoint: EndpointProviderDeclaration | None = None
|
||||
agent_strategy: AgentStrategyProviderEntity | None = None
|
||||
datasource: DatasourceProviderEntity | None = None
|
||||
meta: Meta
|
||||
|
||||
@field_validator("version")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -27,12 +27,12 @@ class DatasourceErrorEvent(BaseDatasourceEvent):
|
|||
class DatasourceCompletedEvent(BaseDatasourceEvent):
|
||||
event: str = DatasourceStreamEvent.COMPLETED.value
|
||||
data: Mapping[str, Any] | list = Field(..., description="result")
|
||||
total: Optional[int] = Field(default=0, description="total")
|
||||
completed: Optional[int] = Field(default=0, description="completed")
|
||||
time_consuming: Optional[float] = Field(default=0.0, description="time consuming")
|
||||
total: int | None = Field(default=0, description="total")
|
||||
completed: int | None = Field(default=0, description="completed")
|
||||
time_consuming: float | None = Field(default=0.0, description="time consuming")
|
||||
|
||||
|
||||
class DatasourceProcessingEvent(BaseDatasourceEvent):
|
||||
event: str = DatasourceStreamEvent.PROCESSING.value
|
||||
total: Optional[int] = Field(..., description="total")
|
||||
completed: Optional[int] = Field(..., description="completed")
|
||||
total: int | None = Field(..., description="total")
|
||||
completed: int | None = Field(..., description="completed")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
|
@ -11,7 +10,7 @@ class NotionInfo(BaseModel):
|
|||
Notion import info.
|
||||
"""
|
||||
|
||||
credential_id: Optional[str] = None
|
||||
credential_id: str | None = None
|
||||
notion_workspace_id: str
|
||||
notion_obj_id: str
|
||||
notion_page_type: str
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
import operator
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import requests
|
||||
|
||||
|
|
@ -35,9 +35,9 @@ class NotionExtractor(BaseExtractor):
|
|||
notion_obj_id: str,
|
||||
notion_page_type: str,
|
||||
tenant_id: str,
|
||||
document_model: Optional[DocumentModel] = None,
|
||||
notion_access_token: Optional[str] = None,
|
||||
credential_id: Optional[str] = None,
|
||||
document_model: DocumentModel | None = None,
|
||||
notion_access_token: str | None = None,
|
||||
credential_id: str | None = None,
|
||||
):
|
||||
self._notion_access_token = None
|
||||
self._document_model = document_model
|
||||
|
|
@ -369,7 +369,7 @@ class NotionExtractor(BaseExtractor):
|
|||
return cast(str, data["last_edited_time"])
|
||||
|
||||
@classmethod
|
||||
def _get_access_token(cls, tenant_id: str, credential_id: Optional[str]) -> str:
|
||||
def _get_access_token(cls, tenant_id: str, credential_id: str | None) -> str:
|
||||
# get credential from tenant_id and credential_id
|
||||
if not credential_id:
|
||||
raise Exception(f"No credential id found for tenant {tenant_id}")
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import json
|
|||
import logging
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Optional, TypeVar, Union
|
||||
from typing import Any, TypeVar, Union
|
||||
|
||||
import psycopg2.errors
|
||||
from sqlalchemy import UnaryExpression, asc, desc, select
|
||||
|
|
@ -530,7 +530,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
def get_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
order_config: OrderConfig | None = None,
|
||||
triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ class SchemaRegistry:
|
|||
except (OSError, json.JSONDecodeError) as e:
|
||||
print(f"Warning: failed to load schema {version}/{schema_name}: {e}")
|
||||
|
||||
def get_schema(self, uri: str) -> Optional[Any]:
|
||||
def get_schema(self, uri: str) -> Any | None:
|
||||
"""Retrieves a schema by URI with version support"""
|
||||
version, schema_name = self._parse_uri(uri)
|
||||
if not version or not schema_name:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import re
|
|||
import threading
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from core.schemas.registry import SchemaRegistry
|
||||
|
||||
|
|
@ -53,8 +53,8 @@ class QueueItem:
|
|||
"""Represents an item in the BFS queue"""
|
||||
|
||||
current: Any
|
||||
parent: Optional[Any]
|
||||
key: Optional[Union[str, int]]
|
||||
parent: Any | None
|
||||
key: Union[str, int] | None
|
||||
depth: int
|
||||
ref_path: set[str]
|
||||
|
||||
|
|
@ -65,7 +65,7 @@ class SchemaResolver:
|
|||
_cache: dict[str, SchemaDict] = {}
|
||||
_cache_lock = threading.Lock()
|
||||
|
||||
def __init__(self, registry: Optional[SchemaRegistry] = None, max_depth: int = 10):
|
||||
def __init__(self, registry: SchemaRegistry | None = None, max_depth: int = 10):
|
||||
"""
|
||||
Initialize the schema resolver
|
||||
|
||||
|
|
@ -202,7 +202,7 @@ class SchemaResolver:
|
|||
)
|
||||
)
|
||||
|
||||
def _get_resolved_schema(self, ref_uri: str) -> Optional[SchemaDict]:
|
||||
def _get_resolved_schema(self, ref_uri: str) -> SchemaDict | None:
|
||||
"""Get resolved schema from cache or registry"""
|
||||
# Check cache first
|
||||
with self._cache_lock:
|
||||
|
|
@ -223,7 +223,7 @@ class SchemaResolver:
|
|||
|
||||
|
||||
def resolve_dify_schema_refs(
|
||||
schema: SchemaType, registry: Optional[SchemaRegistry] = None, max_depth: int = 30
|
||||
schema: SchemaType, registry: SchemaRegistry | None = None, max_depth: int = 30
|
||||
) -> SchemaType:
|
||||
"""
|
||||
Resolve $ref references in Dify schema to actual schema content
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.schemas.registry import SchemaRegistry
|
||||
|
||||
|
|
@ -7,7 +7,7 @@ from core.schemas.registry import SchemaRegistry
|
|||
class SchemaManager:
|
||||
"""Schema manager provides high-level schema operations"""
|
||||
|
||||
def __init__(self, registry: Optional[SchemaRegistry] = None):
|
||||
def __init__(self, registry: SchemaRegistry | None = None):
|
||||
self.registry = registry or SchemaRegistry.default_registry()
|
||||
|
||||
def get_all_schema_definitions(self, version: str = "v1") -> list[Mapping[str, Any]]:
|
||||
|
|
@ -22,7 +22,7 @@ class SchemaManager:
|
|||
"""
|
||||
return self.registry.get_all_schemas_for_version(version)
|
||||
|
||||
def get_schema_by_name(self, schema_name: str, version: str = "v1") -> Optional[Mapping[str, Any]]:
|
||||
def get_schema_by_name(self, schema_name: str, version: str = "v1") -> Mapping[str, Any] | None:
|
||||
"""
|
||||
Get a specific schema by name
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ and don't contain implementation details like tenant_id, app_id, etc.
|
|||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
|
|
@ -53,9 +53,9 @@ class WorkflowNodeExecution(BaseModel):
|
|||
|
||||
# Execution data
|
||||
# The `inputs` and `outputs` fields hold the full content
|
||||
inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node
|
||||
process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data
|
||||
outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node
|
||||
inputs: Mapping[str, Any] | None = None # Input variables used by this node
|
||||
process_data: Mapping[str, Any] | None = None # Intermediate processing data
|
||||
outputs: Mapping[str, Any] | None = None # Output variables produced by this node
|
||||
|
||||
# Execution state
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -50,7 +50,7 @@ class DatasourceNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = DatasourceNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -59,7 +59,7 @@ class DatasourceNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
@ -179,7 +179,7 @@ class DatasourceNode(Node):
|
|||
related_id = datasource_info.get("related_id")
|
||||
if not related_id:
|
||||
raise DatasourceNodeError("File is not exist")
|
||||
upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first()
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == related_id).first()
|
||||
if not upload_file:
|
||||
raise ValueError("Invalid upload file Info")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
|
@ -10,7 +10,7 @@ class DatasourceEntity(BaseModel):
|
|||
plugin_id: str
|
||||
provider_name: str # redundancy
|
||||
provider_type: str
|
||||
datasource_name: Optional[str] = "local_file"
|
||||
datasource_name: str | None = "local_file"
|
||||
datasource_configurations: dict[str, Any] | None = None
|
||||
plugin_unique_identifier: str | None = None # redundancy
|
||||
|
||||
|
|
@ -19,7 +19,7 @@ class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
|||
class DatasourceInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Optional[Literal["mixed", "variable", "constant"]] = None
|
||||
type: Literal["mixed", "variable", "constant"] | None = None
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Literal, Optional, Union
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -65,12 +65,12 @@ class RetrievalSetting(BaseModel):
|
|||
|
||||
search_method: Literal["semantic_search", "keyword_search", "fulltext_search", "hybrid_search"]
|
||||
top_k: int
|
||||
score_threshold: Optional[float] = 0.5
|
||||
score_threshold: float | None = 0.5
|
||||
score_threshold_enabled: bool = False
|
||||
reranking_mode: str = "reranking_model"
|
||||
reranking_enable: bool = True
|
||||
reranking_model: Optional[RerankingModelConfig] = None
|
||||
weights: Optional[WeightedScoreConfig] = None
|
||||
reranking_model: RerankingModelConfig | None = None
|
||||
weights: WeightedScoreConfig | None = None
|
||||
|
||||
|
||||
class IndexMethod(BaseModel):
|
||||
|
|
@ -107,10 +107,10 @@ class OnlineDocumentInfo(BaseModel):
|
|||
"""
|
||||
|
||||
provider: str
|
||||
workspace_id: Optional[str] = None
|
||||
workspace_id: str | None = None
|
||||
page_id: str
|
||||
page_type: str
|
||||
icon: Optional[OnlineDocumentIcon] = None
|
||||
icon: OnlineDocumentIcon | None = None
|
||||
|
||||
|
||||
class WebsiteInfo(BaseModel):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import datetime
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import func, select
|
||||
|
||||
|
|
@ -43,7 +43,7 @@ class KnowledgeIndexNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = KnowledgeIndexNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -52,7 +52,7 @@ class KnowledgeIndexNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import re
|
|||
import time
|
||||
from datetime import datetime
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func, select
|
||||
|
|
@ -76,7 +76,7 @@ class Dataset(Base):
|
|||
|
||||
@property
|
||||
def total_documents(self):
|
||||
return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()
|
||||
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
|
||||
|
||||
@property
|
||||
def total_available_documents(self):
|
||||
|
|
@ -173,10 +173,10 @@ class Dataset(Base):
|
|||
)
|
||||
|
||||
@property
|
||||
def doc_form(self) -> Optional[str]:
|
||||
def doc_form(self) -> str | None:
|
||||
if self.chunk_structure:
|
||||
return self.chunk_structure
|
||||
document = db.session.query(Document).filter(Document.dataset_id == self.id).first()
|
||||
document = db.session.query(Document).where(Document.dataset_id == self.id).first()
|
||||
if document:
|
||||
return document.doc_form
|
||||
return None
|
||||
|
|
@ -234,7 +234,7 @@ class Dataset(Base):
|
|||
@property
|
||||
def is_published(self):
|
||||
if self.pipeline_id:
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first()
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first()
|
||||
if pipeline:
|
||||
return pipeline.is_published
|
||||
return False
|
||||
|
|
@ -1244,7 +1244,7 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
|
|||
|
||||
@property
|
||||
def created_user_name(self):
|
||||
account = db.session.query(Account).filter(Account.id == self.created_by).first()
|
||||
account = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
if account:
|
||||
return account.name
|
||||
return ""
|
||||
|
|
@ -1274,7 +1274,7 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
|||
|
||||
@property
|
||||
def created_user_name(self):
|
||||
account = db.session.query(Account).filter(Account.id == self.created_by).first()
|
||||
account = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
if account:
|
||||
return account.name
|
||||
return ""
|
||||
|
|
@ -1297,7 +1297,7 @@ class Pipeline(Base): # type: ignore[name-defined]
|
|||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
def retrieve_dataset(self, session: Session):
|
||||
return session.query(Dataset).filter(Dataset.pipeline_id == self.id).first()
|
||||
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
|
||||
|
||||
|
||||
class DocumentPipelineExecutionLog(Base):
|
||||
|
|
|
|||
|
|
@ -1546,7 +1546,7 @@ class WorkflowDraftVariableFile(Base):
|
|||
comment="Size of the original variable content in bytes",
|
||||
)
|
||||
|
||||
length: Mapped[Optional[int]] = mapped_column(
|
||||
length: Mapped[int | None] = mapped_column(
|
||||
sa.Integer,
|
||||
nullable=True,
|
||||
comment=(
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import time
|
|||
import uuid
|
||||
from collections import Counter
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import exists, func, select
|
||||
|
|
@ -315,8 +315,8 @@ class DatasetService:
|
|||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(dataset_id) -> Optional[Dataset]:
|
||||
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
def get_dataset(dataset_id) -> Dataset | None:
|
||||
dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class IconInfo(BaseModel):
|
||||
icon: str
|
||||
icon_background: Optional[str] = None
|
||||
icon_type: Optional[str] = None
|
||||
icon_url: Optional[str] = None
|
||||
icon_background: str | None = None
|
||||
icon_type: str | None = None
|
||||
icon_url: str | None = None
|
||||
|
||||
|
||||
class PipelineTemplateInfoEntity(BaseModel):
|
||||
|
|
@ -21,8 +21,8 @@ class RagPipelineDatasetCreateEntity(BaseModel):
|
|||
description: str
|
||||
icon_info: IconInfo
|
||||
permission: str
|
||||
partial_member_list: Optional[list[str]] = None
|
||||
yaml_content: Optional[str] = None
|
||||
partial_member_list: list[str] | None = None
|
||||
yaml_content: str | None = None
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
|
|
@ -30,8 +30,8 @@ class RerankingModelConfig(BaseModel):
|
|||
Reranking Model Config.
|
||||
"""
|
||||
|
||||
reranking_provider_name: Optional[str] = ""
|
||||
reranking_model_name: Optional[str] = ""
|
||||
reranking_provider_name: str | None = ""
|
||||
reranking_model_name: str | None = ""
|
||||
|
||||
|
||||
class VectorSetting(BaseModel):
|
||||
|
|
@ -57,8 +57,8 @@ class WeightedScoreConfig(BaseModel):
|
|||
Weighted score Config.
|
||||
"""
|
||||
|
||||
vector_setting: Optional[VectorSetting]
|
||||
keyword_setting: Optional[KeywordSetting]
|
||||
vector_setting: VectorSetting | None
|
||||
keyword_setting: KeywordSetting | None
|
||||
|
||||
|
||||
class EmbeddingSetting(BaseModel):
|
||||
|
|
@ -85,12 +85,12 @@ class RetrievalSetting(BaseModel):
|
|||
|
||||
search_method: Literal["semantic_search", "fulltext_search", "keyword_search", "hybrid_search"]
|
||||
top_k: int
|
||||
score_threshold: Optional[float] = 0.5
|
||||
score_threshold: float | None = 0.5
|
||||
score_threshold_enabled: bool = False
|
||||
reranking_mode: Optional[str] = "reranking_model"
|
||||
reranking_enable: Optional[bool] = True
|
||||
reranking_model: Optional[RerankingModelConfig] = None
|
||||
weights: Optional[WeightedScoreConfig] = None
|
||||
reranking_mode: str | None = "reranking_model"
|
||||
reranking_enable: bool | None = True
|
||||
reranking_model: RerankingModelConfig | None = None
|
||||
weights: WeightedScoreConfig | None = None
|
||||
|
||||
|
||||
class IndexMethod(BaseModel):
|
||||
|
|
@ -112,7 +112,7 @@ class KnowledgeConfiguration(BaseModel):
|
|||
indexing_technique: Literal["high_quality", "economy"]
|
||||
embedding_model_provider: str = ""
|
||||
embedding_model: str = ""
|
||||
keyword_number: Optional[int] = 10
|
||||
keyword_number: int | None = 10
|
||||
retrieval_model: RetrievalSetting
|
||||
|
||||
@field_validator("embedding_model_provider", mode="before")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -9,7 +9,7 @@ class DatasourceNodeRunApiEntity(BaseModel):
|
|||
node_id: str
|
||||
inputs: Mapping[str, Any]
|
||||
datasource_type: str
|
||||
credential_id: Optional[str] = None
|
||||
credential_id: str | None = None
|
||||
is_published: bool
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ class PipelineGenerateService:
|
|||
Update document status to waiting
|
||||
:param document_id: document id
|
||||
"""
|
||||
document = db.session.query(Document).filter(Document.id == document_id).first()
|
||||
document = db.session.query(Document).where(Document.id == document_id).first()
|
||||
if document:
|
||||
document.indexing_status = "waiting"
|
||||
db.session.add(document)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import json
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
|
|
@ -14,7 +13,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||
Retrieval pipeline template from built-in, the location is constants/pipeline_templates.json
|
||||
"""
|
||||
|
||||
builtin_data: Optional[dict] = None
|
||||
builtin_data: dict | None = None
|
||||
|
||||
def get_type(self) -> str:
|
||||
return PipelineTemplateType.BUILTIN
|
||||
|
|
@ -54,7 +53,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||
return builtin_data.get("pipeline_templates", {}).get(language, {})
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> Optional[dict]:
|
||||
def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> dict | None:
|
||||
"""
|
||||
Fetch pipeline template detail from builtin.
|
||||
:param template_id: Template ID
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
from flask_login import current_user
|
||||
|
|
@ -56,14 +55,14 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||
return {"pipeline_templates": recommended_pipelines_results}
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]:
|
||||
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None:
|
||||
"""
|
||||
Fetch pipeline template detail from db.
|
||||
:param template_id: Template ID
|
||||
:return:
|
||||
"""
|
||||
pipeline_template = (
|
||||
db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
|
||||
db.session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
|
||||
)
|
||||
if not pipeline_template:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
|
|
@ -33,7 +32,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||
"""
|
||||
|
||||
pipeline_built_in_templates: list[PipelineBuiltInTemplate] = (
|
||||
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all()
|
||||
db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.language == language).all()
|
||||
)
|
||||
|
||||
recommended_pipelines_results = []
|
||||
|
|
@ -53,7 +52,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||
return {"pipeline_templates": recommended_pipelines_results}
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]:
|
||||
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None:
|
||||
"""
|
||||
Fetch pipeline template detail from db.
|
||||
:param pipeline_id: Pipeline ID
|
||||
|
|
@ -61,7 +60,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||
"""
|
||||
# is in public recommended list
|
||||
pipeline_template = (
|
||||
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == template_id).first()
|
||||
db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.id == template_id).first()
|
||||
)
|
||||
|
||||
if not pipeline_template:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class PipelineTemplateRetrievalBase(ABC):
|
||||
|
|
@ -10,7 +9,7 @@ class PipelineTemplateRetrievalBase(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_pipeline_template_detail(self, template_id: str) -> Optional[dict]:
|
||||
def get_pipeline_template_detail(self, template_id: str) -> dict | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
|
|
@ -36,7 +35,7 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||
return PipelineTemplateType.REMOTE
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> Optional[dict]:
|
||||
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict | None:
|
||||
"""
|
||||
Fetch pipeline template detail from dify official.
|
||||
:param template_id: Pipeline ID
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import threading
|
|||
import time
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, Union, cast
|
||||
from typing import Any, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from flask_login import current_user
|
||||
|
|
@ -112,7 +112,7 @@ class RagPipelineService:
|
|||
return result
|
||||
|
||||
@classmethod
|
||||
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> Optional[dict]:
|
||||
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None:
|
||||
"""
|
||||
Get pipeline template detail.
|
||||
:param template_id: template id
|
||||
|
|
@ -121,12 +121,12 @@ class RagPipelineService:
|
|||
if type == "built-in":
|
||||
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
built_in_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
|
||||
built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id)
|
||||
return built_in_result
|
||||
else:
|
||||
mode = "customized"
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
customized_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
|
||||
customized_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id)
|
||||
return customized_result
|
||||
|
||||
@classmethod
|
||||
|
|
@ -185,7 +185,7 @@ class RagPipelineService:
|
|||
db.session.delete(customized_template)
|
||||
db.session.commit()
|
||||
|
||||
def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
|
||||
def get_draft_workflow(self, pipeline: Pipeline) -> Workflow | None:
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
|
|
@ -203,7 +203,7 @@ class RagPipelineService:
|
|||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
|
||||
def get_published_workflow(self, pipeline: Pipeline) -> Workflow | None:
|
||||
"""
|
||||
Get published workflow
|
||||
"""
|
||||
|
|
@ -267,7 +267,7 @@ class RagPipelineService:
|
|||
*,
|
||||
pipeline: Pipeline,
|
||||
graph: dict,
|
||||
unique_hash: Optional[str],
|
||||
unique_hash: str | None,
|
||||
account: Account,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
|
|
@ -387,9 +387,7 @@ class RagPipelineService:
|
|||
|
||||
return default_block_configs
|
||||
|
||||
def get_default_block_config(
|
||||
self, node_type: str, filters: Optional[dict] = None
|
||||
) -> Optional[Mapping[str, object]]:
|
||||
def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param node_type: node type
|
||||
|
|
@ -495,7 +493,7 @@ class RagPipelineService:
|
|||
account: Account,
|
||||
datasource_type: str,
|
||||
is_published: bool,
|
||||
credential_id: Optional[str] = None,
|
||||
credential_id: str | None = None,
|
||||
) -> Generator[Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Run published workflow datasource
|
||||
|
|
@ -661,7 +659,7 @@ class RagPipelineService:
|
|||
account: Account,
|
||||
datasource_type: str,
|
||||
is_published: bool,
|
||||
credential_id: Optional[str] = None,
|
||||
credential_id: str | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Run published workflow datasource
|
||||
|
|
@ -876,7 +874,7 @@ class RagPipelineService:
|
|||
if invoke_from.value == InvokeFrom.PUBLISHED.value:
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if document_id:
|
||||
document = db.session.query(Document).filter(Document.id == document_id.value).first()
|
||||
document = db.session.query(Document).where(Document.id == document_id.value).first()
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = error
|
||||
|
|
@ -887,7 +885,7 @@ class RagPipelineService:
|
|||
|
||||
def update_workflow(
|
||||
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
|
||||
) -> Optional[Workflow]:
|
||||
) -> Workflow | None:
|
||||
"""
|
||||
Update workflow attributes
|
||||
|
||||
|
|
@ -1057,7 +1055,7 @@ class RagPipelineService:
|
|||
|
||||
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
|
||||
|
||||
def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]:
|
||||
def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> WorkflowRun | None:
|
||||
"""
|
||||
Get workflow run detail
|
||||
|
||||
|
|
@ -1113,12 +1111,12 @@ class RagPipelineService:
|
|||
"""
|
||||
Publish customized pipeline template
|
||||
"""
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
if not pipeline.workflow_id:
|
||||
raise ValueError("Pipeline workflow not found")
|
||||
workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
|
||||
workflow = db.session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
with Session(db.engine) as session:
|
||||
|
|
@ -1142,7 +1140,7 @@ class RagPipelineService:
|
|||
|
||||
max_position = (
|
||||
db.session.query(func.max(PipelineCustomizedTemplate.position))
|
||||
.filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
|
||||
.where(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
|
|
@ -1278,7 +1276,7 @@ class RagPipelineService:
|
|||
# Query active recommended plugins
|
||||
pipeline_recommended_plugins = (
|
||||
db.session.query(PipelineRecommendedPlugin)
|
||||
.filter(PipelineRecommendedPlugin.active == True)
|
||||
.where(PipelineRecommendedPlugin.active == True)
|
||||
.order_by(PipelineRecommendedPlugin.position.asc())
|
||||
.all()
|
||||
)
|
||||
|
|
@ -1329,12 +1327,12 @@ class RagPipelineService:
|
|||
"""
|
||||
document_pipeline_excution_log = (
|
||||
db.session.query(DocumentPipelineExecutionLog)
|
||||
.filter(DocumentPipelineExecutionLog.document_id == document.id)
|
||||
.where(DocumentPipelineExecutionLog.document_id == document.id)
|
||||
.first()
|
||||
)
|
||||
if not document_pipeline_excution_log:
|
||||
raise ValueError("Document pipeline execution log not found")
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == document_pipeline_excution_log.pipeline_id).first()
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_excution_log.pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
# convert to app config
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import uuid
|
|||
from collections.abc import Mapping
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
from typing import Optional, cast
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
|
|
@ -66,11 +66,11 @@ class ImportStatus(StrEnum):
|
|||
class RagPipelineImportInfo(BaseModel):
|
||||
id: str
|
||||
status: ImportStatus
|
||||
pipeline_id: Optional[str] = None
|
||||
pipeline_id: str | None = None
|
||||
current_dsl_version: str = CURRENT_DSL_VERSION
|
||||
imported_dsl_version: str = ""
|
||||
error: str = ""
|
||||
dataset_id: Optional[str] = None
|
||||
dataset_id: str | None = None
|
||||
|
||||
|
||||
class CheckDependenciesResult(BaseModel):
|
||||
|
|
@ -121,12 +121,12 @@ class RagPipelineDslService:
|
|||
*,
|
||||
account: Account,
|
||||
import_mode: str,
|
||||
yaml_content: Optional[str] = None,
|
||||
yaml_url: Optional[str] = None,
|
||||
pipeline_id: Optional[str] = None,
|
||||
dataset: Optional[Dataset] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
icon_info: Optional[IconInfo] = None,
|
||||
yaml_content: str | None = None,
|
||||
yaml_url: str | None = None,
|
||||
pipeline_id: str | None = None,
|
||||
dataset: Dataset | None = None,
|
||||
dataset_name: str | None = None,
|
||||
icon_info: IconInfo | None = None,
|
||||
) -> RagPipelineImportInfo:
|
||||
"""Import an app from YAML content or URL."""
|
||||
import_id = str(uuid.uuid4())
|
||||
|
|
@ -530,10 +530,10 @@ class RagPipelineDslService:
|
|||
def _create_or_update_pipeline(
|
||||
self,
|
||||
*,
|
||||
pipeline: Optional[Pipeline],
|
||||
pipeline: Pipeline | None,
|
||||
data: dict,
|
||||
account: Account,
|
||||
dependencies: Optional[list[PluginDependency]] = None,
|
||||
dependencies: list[PluginDependency] | None = None,
|
||||
) -> Pipeline:
|
||||
"""Create a new app or update an existing one."""
|
||||
if not account.current_tenant_id:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import json
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
|
|
@ -21,7 +20,7 @@ from services.plugin.plugin_service import PluginService
|
|||
|
||||
class RagPipelineTransformService:
|
||||
def transform_dataset(self, dataset_id: str):
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
if dataset.pipeline_id and dataset.runtime_mode == "rag_pipeline":
|
||||
|
|
@ -90,7 +89,7 @@ class RagPipelineTransformService:
|
|||
"status": "success",
|
||||
}
|
||||
|
||||
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: Optional[str]):
|
||||
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None):
|
||||
pipeline_yaml = {}
|
||||
if doc_form == "text_model":
|
||||
match datasource_type:
|
||||
|
|
@ -152,7 +151,7 @@ class RagPipelineTransformService:
|
|||
return node
|
||||
|
||||
def _deal_knowledge_index(
|
||||
self, dataset: Dataset, doc_form: str, indexing_technique: Optional[str], retrieval_model: dict, node: dict
|
||||
self, dataset: Dataset, doc_form: str, indexing_technique: str | None, retrieval_model: dict, node: dict
|
||||
):
|
||||
knowledge_configuration_dict = node.get("data", {})
|
||||
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration_dict)
|
||||
|
|
@ -289,7 +288,7 @@ class RagPipelineTransformService:
|
|||
jina_node_id = "1752491761974"
|
||||
firecrawl_node_id = "1752565402678"
|
||||
|
||||
documents = db.session.query(Document).filter(Document.dataset_id == dataset.id).all()
|
||||
documents = db.session.query(Document).where(Document.dataset_id == dataset.id).all()
|
||||
|
||||
for document in documents:
|
||||
data_source_info_dict = document.data_source_info_dict
|
||||
|
|
@ -299,7 +298,7 @@ class RagPipelineTransformService:
|
|||
document.data_source_type = "local_file"
|
||||
file_id = data_source_info_dict.get("upload_file_id")
|
||||
if file_id:
|
||||
file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
|
||||
file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
if file:
|
||||
data_source_info = json.dumps(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
|
@ -17,7 +16,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: Optional[str], file_ids: list[str]):
|
||||
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]):
|
||||
"""
|
||||
Clean document when document deleted.
|
||||
:param document_ids: document ids
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
|||
|
||||
if dataset_documents:
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
|
|
@ -76,12 +76,12 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
|||
# clean keywords
|
||||
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
|
|
@ -100,7 +100,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
|||
if dataset_documents:
|
||||
# update document status
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
|
|
@ -148,12 +148,12 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
|||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -104,20 +104,20 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
|
|||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Load required entities
|
||||
account = session.query(Account).filter(Account.id == user_id).first()
|
||||
account = session.query(Account).where(Account.id == user_id).first()
|
||||
if not account:
|
||||
raise ValueError(f"Account {user_id} not found")
|
||||
|
||||
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
|
||||
tenant = session.query(Tenant).where(Tenant.id == tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant {tenant_id} not found")
|
||||
account.current_tenant = tenant
|
||||
|
||||
pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
|
||||
pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError(f"Pipeline {pipeline_id} not found")
|
||||
|
||||
workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
|
||||
workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -125,20 +125,20 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
|
|||
|
||||
with Session(db.engine) as session:
|
||||
# Load required entities
|
||||
account = session.query(Account).filter(Account.id == user_id).first()
|
||||
account = session.query(Account).where(Account.id == user_id).first()
|
||||
if not account:
|
||||
raise ValueError(f"Account {user_id} not found")
|
||||
|
||||
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
|
||||
tenant = session.query(Tenant).where(Tenant.id == tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant {tenant_id} not found")
|
||||
account.current_tenant = tenant
|
||||
|
||||
pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
|
||||
pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError(f"Pipeline {pipeline_id} not found")
|
||||
|
||||
workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
|
||||
workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
|
|||
if not user:
|
||||
logger.info(click.style(f"User not found: {user_id}", fg="red"))
|
||||
return
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == dataset.tenant_id).first()
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == dataset.tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError("Tenant not found")
|
||||
user.current_tenant = tenant
|
||||
|
|
|
|||
Loading…
Reference in New Issue