Merge branch 'feat/rag-2' into fix/dependency-version

This commit is contained in:
Wu Tianwei 2025-09-18 11:06:26 +08:00 committed by GitHub
commit 18027b530a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
66 changed files with 516 additions and 475 deletions

View File

@ -1440,12 +1440,12 @@ def transform_datasource_credentials():
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
if notion_credentials:
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
for credential in notion_credentials:
tenant_id = credential.tenant_id
for notion_credential in notion_credentials:
tenant_id = notion_credential.tenant_id
if tenant_id not in notion_credentials_tenant_mapping:
notion_credentials_tenant_mapping[tenant_id] = []
notion_credentials_tenant_mapping[tenant_id].append(credential)
for tenant_id, credentials in notion_credentials_tenant_mapping.items():
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
# check notion plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
@ -1454,12 +1454,12 @@ def transform_datasource_credentials():
# install notion plugin
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
auth_count = 0
for credential in credentials:
for notion_tenant_credential in notion_tenant_credentials:
auth_count += 1
# get credential oauth params
access_token = credential.access_token
access_token = notion_tenant_credential.access_token
# notion info
notion_info = credential.source_info
notion_info = notion_tenant_credential.source_info
workspace_id = notion_info.get("workspace_id")
workspace_name = notion_info.get("workspace_name")
workspace_icon = notion_info.get("workspace_icon")
@ -1487,12 +1487,12 @@ def transform_datasource_credentials():
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
if firecrawl_credentials:
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
for credential in firecrawl_credentials:
tenant_id = credential.tenant_id
for firecrawl_credential in firecrawl_credentials:
tenant_id = firecrawl_credential.tenant_id
if tenant_id not in firecrawl_credentials_tenant_mapping:
firecrawl_credentials_tenant_mapping[tenant_id] = []
firecrawl_credentials_tenant_mapping[tenant_id].append(credential)
for tenant_id, credentials in firecrawl_credentials_tenant_mapping.items():
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
# check firecrawl plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
@ -1502,10 +1502,10 @@ def transform_datasource_credentials():
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
auth_count = 0
for credential in credentials:
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
auth_count += 1
# get credential api key
credentials_json = json.loads(credential.credentials)
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
base_url = credentials_json.get("config", {}).get("base_url")
new_credentials = {
@ -1530,12 +1530,12 @@ def transform_datasource_credentials():
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all()
if jina_credentials:
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
for credential in jina_credentials:
tenant_id = credential.tenant_id
for jina_credential in jina_credentials:
tenant_id = jina_credential.tenant_id
if tenant_id not in jina_credentials_tenant_mapping:
jina_credentials_tenant_mapping[tenant_id] = []
jina_credentials_tenant_mapping[tenant_id].append(credential)
for tenant_id, credentials in jina_credentials_tenant_mapping.items():
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
# check jina plugin is installed
installed_plugins = installer_manager.list_plugins(tenant_id)
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
@ -1546,10 +1546,10 @@ def transform_datasource_credentials():
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
auth_count = 0
for credential in credentials:
for jina_tenant_credential in jina_tenant_credentials:
auth_count += 1
# get credential api key
credentials_json = json.loads(credential.credentials)
credentials_json = json.loads(jina_tenant_credential.credentials)
api_key = credentials_json.get("config", {}).get("api_key")
new_credentials = {
"integration_secret": api_key,

View File

@ -1,3 +1,3 @@
from .app_config import DifyConfig
dify_config = DifyConfig() # pyright: ignore[reportCallIssue]
dify_config = DifyConfig() # type: ignore

View File

@ -512,11 +512,11 @@ class WorkflowVariableTruncationConfig(BaseSettings):
description="Maximum size for variable to trigger final truncation.",
)
WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH: PositiveInt = Field(
50000,
100000,
description="maximum length for string to trigger tuncation, measure in number of characters",
)
WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH: PositiveInt = Field(
100,
1000,
description="maximum length for array to trigger truncation.",
)

View File

@ -104,7 +104,7 @@ class CustomizedPipelineTemplateApi(Resource):
def post(self, template_id: str):
with Session(db.engine) as session:
template = (
session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
)
if not template:
raise ValueError("Customized pipeline template not found.")

View File

@ -1,6 +1,5 @@
from collections.abc import Callable
from functools import wraps
from typing import Optional
from controllers.console.datasets.error import PipelineNotFoundError
from extensions.ext_database import db
@ -10,7 +9,7 @@ from models.dataset import Pipeline
def get_rag_pipeline(
view: Optional[Callable] = None,
view: Callable | None = None,
):
def decorator(view_func):
@wraps(view_func)
@ -28,7 +27,7 @@ def get_rag_pipeline(
pipeline = (
db.session.query(Pipeline)
.filter(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
.first()
)

View File

@ -421,11 +421,10 @@ class PluginUploadFileRequestApi(Resource):
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile):
# generate signed url
url = get_signed_file_url_for_plugin(
payload.filename,
payload.mimetype,
tenant_model.id,
user_model.id,
user_model.session_id if isinstance(user_model, EndUser) else None,
filename=payload.filename,
mimetype=payload.mimetype,
tenant_id=tenant_model.id,
user_id=user_model.id,
)
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()

View File

@ -32,11 +32,20 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
user_model = (
session.query(EndUser)
.where(
EndUser.session_id == user_id,
EndUser.id == user_id,
EndUser.tenant_id == tenant_id,
)
.first()
)
if not user_model:
user_model = (
session.query(EndUser)
.where(
EndUser.session_id == user_id,
EndUser.tenant_id == tenant_id,
)
.first()
)
if not user_model:
user_model = EndUser(
tenant_id=tenant_id,

View File

@ -124,6 +124,9 @@ class DocumentAddByTextApi(DatasetApiResource):
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
)
if not current_user:
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text(
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
)
@ -204,6 +207,8 @@ class DocumentUpdateByTextApi(DatasetApiResource):
name = args.get("name")
if text is None or name is None:
raise ValueError("Both text and name must be strings.")
if not current_user:
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_text(
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
)
@ -308,6 +313,8 @@ class DocumentAddByFileApi(DatasetApiResource):
if not isinstance(current_user, EndUser):
raise ValueError("Invalid user account")
if not current_user:
raise ValueError("current_user is required")
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
@ -396,8 +403,12 @@ class DocumentUpdateByFileApi(DatasetApiResource):
if not file.filename:
raise FilenameNotExistsError
if not current_user:
raise ValueError("current_user is required")
if not isinstance(current_user, EndUser):
raise ValueError("Invalid user account")
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
@ -577,7 +588,7 @@ class DocumentApi(DatasetApiResource):
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
elif metadata == "without":
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict()
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,
@ -610,7 +621,7 @@ class DocumentApi(DatasetApiResource):
}
else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict()
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
data_source_info = document.data_source_detail_dict
response = {
"id": document.id,

View File

@ -47,3 +47,9 @@ class DatasetInUseError(BaseHTTPException):
error_code = "dataset_in_use"
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
code = 409
class PipelineRunError(BaseHTTPException):
error_code = "pipeline_run_error"
description = "An error occurred while running the pipeline."
code = 500

View File

@ -216,6 +216,9 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
if not file.filename:
raise FilenameNotExistsError
if not current_user:
raise ValueError("Invalid user account")
try:
upload_file = FileService(db.engine).upload_file(
filename=file.filename,

View File

@ -1,6 +1,6 @@
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal, Optional
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
@ -114,9 +114,9 @@ class VariableEntity(BaseModel):
hide: bool = False
max_length: int | None = None
options: Sequence[str] = Field(default_factory=list)
allowed_file_types: Optional[Sequence[FileType]] = Field(default_factory=list)
allowed_file_extensions: Optional[Sequence[str]] = Field(default_factory=list)
allowed_file_upload_methods: Optional[Sequence[FileTransferMethod]] = Field(default_factory=list)
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
@field_validator("description", mode="before")
@classmethod
@ -134,8 +134,8 @@ class RagPipelineVariableEntity(VariableEntity):
Rag Pipeline Variable Entity.
"""
tooltips: Optional[str] = None
placeholder: Optional[str] = None
tooltips: str | None = None
placeholder: str | None = None
belong_to_node_id: str
@ -298,7 +298,7 @@ class AppConfig(BaseModel):
tenant_id: str
app_id: str
app_mode: AppMode
additional_features: Optional[AppAdditionalFeatures] = None
additional_features: AppAdditionalFeatures | None = None
variables: list[VariableEntity] = []
sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None

View File

@ -45,9 +45,9 @@ class BaseAppGenerator:
mapping=v,
tenant_id=tenant_id,
config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types, # pyright: ignore[reportArgumentType]
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, # pyright: ignore[reportArgumentType]
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, # pyright: ignore[reportArgumentType]
allowed_file_types=entity_dictionary[k].allowed_file_types or [],
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [],
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
),
strict_type_validation=strict_type_validation,
)
@ -60,9 +60,9 @@ class BaseAppGenerator:
mappings=v,
tenant_id=tenant_id,
config=FileUploadConfig(
allowed_file_types=entity_dictionary[k].allowed_file_types, # pyright: ignore[reportArgumentType]
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, # pyright: ignore[reportArgumentType]
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, # pyright: ignore[reportArgumentType]
allowed_file_types=entity_dictionary[k].allowed_file_types or [],
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [],
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
),
)
for k, v in user_inputs.items()

View File

@ -23,7 +23,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
:param blocking_response: blocking response
:return:
"""
return dict(blocking_response.to_dict())
return dict(blocking_response.model_dump())
@classmethod
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(cast(dict, data))
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump())
yield response_chunk
@classmethod
@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict()))
else:
response_chunk.update(sub_stream_response.to_dict())
response_chunk.update(sub_stream_response.model_dump())
yield response_chunk

View File

@ -7,7 +7,7 @@ import threading
import time
import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Optional, Union, cast, overload
from typing import Any, Literal, Union, cast, overload
from flask import Flask, current_app
from pydantic import ValidationError
@ -69,7 +69,7 @@ class PipelineGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
workflow_thread_pool_id: Optional[str],
workflow_thread_pool_id: str | None,
is_retry: bool = False,
) -> Generator[Mapping | str, None, None]: ...
@ -84,7 +84,7 @@ class PipelineGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
workflow_thread_pool_id: Optional[str],
workflow_thread_pool_id: str | None,
is_retry: bool = False,
) -> Mapping[str, Any]: ...
@ -99,7 +99,7 @@ class PipelineGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_thread_pool_id: Optional[str],
workflow_thread_pool_id: str | None,
is_retry: bool = False,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
@ -113,7 +113,7 @@ class PipelineGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_thread_pool_id: Optional[str] = None,
workflow_thread_pool_id: str | None = None,
is_retry: bool = False,
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
# Add null check for dataset
@ -314,7 +314,7 @@ class PipelineGenerator(BaseAppGenerator):
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
workflow_thread_pool_id: Optional[str] = None,
workflow_thread_pool_id: str | None = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@ -331,7 +331,7 @@ class PipelineGenerator(BaseAppGenerator):
"""
with preserve_flask_contexts(flask_app, context_vars=context):
# init queue manager
workflow = db.session.query(Workflow).filter(Workflow.id == workflow_id).first()
workflow = db.session.query(Workflow).where(Workflow.id == workflow_id).first()
if not workflow:
raise ValueError(f"Workflow not found: {workflow_id}")
queue_manager = PipelineQueueManager(
@ -568,7 +568,7 @@ class PipelineGenerator(BaseAppGenerator):
queue_manager: AppQueueManager,
context: contextvars.Context,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None,
workflow_thread_pool_id: str | None = None,
) -> None:
"""
Generate worker in a new thread.
@ -744,7 +744,7 @@ class PipelineGenerator(BaseAppGenerator):
Format datasource info list.
"""
if datasource_type == "online_drive":
all_files = []
all_files: list[Mapping[str, Any]] = []
datasource_node_data = None
datasource_nodes = workflow.graph_dict.get("nodes", [])
for datasource_node in datasource_nodes:
@ -801,11 +801,11 @@ class PipelineGenerator(BaseAppGenerator):
self,
datasource_runtime: OnlineDriveDatasourcePlugin,
prefix: str,
bucket: Optional[str],
bucket: str | None,
user_id: str,
all_files: list,
datasource_info: Mapping[str, Any],
next_page_parameters: Optional[dict] = None,
next_page_parameters: dict | None = None,
):
"""
Get files in a folder.

View File

@ -1,6 +1,6 @@
import logging
import time
from typing import Optional, cast
from typing import cast
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
@ -40,7 +40,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
workflow_thread_pool_id: Optional[str] = None,
workflow_thread_pool_id: str | None = None,
) -> None:
"""
:param application_generate_entity: application generate entity
@ -69,13 +69,13 @@ class PipelineRunner(WorkflowBasedAppRunner):
user_id = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first()
pipeline = db.session.query(Pipeline).where(Pipeline.id == app_config.app_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
@ -188,16 +188,14 @@ class PipelineRunner(WorkflowBasedAppRunner):
)
self._handle_event(workflow_entry, event)
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]:
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id
)
.where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id)
.first()
)
@ -205,7 +203,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
return workflow
def _init_rag_pipeline_graph(
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: Optional[str] = None
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: str | None = None
) -> Graph:
"""
Init pipeline graph
@ -272,7 +270,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
if document_id and dataset_id:
document = (
db.session.query(Document)
.filter(Document.id == document_id, Document.dataset_id == dataset_id)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
if document:

View File

@ -242,7 +242,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
node_id: str
inputs: dict
single_loop_run: Optional[SingleLoopRunEntity] = None
single_loop_run: SingleLoopRunEntity | None = None
class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
@ -256,9 +256,9 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
datasource_info: Mapping[str, Any]
dataset_id: str
batch: str
document_id: Optional[str] = None
original_document_id: Optional[str] = None
start_node_id: Optional[str] = None
document_id: str | None = None
original_document_id: str | None = None
start_node_id: str | None = None
# Import TraceQueueManager at runtime to resolve forward references

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel, ConfigDict, Field
@ -252,8 +252,8 @@ class NodeStartStreamResponse(StreamResponse):
node_type: str
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: str | None = None
inputs: Mapping[str, Any] | None = None
inputs_truncated: bool = False
created_at: int
extras: dict[str, object] = Field(default_factory=dict)
@ -310,12 +310,12 @@ class NodeFinishStreamResponse(StreamResponse):
node_type: str
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: str | None = None
inputs: Mapping[str, Any] | None = None
inputs_truncated: bool = False
process_data: Optional[Mapping[str, Any]] = None
process_data: Mapping[str, Any] | None = None
process_data_truncated: bool = False
outputs: Optional[Mapping[str, Any]] = None
outputs: Mapping[str, Any] | None = None
outputs_truncated: bool = True
status: str
error: str | None = None
@ -382,12 +382,12 @@ class NodeRetryStreamResponse(StreamResponse):
node_type: str
title: str
index: int
predecessor_node_id: Optional[str] = None
inputs: Optional[Mapping[str, Any]] = None
predecessor_node_id: str | None = None
inputs: Mapping[str, Any] | None = None
inputs_truncated: bool = False
process_data: Optional[Mapping[str, Any]] = None
process_data: Mapping[str, Any] | None = None
process_data_truncated: bool = False
outputs: Optional[Mapping[str, Any]] = None
outputs: Mapping[str, Any] | None = None
outputs_truncated: bool = False
status: str
error: str | None = None
@ -503,11 +503,11 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
node_id: str
node_type: str
title: str
outputs: Optional[Mapping] = None
outputs: Mapping | None = None
outputs_truncated: bool = False
created_at: int
extras: Optional[dict] = None
inputs: Optional[Mapping] = None
extras: dict | None = None
inputs: Mapping | None = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus
error: str | None = None
@ -541,8 +541,8 @@ class LoopNodeStartStreamResponse(StreamResponse):
metadata: Mapping = {}
inputs: Mapping = {}
inputs_truncated: bool = False
parallel_id: Optional[str] = None
parallel_start_node_id: Optional[str] = None
parallel_id: str | None = None
parallel_start_node_id: str | None = None
event: StreamEvent = StreamEvent.LOOP_STARTED
workflow_run_id: str
@ -590,11 +590,11 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
node_id: str
node_type: str
title: str
outputs: Optional[Mapping] = None
outputs: Mapping | None = None
outputs_truncated: bool = False
created_at: int
extras: Optional[dict] = None
inputs: Optional[Mapping] = None
extras: dict | None = None
inputs: Mapping | None = None
inputs_truncated: bool = False
status: WorkflowNodeExecutionStatus
error: str | None = None

View File

@ -17,9 +17,9 @@ class DatasourceRuntime(BaseModel):
"""
tenant_id: str
datasource_id: Optional[str] = None
datasource_id: str | None = None
invoke_from: Optional["InvokeFrom"] = None
datasource_invoke_from: Optional[DatasourceInvokeFrom] = None
datasource_invoke_from: DatasourceInvokeFrom | None = None
credentials: dict[str, Any] = Field(default_factory=dict)
runtime_parameters: dict[str, Any] = Field(default_factory=dict)

View File

@ -6,7 +6,7 @@ import os
import time
from datetime import datetime
from mimetypes import guess_extension, guess_type
from typing import Optional, Union
from typing import Union
from uuid import uuid4
import httpx
@ -62,10 +62,10 @@ class DatasourceFileManager:
*,
user_id: str,
tenant_id: str,
conversation_id: Optional[str],
conversation_id: str | None,
file_binary: bytes,
mimetype: str,
filename: Optional[str] = None,
filename: str | None = None,
) -> UploadFile:
extension = guess_extension(mimetype) or ".bin"
unique_name = uuid4().hex
@ -106,7 +106,7 @@ class DatasourceFileManager:
user_id: str,
tenant_id: str,
file_url: str,
conversation_id: Optional[str] = None,
conversation_id: str | None = None,
) -> ToolFile:
# try to download image
try:
@ -152,13 +152,7 @@ class DatasourceFileManager:
:return: the binary of the file, mime type
"""
upload_file: UploadFile | None = (
db.session.query(UploadFile)
.filter(
UploadFile.id == id,
)
.first()
)
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == id).first()
if not upload_file:
return None
@ -176,13 +170,7 @@ class DatasourceFileManager:
:return: the binary of the file, mime type
"""
message_file: MessageFile | None = (
db.session.query(MessageFile)
.filter(
MessageFile.id == id,
)
.first()
)
message_file: MessageFile | None = db.session.query(MessageFile).where(MessageFile.id == id).first()
# Check if message_file is not None
if message_file is not None:
@ -196,13 +184,7 @@ class DatasourceFileManager:
else:
tool_file_id = None
tool_file: ToolFile | None = (
db.session.query(ToolFile)
.filter(
ToolFile.id == tool_file_id,
)
.first()
)
tool_file: ToolFile | None = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first()
if not tool_file:
return None
@ -220,13 +202,7 @@ class DatasourceFileManager:
:return: the binary of the file, mime type
"""
upload_file: UploadFile | None = (
db.session.query(UploadFile)
.filter(
UploadFile.id == upload_file_id,
)
.first()
)
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if not upload_file:
return None, None

View File

@ -46,7 +46,7 @@ class DatasourceManager:
provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id)
if not provider_entity:
raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found")
controller: DatasourcePluginProviderController | None = None
match datasource_type:
case DatasourceProviderType.ONLINE_DOCUMENT:
controller = OnlineDocumentDatasourcePluginProviderController(
@ -79,7 +79,11 @@ class DatasourceManager:
case _:
raise ValueError(f"Unsupported datasource type: {datasource_type}")
datasource_plugin_providers[provider_id] = controller
if controller:
datasource_plugin_providers[provider_id] = controller
if controller is None:
raise DatasourceProviderNotFoundError(f"Datasource provider {provider_id} not found.")
return controller

View File

@ -12,9 +12,9 @@ class DatasourceApiEntity(BaseModel):
name: str # identifier
label: I18nObject # label
description: I18nObject
parameters: Optional[list[DatasourceParameter]] = None
parameters: list[DatasourceParameter] | None = None
labels: list[str] = Field(default_factory=list)
output_schema: Optional[dict] = None
output_schema: dict | None = None
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
@ -28,12 +28,12 @@ class DatasourceProviderApiEntity(BaseModel):
icon: str | dict
label: I18nObject # label
type: str
masked_credentials: Optional[dict] = None
original_credentials: Optional[dict] = None
masked_credentials: dict | None = None
original_credentials: dict | None = None
is_team_authorization: bool = False
allow_delete: bool = True
plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource")
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource")
plugin_id: str | None = Field(default="", description="The plugin id of the datasource")
plugin_unique_identifier: str | None = Field(default="", description="The unique identifier of the datasource")
datasources: list[DatasourceApiEntity] = Field(default_factory=list)
labels: list[str] = Field(default_factory=list)
@ -62,7 +62,7 @@ class DatasourceProviderApiEntity(BaseModel):
"description": self.description.to_dict(),
"icon": self.icon,
"label": self.label.to_dict(),
"type": self.type.value,
"type": self.type,
"team_credentials": self.masked_credentials,
"is_team_authorization": self.is_team_authorization,
"allow_delete": self.allow_delete,

View File

@ -1,5 +1,3 @@
from typing import Optional
from pydantic import BaseModel, Field
@ -9,9 +7,9 @@ class I18nObject(BaseModel):
"""
en_US: str
zh_Hans: Optional[str] = Field(default=None)
pt_BR: Optional[str] = Field(default=None)
ja_JP: Optional[str] = Field(default=None)
zh_Hans: str | None = Field(default=None)
pt_BR: str | None = Field(default=None)
ja_JP: str | None = Field(default=None)
def __init__(self, **data):
super().__init__(**data)

View File

@ -1,6 +1,6 @@
import enum
from enum import Enum
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel, Field, ValidationInfo, field_validator
from yarl import URL
@ -80,7 +80,7 @@ class DatasourceParameter(PluginParameter):
name: str,
typ: DatasourceParameterType,
required: bool,
options: Optional[list[str]] = None,
options: list[str] | None = None,
) -> "DatasourceParameter":
"""
get a simple datasource parameter
@ -120,14 +120,14 @@ class DatasourceIdentity(BaseModel):
name: str = Field(..., description="The name of the datasource")
label: I18nObject = Field(..., description="The label of the datasource")
provider: str = Field(..., description="The provider of the datasource")
icon: Optional[str] = None
icon: str | None = None
class DatasourceEntity(BaseModel):
identity: DatasourceIdentity
parameters: list[DatasourceParameter] = Field(default_factory=list)
description: I18nObject = Field(..., description="The label of the datasource")
output_schema: Optional[dict] = None
output_schema: dict | None = None
@field_validator("parameters", mode="before")
@classmethod
@ -141,7 +141,7 @@ class DatasourceProviderIdentity(BaseModel):
description: I18nObject = Field(..., description="The description of the tool")
icon: str = Field(..., description="The icon of the tool")
label: I18nObject = Field(..., description="The label of the tool")
tags: Optional[list[ToolLabelEnum]] = Field(
tags: list[ToolLabelEnum] | None = Field(
default=[],
description="The tags of the tool",
)
@ -169,7 +169,7 @@ class DatasourceProviderEntity(BaseModel):
identity: DatasourceProviderIdentity
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
oauth_schema: Optional[OAuthSchema] = None
oauth_schema: OAuthSchema | None = None
provider_type: DatasourceProviderType
@ -183,8 +183,8 @@ class DatasourceInvokeMeta(BaseModel):
"""
time_cost: float = Field(..., description="The time cost of the tool invoke")
error: Optional[str] = None
tool_config: Optional[dict] = None
error: str | None = None
tool_config: dict | None = None
@classmethod
def empty(cls) -> "DatasourceInvokeMeta":
@ -233,10 +233,10 @@ class OnlineDocumentPage(BaseModel):
page_id: str = Field(..., description="The page id")
page_name: str = Field(..., description="The page title")
page_icon: Optional[dict] = Field(None, description="The page icon")
page_icon: dict | None = Field(None, description="The page icon")
type: str = Field(..., description="The type of the page")
last_edited_time: str = Field(..., description="The last edited time")
parent_id: Optional[str] = Field(None, description="The parent page id")
parent_id: str | None = Field(None, description="The parent page id")
class OnlineDocumentInfo(BaseModel):
@ -244,9 +244,9 @@ class OnlineDocumentInfo(BaseModel):
Online document info
"""
workspace_id: Optional[str] = Field(None, description="The workspace id")
workspace_name: Optional[str] = Field(None, description="The workspace name")
workspace_icon: Optional[str] = Field(None, description="The workspace icon")
workspace_id: str | None = Field(None, description="The workspace id")
workspace_name: str | None = Field(None, description="The workspace name")
workspace_icon: str | None = Field(None, description="The workspace icon")
total: int = Field(..., description="The total number of documents")
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")
@ -307,10 +307,10 @@ class WebSiteInfo(BaseModel):
Website info
"""
status: Optional[str] = Field(..., description="crawl job status")
web_info_list: Optional[list[WebSiteInfoDetail]] = []
total: Optional[int] = Field(default=0, description="The total number of websites")
completed: Optional[int] = Field(default=0, description="The number of completed websites")
status: str | None = Field(..., description="crawl job status")
web_info_list: list[WebSiteInfoDetail] | None = []
total: int | None = Field(default=0, description="The total number of websites")
completed: int | None = Field(default=0, description="The number of completed websites")
class WebsiteCrawlMessage(BaseModel):
@ -346,10 +346,10 @@ class OnlineDriveFileBucket(BaseModel):
Online drive file bucket
"""
bucket: Optional[str] = Field(None, description="The file bucket")
bucket: str | None = Field(None, description="The file bucket")
files: list[OnlineDriveFile] = Field(..., description="The file list")
is_truncated: bool = Field(False, description="Whether the result is truncated")
next_page_parameters: Optional[dict] = Field(None, description="Parameters for fetching the next page")
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
class OnlineDriveBrowseFilesRequest(BaseModel):
@ -357,10 +357,10 @@ class OnlineDriveBrowseFilesRequest(BaseModel):
Get online drive file list request
"""
bucket: Optional[str] = Field(None, description="The file bucket")
bucket: str | None = Field(None, description="The file bucket")
prefix: str = Field(..., description="The parent folder ID")
max_keys: int = Field(20, description="Page size for pagination")
next_page_parameters: Optional[dict] = Field(None, description="Parameters for fetching the next page")
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
class OnlineDriveBrowseFilesResponse(BaseModel):
@ -377,4 +377,4 @@ class OnlineDriveDownloadFileRequest(BaseModel):
"""
id: str = Field(..., description="The id of the file")
bucket: Optional[str] = Field(None, description="The name of the bucket")
bucket: str | None = Field(None, description="The name of the bucket")

View File

@ -1,11 +1,11 @@
import logging
from collections.abc import Generator
from mimetypes import guess_extension, guess_type
from typing import Optional
from core.datasource.entities.datasource_entities import DatasourceMessage
from core.file import File, FileTransferMethod, FileType
from core.tools.tool_file_manager import ToolFileManager
from models.tools import ToolFile
logger = logging.getLogger(__name__)
@ -17,7 +17,7 @@ class DatasourceFileMessageTransformer:
messages: Generator[DatasourceMessage, None, None],
user_id: str,
tenant_id: str,
conversation_id: Optional[str] = None,
conversation_id: str | None = None,
) -> Generator[DatasourceMessage, None, None]:
"""
Transform datasource message and handle file download
@ -32,20 +32,20 @@ class DatasourceFileMessageTransformer:
try:
assert isinstance(message.message, DatasourceMessage.TextMessage)
tool_file_manager = ToolFileManager()
file = tool_file_manager.create_file_by_url(
tool_file: ToolFile | None = tool_file_manager.create_file_by_url(
user_id=user_id,
tenant_id=tenant_id,
file_url=message.message.text,
conversation_id=conversation_id,
)
if tool_file:
url = f"/files/datasources/{tool_file.id}{guess_extension(tool_file.mimetype) or '.png'}"
url = f"/files/datasources/{file.id}{guess_extension(file.mimetype) or '.png'}"
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=message.meta.copy() if message.meta is not None else {},
)
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=message.meta.copy() if message.meta is not None else {},
)
except Exception as e:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.TEXT,
@ -72,7 +72,7 @@ class DatasourceFileMessageTransformer:
# FIXME: should do a type check here.
assert isinstance(message.message.blob, bytes)
tool_file_manager = ToolFileManager()
file = tool_file_manager.create_file_by_raw(
blob_tool_file: ToolFile | None = tool_file_manager.create_file_by_raw(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
@ -80,25 +80,27 @@ class DatasourceFileMessageTransformer:
mimetype=mimetype,
filename=filename,
)
url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mimetype))
# check if file is image
if "image" in mimetype:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.BINARY_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
if blob_tool_file:
url = cls.get_datasource_file_url(
datasource_file_id=blob_tool_file.id, extension=guess_extension(blob_tool_file.mimetype)
)
# check if file is image
if "image" in mimetype:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.IMAGE_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
else:
yield DatasourceMessage(
type=DatasourceMessage.MessageType.BINARY_LINK,
message=DatasourceMessage.TextMessage(text=url),
meta=meta.copy() if meta is not None else {},
)
elif message.type == DatasourceMessage.MessageType.FILE:
meta = message.meta or {}
file = meta.get("file", None)
file: File | None = meta.get("file")
if isinstance(file, File):
if file.transfer_method == FileTransferMethod.TOOL_FILE:
assert file.related_id is not None
@ -121,5 +123,5 @@ class DatasourceFileMessageTransformer:
yield message
@classmethod
def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str:
def get_datasource_file_url(cls, datasource_file_id: str, extension: str | None) -> str:
return f"/files/datasources/{datasource_file_id}{extension or '.bin'}"

View File

@ -3,7 +3,6 @@ import uuid
from json import dumps as json_dumps
from json import loads as json_loads
from json.decoder import JSONDecodeError
from typing import Optional
from flask import request
from requests import get
@ -169,9 +168,9 @@ class ApiBasedToolSchemaParser:
return bundles
@staticmethod
def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
parameter = parameter or {}
typ: Optional[str] = None
typ: str | None = None
if parameter.get("format") == "binary":
return ToolParameter.ToolParameterType.FILE

View File

@ -1,5 +1,3 @@
from typing import Optional
from pydantic import BaseModel
@ -16,7 +14,7 @@ class QAPreviewDetail(BaseModel):
class IndexingEstimate(BaseModel):
total_segments: int
preview: list[PreviewDetail]
qa_preview: Optional[list[QAPreviewDetail]] = None
qa_preview: list[QAPreviewDetail] | None = None
class PipelineDataset(BaseModel):
@ -30,10 +28,10 @@ class PipelineDocument(BaseModel):
id: str
position: int
data_source_type: str
data_source_info: Optional[dict] = None
data_source_info: dict | None = None
name: str
indexing_status: str
error: Optional[str] = None
error: str | None = None
enabled: bool

View File

@ -25,9 +25,7 @@ def get_signed_file_url(upload_file_id: str, as_attachment=False) -> str:
return f"{url}?{query_string}"
def get_signed_file_url_for_plugin(
filename: str, mimetype: str, tenant_id: str, user_id: str, session_id: str | None
) -> str:
def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str:
# Plugin access should use internal URL for Docker network communication
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
url = f"{base_url}/files/upload/for-plugin"
@ -37,9 +35,7 @@ def get_signed_file_url_for_plugin(
msg = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
url_user_id = session_id or user_id
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={url_user_id}&tenant_id={tenant_id}"
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={user_id}&tenant_id={tenant_id}"
def verify_plugin_file_signature(

View File

@ -119,6 +119,7 @@ class File(BaseModel):
assert self.related_id is not None
assert self.extension is not None
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension)
return None
def to_plugin_parameter(self) -> dict[str, Any]:
return {

View File

@ -1,7 +1,7 @@
import datetime
from collections.abc import Mapping
from enum import StrEnum, auto
from typing import Any, Optional
from typing import Any
from packaging.version import InvalidVersion, Version
from pydantic import BaseModel, Field, field_validator, model_validator
@ -67,10 +67,10 @@ class PluginCategory(StrEnum):
class PluginDeclaration(BaseModel):
class Plugins(BaseModel):
tools: Optional[list[str]] = Field(default_factory=list[str])
models: Optional[list[str]] = Field(default_factory=list[str])
endpoints: Optional[list[str]] = Field(default_factory=list[str])
datasources: Optional[list[str]] = Field(default_factory=list[str])
tools: list[str] | None = Field(default_factory=list[str])
models: list[str] | None = Field(default_factory=list[str])
endpoints: list[str] | None = Field(default_factory=list[str])
datasources: list[str] | None = Field(default_factory=list[str])
class Meta(BaseModel):
minimum_dify_version: str | None = Field(default=None)
@ -101,11 +101,11 @@ class PluginDeclaration(BaseModel):
tags: list[str] = Field(default_factory=list)
repo: str | None = Field(default=None)
verified: bool = Field(default=False)
tool: Optional[ToolProviderEntity] = None
model: Optional[ProviderEntity] = None
endpoint: Optional[EndpointProviderDeclaration] = None
agent_strategy: Optional[AgentStrategyProviderEntity] = None
datasource: Optional[DatasourceProviderEntity] = None
tool: ToolProviderEntity | None = None
model: ProviderEntity | None = None
endpoint: EndpointProviderDeclaration | None = None
agent_strategy: AgentStrategyProviderEntity | None = None
datasource: DatasourceProviderEntity | None = None
meta: Meta
@field_validator("version")

View File

@ -1,6 +1,6 @@
from collections.abc import Mapping
from enum import Enum
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel, Field
@ -27,12 +27,12 @@ class DatasourceErrorEvent(BaseDatasourceEvent):
class DatasourceCompletedEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.COMPLETED.value
data: Mapping[str, Any] | list = Field(..., description="result")
total: Optional[int] = Field(default=0, description="total")
completed: Optional[int] = Field(default=0, description="completed")
time_consuming: Optional[float] = Field(default=0.0, description="time consuming")
total: int | None = Field(default=0, description="total")
completed: int | None = Field(default=0, description="completed")
time_consuming: float | None = Field(default=0.0, description="time consuming")
class DatasourceProcessingEvent(BaseDatasourceEvent):
event: str = DatasourceStreamEvent.PROCESSING.value
total: Optional[int] = Field(..., description="total")
completed: Optional[int] = Field(..., description="completed")
total: int | None = Field(..., description="total")
completed: int | None = Field(..., description="completed")

View File

@ -1,5 +1,3 @@
from typing import Optional
from pydantic import BaseModel, ConfigDict
from models.dataset import Document
@ -11,7 +9,7 @@ class NotionInfo(BaseModel):
Notion import info.
"""
credential_id: Optional[str] = None
credential_id: str | None = None
notion_workspace_id: str
notion_obj_id: str
notion_page_type: str

View File

@ -1,7 +1,7 @@
import json
import logging
import operator
from typing import Any, Optional, cast
from typing import Any, cast
import requests
@ -35,9 +35,9 @@ class NotionExtractor(BaseExtractor):
notion_obj_id: str,
notion_page_type: str,
tenant_id: str,
document_model: Optional[DocumentModel] = None,
notion_access_token: Optional[str] = None,
credential_id: Optional[str] = None,
document_model: DocumentModel | None = None,
notion_access_token: str | None = None,
credential_id: str | None = None,
):
self._notion_access_token = None
self._document_model = document_model
@ -369,7 +369,7 @@ class NotionExtractor(BaseExtractor):
return cast(str, data["last_edited_time"])
@classmethod
def _get_access_token(cls, tenant_id: str, credential_id: Optional[str]) -> str:
def _get_access_token(cls, tenant_id: str, credential_id: str | None) -> str:
# get credential from tenant_id and credential_id
if not credential_id:
raise Exception(f"No credential id found for tenant {tenant_id}")

View File

@ -7,7 +7,7 @@ import json
import logging
from collections.abc import Callable, Mapping, Sequence
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional, TypeVar, Union
from typing import Any, TypeVar, Union
import psycopg2.errors
from sqlalchemy import UnaryExpression, asc, desc, select
@ -291,7 +291,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
return None
value_json = _deterministic_json_dump(json_encodable_value)
assert value_json is not None, "value_json should be None here."
assert value_json is not None, "value_json should be not None here."
suffix = type_.value
upload_file = self._file_service.upload_file(
@ -530,7 +530,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
def get_by_workflow_run(
self,
workflow_run_id: str,
order_config: Optional[OrderConfig] = None,
order_config: OrderConfig | None = None,
triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
) -> Sequence[WorkflowNodeExecution]:
"""

View File

@ -85,7 +85,7 @@ class SchemaRegistry:
except (OSError, json.JSONDecodeError) as e:
print(f"Warning: failed to load schema {version}/{schema_name}: {e}")
def get_schema(self, uri: str) -> Optional[Any]:
def get_schema(self, uri: str) -> Any | None:
"""Retrieves a schema by URI with version support"""
version, schema_name = self._parse_uri(uri)
if not version or not schema_name:

View File

@ -3,7 +3,7 @@ import re
import threading
from collections import deque
from dataclasses import dataclass
from typing import Any, Optional, Union
from typing import Any, Union
from core.schemas.registry import SchemaRegistry
@ -53,8 +53,8 @@ class QueueItem:
"""Represents an item in the BFS queue"""
current: Any
parent: Optional[Any]
key: Optional[Union[str, int]]
parent: Any | None
key: Union[str, int] | None
depth: int
ref_path: set[str]
@ -65,7 +65,7 @@ class SchemaResolver:
_cache: dict[str, SchemaDict] = {}
_cache_lock = threading.Lock()
def __init__(self, registry: Optional[SchemaRegistry] = None, max_depth: int = 10):
def __init__(self, registry: SchemaRegistry | None = None, max_depth: int = 10):
"""
Initialize the schema resolver
@ -202,7 +202,7 @@ class SchemaResolver:
)
)
def _get_resolved_schema(self, ref_uri: str) -> Optional[SchemaDict]:
def _get_resolved_schema(self, ref_uri: str) -> SchemaDict | None:
"""Get resolved schema from cache or registry"""
# Check cache first
with self._cache_lock:
@ -223,7 +223,7 @@ class SchemaResolver:
def resolve_dify_schema_refs(
schema: SchemaType, registry: Optional[SchemaRegistry] = None, max_depth: int = 30
schema: SchemaType, registry: SchemaRegistry | None = None, max_depth: int = 30
) -> SchemaType:
"""
Resolve $ref references in Dify schema to actual schema content

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from core.schemas.registry import SchemaRegistry
@ -7,7 +7,7 @@ from core.schemas.registry import SchemaRegistry
class SchemaManager:
"""Schema manager provides high-level schema operations"""
def __init__(self, registry: Optional[SchemaRegistry] = None):
def __init__(self, registry: SchemaRegistry | None = None):
self.registry = registry or SchemaRegistry.default_registry()
def get_all_schema_definitions(self, version: str = "v1") -> list[Mapping[str, Any]]:
@ -22,7 +22,7 @@ class SchemaManager:
"""
return self.registry.get_all_schemas_for_version(version)
def get_schema_by_name(self, schema_name: str, version: str = "v1") -> Optional[Mapping[str, Any]]:
def get_schema_by_name(self, schema_name: str, version: str = "v1") -> Mapping[str, Any] | None:
"""
Get a specific schema by name

View File

@ -67,7 +67,7 @@ class VariablePool(BaseModel):
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
# Add rag pipeline variables to the variable pool
if self.rag_pipeline_variables:
rag_pipeline_variables_map = defaultdict(dict)
rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict)
for rag_var in self.rag_pipeline_variables:
node_id = rag_var.variable.belong_to_node_id
key = rag_var.variable.variable

View File

@ -8,7 +8,7 @@ and don't contain implementation details like tenant_id, app_id, etc.
from collections.abc import Mapping
from datetime import datetime
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel, Field, PrivateAttr
@ -53,9 +53,9 @@ class WorkflowNodeExecution(BaseModel):
# Execution data
# The `inputs` and `outputs` fields hold the full content
inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node
process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data
outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node
inputs: Mapping[str, Any] | None = None # Input variables used by this node
process_data: Mapping[str, Any] | None = None # Intermediate processing data
outputs: Mapping[str, Any] | None = None # Output variables produced by this node
# Execution state
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status

View File

@ -13,4 +13,4 @@ class WorkflowNodeRunFailedError(Exception):
@property
def error(self) -> str:
return self._error
return self._error

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -19,7 +19,7 @@ from core.file.enums import FileTransferMethod, FileType
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.variables.segments import ArrayAnySegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
@ -50,7 +50,7 @@ class DatasourceNode(Node):
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = DatasourceNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -59,7 +59,7 @@ class DatasourceNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -87,29 +87,18 @@ class DatasourceNode(Node):
raise DatasourceNodeError("Invalid datasource info format")
datasource_info: dict[str, Any] = datasource_info_value
# get datasource runtime
try:
from core.datasource.datasource_manager import DatasourceManager
from core.datasource.datasource_manager import DatasourceManager
if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set")
if datasource_type is None:
raise DatasourceNodeError("Datasource type is not set")
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
datasource_name=node_data.datasource_name or "",
tenant_id=self.tenant_id,
datasource_type=DatasourceProviderType.value_of(datasource_type),
)
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
except DatasourceNodeError as e:
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
error=f"Failed to get datasource runtime: {str(e)}",
error_type=type(e).__name__,
)
)
datasource_runtime = DatasourceManager.get_datasource_runtime(
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
datasource_name=node_data.datasource_name or "",
tenant_id=self.tenant_id,
datasource_type=DatasourceProviderType.value_of(datasource_type),
)
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
parameters_for_log = datasource_info
@ -179,7 +168,7 @@ class DatasourceNode(Node):
related_id = datasource_info.get("related_id")
if not related_id:
raise DatasourceNodeError("File is not exist")
upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first()
upload_file = db.session.query(UploadFile).where(UploadFile.id == related_id).first()
if not upload_file:
raise ValueError("Invalid upload file Info")
@ -282,27 +271,6 @@ class DatasourceNode(Node):
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _append_variables_recursively(
self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue
):
"""
Append variables recursively
:param node_id: node id
:param variable_key_list: variable key list
:param variable_value: variable value
:return:
"""
variable_pool.add([node_id] + [".".join(variable_key_list)], variable_value)
# if variable_value is a dict, then recursively append variables
if isinstance(variable_value, dict):
for key, value in variable_value.items():
# construct new key list
new_key_list = variable_key_list + [key]
self._append_variables_recursively(
variable_pool=variable_pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
@ -423,13 +391,6 @@ class DatasourceNode(Node):
)
elif message.type == DatasourceMessage.MessageType.JSON:
assert isinstance(message.message, DatasourceMessage.JsonMessage)
if self.node_type == NodeType.AGENT:
msg_metadata = message.message.json_object.pop("execution_metadata", {})
agent_execution_metadata = {
key: value
for key, value in msg_metadata.items()
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
}
json.append(message.message.json_object)
elif message.type == DatasourceMessage.MessageType.LINK:
assert isinstance(message.message, DatasourceMessage.TextMessage)

View File

@ -1,4 +1,4 @@
from typing import Any, Literal, Optional, Union
from typing import Any, Literal, Union
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
@ -10,7 +10,7 @@ class DatasourceEntity(BaseModel):
plugin_id: str
provider_name: str # redundancy
provider_type: str
datasource_name: Optional[str] = "local_file"
datasource_name: str | None = "local_file"
datasource_configurations: dict[str, Any] | None = None
plugin_unique_identifier: str | None = None # redundancy
@ -19,7 +19,7 @@ class DatasourceNodeData(BaseNodeData, DatasourceEntity):
class DatasourceInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Optional[Literal["mixed", "variable", "constant"]] = None
type: Literal["mixed", "variable", "constant"] | None = None
@field_validator("type", mode="before")
@classmethod

View File

@ -1,4 +1,4 @@
from typing import Literal, Optional, Union
from typing import Literal, Union
from pydantic import BaseModel
@ -65,12 +65,12 @@ class RetrievalSetting(BaseModel):
search_method: Literal["semantic_search", "keyword_search", "fulltext_search", "hybrid_search"]
top_k: int
score_threshold: Optional[float] = 0.5
score_threshold: float | None = 0.5
score_threshold_enabled: bool = False
reranking_mode: str = "reranking_model"
reranking_enable: bool = True
reranking_model: Optional[RerankingModelConfig] = None
weights: Optional[WeightedScoreConfig] = None
reranking_model: RerankingModelConfig | None = None
weights: WeightedScoreConfig | None = None
class IndexMethod(BaseModel):
@ -107,10 +107,10 @@ class OnlineDocumentInfo(BaseModel):
"""
provider: str
workspace_id: Optional[str] = None
workspace_id: str | None = None
page_id: str
page_type: str
icon: Optional[OnlineDocumentIcon] = None
icon: OnlineDocumentIcon | None = None
class WebsiteInfo(BaseModel):

View File

@ -2,7 +2,7 @@ import datetime
import logging
import time
from collections.abc import Mapping
from typing import Any, Optional, cast
from typing import Any, cast
from sqlalchemy import func, select
@ -43,7 +43,7 @@ class KnowledgeIndexNode(Node):
def init_node_data(self, data: Mapping[str, Any]) -> None:
self._node_data = KnowledgeIndexNodeData.model_validate(data)
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
@ -52,7 +52,7 @@ class KnowledgeIndexNode(Node):
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self) -> Optional[str]:
def _get_description(self) -> str | None:
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
@ -160,7 +160,7 @@ class KnowledgeIndexNode(Node):
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
document.word_count = (
db.session.query(func.sum(DocumentSegment.word_count))
.filter(
.where(
DocumentSegment.document_id == document.id,
DocumentSegment.dataset_id == dataset.id,
)
@ -168,7 +168,7 @@ class KnowledgeIndexNode(Node):
)
db.session.add(document)
# update document segment status
db.session.query(DocumentSegment).filter(
db.session.query(DocumentSegment).where(
DocumentSegment.document_id == document.id,
DocumentSegment.dataset_id == dataset.id,
).update(

View File

@ -92,7 +92,9 @@ class LoopNode(Node):
if self._node_data.loop_variables:
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value),
"variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value)
if isinstance(var.value, list)
else None,
}
for loop_variable in self._node_data.loop_variables:
if loop_variable.value_type not in value_processor:

View File

@ -326,7 +326,7 @@ def _build_from_datasource_file(
) -> File:
datasource_file = (
db.session.query(UploadFile)
.filter(
.where(
UploadFile.id == mapping.get("datasource_file_id"),
UploadFile.tenant_id == tenant_id,
)

View File

@ -10,7 +10,7 @@ import re
import time
from datetime import datetime
from json import JSONDecodeError
from typing import Any, Optional, cast
from typing import Any, cast
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func, select
@ -76,13 +76,13 @@ class Dataset(Base):
@property
def total_documents(self):
return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
@property
def total_available_documents(self):
return (
db.session.query(func.count(Document.id))
.filter(
.where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
@ -173,10 +173,10 @@ class Dataset(Base):
)
@property
def doc_form(self) -> Optional[str]:
def doc_form(self) -> str | None:
if self.chunk_structure:
return self.chunk_structure
document = db.session.query(Document).filter(Document.dataset_id == self.id).first()
document = db.session.query(Document).where(Document.dataset_id == self.id).first()
if document:
return document.doc_form
return None
@ -234,7 +234,7 @@ class Dataset(Base):
@property
def is_published(self):
if self.pipeline_id:
pipeline = db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first()
pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first()
if pipeline:
return pipeline.is_published
return False
@ -1244,7 +1244,7 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
@property
def created_user_name(self):
account = db.session.query(Account).filter(Account.id == self.created_by).first()
account = db.session.query(Account).where(Account.id == self.created_by).first()
if account:
return account.name
return ""
@ -1274,7 +1274,7 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
@property
def created_user_name(self):
account = db.session.query(Account).filter(Account.id == self.created_by).first()
account = db.session.query(Account).where(Account.id == self.created_by).first()
if account:
return account.name
return ""
@ -1297,7 +1297,7 @@ class Pipeline(Base): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
def retrieve_dataset(self, session: Session):
return session.query(Dataset).filter(Dataset.pipeline_id == self.id).first()
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
class DocumentPipelineExecutionLog(Base):

View File

@ -1546,7 +1546,7 @@ class WorkflowDraftVariableFile(Base):
comment="Size of the original variable content in bytes",
)
length: Mapped[Optional[int]] = mapped_column(
length: Mapped[int | None] = mapped_column(
sa.Integer,
nullable=True,
comment=(

View File

@ -7,7 +7,7 @@ import time
import uuid
from collections import Counter
from collections.abc import Sequence
from typing import Any, Literal, Optional
from typing import Any, Literal
import sqlalchemy as sa
from sqlalchemy import exists, func, select
@ -315,8 +315,8 @@ class DatasetService:
return dataset
@staticmethod
def get_dataset(dataset_id) -> Optional[Dataset]:
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
def get_dataset(dataset_id) -> Dataset | None:
dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first()
return dataset
@staticmethod
@ -419,7 +419,7 @@ class DatasetService:
def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str):
dataset = (
db.session.query(Dataset)
.filter(
.where(
Dataset.id != dataset_id,
Dataset.name == name,
Dataset.tenant_id == tenant_id,

View File

@ -77,7 +77,7 @@ class DatasourceProviderService:
provider_id=f"{plugin_id}/{provider}",
credential_type=CredentialType.of(datasource_provider.auth_type),
)
encrypted_credentials = raw_credentials.copy()
encrypted_credentials = dict(raw_credentials)
for key, value in encrypted_credentials.items():
if key in provider_credential_secret_variables:
encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value)
@ -690,7 +690,7 @@ class DatasourceProviderService:
# Get all provider configurations of the current workspace
datasource_providers: list[DatasourceProvider] = (
db.session.query(DatasourceProvider)
.filter(
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,
@ -862,7 +862,7 @@ class DatasourceProviderService:
# Get all provider configurations of the current workspace
datasource_providers: list[DatasourceProvider] = (
db.session.query(DatasourceProvider)
.filter(
.where(
DatasourceProvider.tenant_id == tenant_id,
DatasourceProvider.provider == provider,
DatasourceProvider.plugin_id == plugin_id,

View File

@ -1,13 +1,13 @@
from typing import Literal, Optional
from typing import Literal
from pydantic import BaseModel, field_validator
class IconInfo(BaseModel):
icon: str
icon_background: Optional[str] = None
icon_type: Optional[str] = None
icon_url: Optional[str] = None
icon_background: str | None = None
icon_type: str | None = None
icon_url: str | None = None
class PipelineTemplateInfoEntity(BaseModel):
@ -21,8 +21,8 @@ class RagPipelineDatasetCreateEntity(BaseModel):
description: str
icon_info: IconInfo
permission: str
partial_member_list: Optional[list[str]] = None
yaml_content: Optional[str] = None
partial_member_list: list[str] | None = None
yaml_content: str | None = None
class RerankingModelConfig(BaseModel):
@ -30,8 +30,8 @@ class RerankingModelConfig(BaseModel):
Reranking Model Config.
"""
reranking_provider_name: Optional[str] = ""
reranking_model_name: Optional[str] = ""
reranking_provider_name: str | None = ""
reranking_model_name: str | None = ""
class VectorSetting(BaseModel):
@ -57,8 +57,8 @@ class WeightedScoreConfig(BaseModel):
Weighted score Config.
"""
vector_setting: Optional[VectorSetting]
keyword_setting: Optional[KeywordSetting]
vector_setting: VectorSetting | None
keyword_setting: KeywordSetting | None
class EmbeddingSetting(BaseModel):
@ -85,12 +85,12 @@ class RetrievalSetting(BaseModel):
search_method: Literal["semantic_search", "fulltext_search", "keyword_search", "hybrid_search"]
top_k: int
score_threshold: Optional[float] = 0.5
score_threshold: float | None = 0.5
score_threshold_enabled: bool = False
reranking_mode: Optional[str] = "reranking_model"
reranking_enable: Optional[bool] = True
reranking_model: Optional[RerankingModelConfig] = None
weights: Optional[WeightedScoreConfig] = None
reranking_mode: str | None = "reranking_model"
reranking_enable: bool | None = True
reranking_model: RerankingModelConfig | None = None
weights: WeightedScoreConfig | None = None
class IndexMethod(BaseModel):
@ -112,7 +112,7 @@ class KnowledgeConfiguration(BaseModel):
indexing_technique: Literal["high_quality", "economy"]
embedding_model_provider: str = ""
embedding_model: str = ""
keyword_number: Optional[int] = 10
keyword_number: int | None = 10
retrieval_model: RetrievalSetting
@field_validator("embedding_model_provider", mode="before")

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel
@ -7,9 +7,9 @@ from pydantic import BaseModel
class DatasourceNodeRunApiEntity(BaseModel):
pipeline_id: str
node_id: str
inputs: Mapping[str, Any]
inputs: dict[str, Any]
datasource_type: str
credential_id: Optional[str] = None
credential_id: str | None = None
is_published: bool

View File

@ -108,7 +108,7 @@ class PipelineGenerateService:
Update document status to waiting
:param document_id: document id
"""
document = db.session.query(Document).filter(Document.id == document_id).first()
document = db.session.query(Document).where(Document.id == document_id).first()
if document:
document.indexing_status = "waiting"
db.session.add(document)

View File

@ -1,7 +1,6 @@
import json
from os import path
from pathlib import Path
from typing import Optional
from flask import current_app
@ -14,7 +13,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
Retrieval pipeline template from built-in, the location is constants/pipeline_templates.json
"""
builtin_data: Optional[dict] = None
builtin_data: dict | None = None
def get_type(self) -> str:
return PipelineTemplateType.BUILTIN
@ -54,7 +53,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return builtin_data.get("pipeline_templates", {}).get(language, {})
@classmethod
def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> Optional[dict]:
def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> dict | None:
"""
Fetch pipeline template detail from builtin.
:param template_id: Template ID

View File

@ -1,5 +1,3 @@
from typing import Optional
import yaml
from flask_login import current_user
@ -37,7 +35,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
pipeline_customized_templates = (
db.session.query(PipelineCustomizedTemplate)
.filter(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language)
.where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language)
.order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc())
.all()
)
@ -56,14 +54,14 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return {"pipeline_templates": recommended_pipelines_results}
@classmethod
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]:
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None:
"""
Fetch pipeline template detail from db.
:param template_id: Template ID
:return:
"""
pipeline_template = (
db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
db.session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
)
if not pipeline_template:
return None

View File

@ -1,5 +1,3 @@
from typing import Optional
import yaml
from extensions.ext_database import db
@ -33,7 +31,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
pipeline_built_in_templates: list[PipelineBuiltInTemplate] = (
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all()
db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.language == language).all()
)
recommended_pipelines_results = []
@ -53,7 +51,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return {"pipeline_templates": recommended_pipelines_results}
@classmethod
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]:
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None:
"""
Fetch pipeline template detail from db.
:param pipeline_id: Pipeline ID
@ -61,7 +59,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
"""
# is in public recommended list
pipeline_template = (
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == template_id).first()
db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.id == template_id).first()
)
if not pipeline_template:

View File

@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from typing import Optional
class PipelineTemplateRetrievalBase(ABC):
@ -10,7 +9,7 @@ class PipelineTemplateRetrievalBase(ABC):
raise NotImplementedError
@abstractmethod
def get_pipeline_template_detail(self, template_id: str) -> Optional[dict]:
def get_pipeline_template_detail(self, template_id: str) -> dict | None:
raise NotImplementedError
@abstractmethod

View File

@ -1,5 +1,4 @@
import logging
from typing import Optional
import requests
@ -36,7 +35,7 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
return PipelineTemplateType.REMOTE
@classmethod
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> Optional[dict]:
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict | None:
"""
Fetch pipeline template detail from dify official.
:param template_id: Pipeline ID

View File

@ -5,7 +5,7 @@ import threading
import time
from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Optional, Union, cast
from typing import Any, Union, cast
from uuid import uuid4
from flask_login import current_user
@ -112,7 +112,7 @@ class RagPipelineService:
return result
@classmethod
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> Optional[dict]:
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None:
"""
Get pipeline template detail.
:param template_id: template id
@ -121,12 +121,12 @@ class RagPipelineService:
if type == "built-in":
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
built_in_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id)
return built_in_result
else:
mode = "customized"
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
customized_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
customized_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id)
return customized_result
@classmethod
@ -138,7 +138,7 @@ class RagPipelineService:
"""
customized_template: PipelineCustomizedTemplate | None = (
db.session.query(PipelineCustomizedTemplate)
.filter(
.where(
PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
)
@ -151,7 +151,7 @@ class RagPipelineService:
if template_name:
template = (
db.session.query(PipelineCustomizedTemplate)
.filter(
.where(
PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
PipelineCustomizedTemplate.id != template_id,
@ -174,7 +174,7 @@ class RagPipelineService:
"""
customized_template: PipelineCustomizedTemplate | None = (
db.session.query(PipelineCustomizedTemplate)
.filter(
.where(
PipelineCustomizedTemplate.id == template_id,
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
)
@ -185,14 +185,14 @@ class RagPipelineService:
db.session.delete(customized_template)
db.session.commit()
def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
def get_draft_workflow(self, pipeline: Pipeline) -> Workflow | None:
"""
Get draft workflow
"""
# fetch draft workflow by rag pipeline
workflow = (
db.session.query(Workflow)
.filter(
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
@ -203,7 +203,7 @@ class RagPipelineService:
# return draft workflow
return workflow
def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
def get_published_workflow(self, pipeline: Pipeline) -> Workflow | None:
"""
Get published workflow
"""
@ -214,7 +214,7 @@ class RagPipelineService:
# fetch published workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.id == pipeline.workflow_id,
@ -267,7 +267,7 @@ class RagPipelineService:
*,
pipeline: Pipeline,
graph: dict,
unique_hash: Optional[str],
unique_hash: str | None,
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
@ -378,18 +378,16 @@ class RagPipelineService:
Get default block configs
"""
# return default block config
default_block_configs = []
default_block_configs: list[dict[str, Any]] = []
for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
node_class = node_class_mapping[LATEST_VERSION]
default_config = node_class.get_default_config()
if default_config:
default_block_configs.append(default_config)
default_block_configs.append(dict(default_config))
return default_block_configs
def get_default_block_config(
self, node_type: str, filters: Optional[dict] = None
) -> Optional[Mapping[str, object]]:
def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None:
"""
Get default config of node.
:param node_type: node type
@ -495,7 +493,7 @@ class RagPipelineService:
account: Account,
datasource_type: str,
is_published: bool,
credential_id: Optional[str] = None,
credential_id: str | None = None,
) -> Generator[Mapping[str, Any], None, None]:
"""
Run published workflow datasource
@ -580,10 +578,10 @@ class RagPipelineService:
)
yield start_event.model_dump()
try:
for message in online_document_result:
for online_document_message in online_document_result:
end_time = time.time()
online_document_event = DatasourceCompletedEvent(
data=message.result, time_consuming=round(end_time - start_time, 2)
data=online_document_message.result, time_consuming=round(end_time - start_time, 2)
)
yield online_document_event.model_dump()
except Exception as e:
@ -609,10 +607,10 @@ class RagPipelineService:
completed=0,
)
yield start_event.model_dump()
for message in online_drive_result:
for online_drive_message in online_drive_result:
end_time = time.time()
online_drive_event = DatasourceCompletedEvent(
data=message.result,
data=online_drive_message.result,
time_consuming=round(end_time - start_time, 2),
total=None,
completed=None,
@ -629,19 +627,20 @@ class RagPipelineService:
)
start_time = time.time()
try:
for message in website_crawl_result:
for website_crawl_message in website_crawl_result:
end_time = time.time()
if message.result.status == "completed":
crawl_event: DatasourceCompletedEvent | DatasourceProcessingEvent
if website_crawl_message.result.status == "completed":
crawl_event = DatasourceCompletedEvent(
data=message.result.web_info_list or [],
total=message.result.total,
completed=message.result.completed,
data=website_crawl_message.result.web_info_list or [],
total=website_crawl_message.result.total,
completed=website_crawl_message.result.completed,
time_consuming=round(end_time - start_time, 2),
)
else:
crawl_event = DatasourceProcessingEvent(
total=message.result.total,
completed=message.result.completed,
total=website_crawl_message.result.total,
completed=website_crawl_message.result.completed,
)
yield crawl_event.model_dump()
except Exception as e:
@ -661,7 +660,7 @@ class RagPipelineService:
account: Account,
datasource_type: str,
is_published: bool,
credential_id: Optional[str] = None,
credential_id: str | None = None,
) -> Mapping[str, Any]:
"""
Run published workflow datasource
@ -723,12 +722,12 @@ class RagPipelineService:
)
try:
variables: dict[str, Any] = {}
for message in online_document_result:
if message.type == DatasourceMessage.MessageType.VARIABLE:
assert isinstance(message.message, DatasourceMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
for online_document_message in online_document_result:
if online_document_message.type == DatasourceMessage.MessageType.VARIABLE:
assert isinstance(online_document_message.message, DatasourceMessage.VariableMessage)
variable_name = online_document_message.message.variable_name
variable_value = online_document_message.message.variable_value
if online_document_message.message.stream:
if not isinstance(variable_value, str):
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
@ -793,8 +792,9 @@ class RagPipelineService:
for event in generator:
if isinstance(event, (NodeRunSucceededEvent, NodeRunFailedEvent)):
node_run_result = event.node_run_result
# sign output files
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) or {}
if node_run_result:
# sign output files
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) or {}
break
if not node_run_result:
@ -876,7 +876,7 @@ class RagPipelineService:
if invoke_from.value == InvokeFrom.PUBLISHED.value:
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if document_id:
document = db.session.query(Document).filter(Document.id == document_id.value).first()
document = db.session.query(Document).where(Document.id == document_id.value).first()
if document:
document.indexing_status = "error"
document.error = error
@ -887,7 +887,7 @@ class RagPipelineService:
def update_workflow(
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
) -> Optional[Workflow]:
) -> Workflow | None:
"""
Update workflow attributes
@ -1016,7 +1016,7 @@ class RagPipelineService:
"""
limit = int(args.get("limit", 20))
base_query = db.session.query(WorkflowRun).filter(
base_query = db.session.query(WorkflowRun).where(
WorkflowRun.tenant_id == pipeline.tenant_id,
WorkflowRun.app_id == pipeline.id,
or_(
@ -1026,7 +1026,7 @@ class RagPipelineService:
)
if args.get("last_id"):
last_workflow_run = base_query.filter(
last_workflow_run = base_query.where(
WorkflowRun.id == args.get("last_id"),
).first()
@ -1034,7 +1034,7 @@ class RagPipelineService:
raise ValueError("Last workflow run not exists")
workflow_runs = (
base_query.filter(
base_query.where(
WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
)
.order_by(WorkflowRun.created_at.desc())
@ -1047,7 +1047,7 @@ class RagPipelineService:
has_more = False
if len(workflow_runs) == limit:
current_page_first_workflow_run = workflow_runs[-1]
rest_count = base_query.filter(
rest_count = base_query.where(
WorkflowRun.created_at < current_page_first_workflow_run.created_at,
WorkflowRun.id != current_page_first_workflow_run.id,
).count()
@ -1057,7 +1057,7 @@ class RagPipelineService:
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]:
def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> WorkflowRun | None:
"""
Get workflow run detail
@ -1066,7 +1066,7 @@ class RagPipelineService:
"""
workflow_run = (
db.session.query(WorkflowRun)
.filter(
.where(
WorkflowRun.tenant_id == pipeline.tenant_id,
WorkflowRun.app_id == pipeline.id,
WorkflowRun.id == run_id,
@ -1113,12 +1113,12 @@ class RagPipelineService:
"""
Publish customized pipeline template
"""
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
pipeline = db.session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
if not pipeline.workflow_id:
raise ValueError("Pipeline workflow not found")
workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
workflow = db.session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError("Workflow not found")
with Session(db.engine) as session:
@ -1131,7 +1131,7 @@ class RagPipelineService:
if template_name:
template = (
db.session.query(PipelineCustomizedTemplate)
.filter(
.where(
PipelineCustomizedTemplate.name == template_name,
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id,
)
@ -1142,7 +1142,7 @@ class RagPipelineService:
max_position = (
db.session.query(func.max(PipelineCustomizedTemplate.position))
.filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
.where(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
.scalar()
)
@ -1169,7 +1169,7 @@ class RagPipelineService:
def is_workflow_exist(self, pipeline: Pipeline) -> bool:
return (
db.session.query(Workflow)
.filter(
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == Workflow.VERSION_DRAFT,
@ -1278,7 +1278,7 @@ class RagPipelineService:
# Query active recommended plugins
pipeline_recommended_plugins = (
db.session.query(PipelineRecommendedPlugin)
.filter(PipelineRecommendedPlugin.active == True)
.where(PipelineRecommendedPlugin.active == True)
.order_by(PipelineRecommendedPlugin.position.asc())
.all()
)
@ -1329,12 +1329,12 @@ class RagPipelineService:
"""
document_pipeline_excution_log = (
db.session.query(DocumentPipelineExecutionLog)
.filter(DocumentPipelineExecutionLog.document_id == document.id)
.where(DocumentPipelineExecutionLog.document_id == document.id)
.first()
)
if not document_pipeline_excution_log:
raise ValueError("Document pipeline execution log not found")
pipeline = db.session.query(Pipeline).filter(Pipeline.id == document_pipeline_excution_log.pipeline_id).first()
pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_excution_log.pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
# convert to app config
@ -1358,3 +1358,99 @@ class RagPipelineService:
workflow_thread_pool_id=None,
is_retry=True,
)
def get_datasource_plugins(self, tenant_id: str, dataset_id: str, is_published: bool) -> list[dict]:
"""
Get datasource plugins
"""
dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset not found")
pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
workflow: Workflow | None = None
if is_published:
workflow = self.get_published_workflow(pipeline=pipeline)
else:
workflow = self.get_draft_workflow(pipeline=pipeline)
if not pipeline or not workflow:
raise ValueError("Pipeline or workflow not found")
datasource_nodes = workflow.graph_dict.get("nodes", [])
datasource_plugins = []
for datasource_node in datasource_nodes:
if datasource_node.get("type") == "datasource":
datasource_node_data = datasource_node.get("data", {})
if not datasource_node_data:
continue
variables = workflow.rag_pipeline_variables
if variables:
variables_map = {item["variable"]: item for item in variables}
else:
variables_map = {}
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
user_input_variables_keys = []
user_input_variables = []
for _, value in datasource_parameters.items():
if value.get("value") and isinstance(value.get("value"), str):
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
match = re.match(pattern, value["value"])
if match:
full_path = match.group(1)
last_part = full_path.split(".")[-1]
user_input_variables_keys.append(last_part)
elif value.get("value") and isinstance(value.get("value"), list):
last_part = value.get("value")[-1]
user_input_variables_keys.append(last_part)
for key, value in variables_map.items():
if key in user_input_variables_keys:
user_input_variables.append(value)
# get credentials
datasource_provider_service: DatasourceProviderService = DatasourceProviderService()
credentials: list[dict[Any, Any]] = datasource_provider_service.list_datasource_credentials(
tenant_id=tenant_id,
provider=datasource_node_data.get("provider_name"),
plugin_id=datasource_node_data.get("plugin_id"),
)
credential_info_list: list[Any] = []
for credential in credentials:
credential_info_list.append(
{
"id": credential.get("id"),
"name": credential.get("name"),
"type": credential.get("type"),
"is_default": credential.get("is_default"),
}
)
datasource_plugins.append(
{
"node_id": datasource_node.get("id"),
"plugin_id": datasource_node_data.get("plugin_id"),
"provider_name": datasource_node_data.get("provider_name"),
"datasource_type": datasource_node_data.get("provider_type"),
"title": datasource_node_data.get("title"),
"user_input_variables": user_input_variables,
"credentials": credential_info_list,
}
)
return datasource_plugins
def get_pipeline(self, tenant_id: str, dataset_id: str) -> Pipeline:
"""
Get pipeline
"""
dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset not found")
pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first()
if not pipeline:
raise ValueError("Pipeline not found")
return pipeline

View File

@ -6,7 +6,7 @@ import uuid
from collections.abc import Mapping
from datetime import UTC, datetime
from enum import StrEnum
from typing import Optional, cast
from typing import cast
from urllib.parse import urlparse
from uuid import uuid4
@ -66,11 +66,11 @@ class ImportStatus(StrEnum):
class RagPipelineImportInfo(BaseModel):
id: str
status: ImportStatus
pipeline_id: Optional[str] = None
pipeline_id: str | None = None
current_dsl_version: str = CURRENT_DSL_VERSION
imported_dsl_version: str = ""
error: str = ""
dataset_id: Optional[str] = None
dataset_id: str | None = None
class CheckDependenciesResult(BaseModel):
@ -121,12 +121,12 @@ class RagPipelineDslService:
*,
account: Account,
import_mode: str,
yaml_content: Optional[str] = None,
yaml_url: Optional[str] = None,
pipeline_id: Optional[str] = None,
dataset: Optional[Dataset] = None,
dataset_name: Optional[str] = None,
icon_info: Optional[IconInfo] = None,
yaml_content: str | None = None,
yaml_url: str | None = None,
pipeline_id: str | None = None,
dataset: Dataset | None = None,
dataset_name: str | None = None,
icon_info: IconInfo | None = None,
) -> RagPipelineImportInfo:
"""Import an app from YAML content or URL."""
import_id = str(uuid.uuid4())
@ -318,7 +318,7 @@ class RagPipelineDslService:
if knowledge_configuration.indexing_technique == "high_quality":
dataset_collection_binding = (
self._session.query(DatasetCollectionBinding)
.filter(
.where(
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
@ -452,7 +452,7 @@ class RagPipelineDslService:
if knowledge_configuration.indexing_technique == "high_quality":
dataset_collection_binding = (
self._session.query(DatasetCollectionBinding)
.filter(
.where(
DatasetCollectionBinding.provider_name
== knowledge_configuration.embedding_model_provider,
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
@ -530,10 +530,10 @@ class RagPipelineDslService:
def _create_or_update_pipeline(
self,
*,
pipeline: Optional[Pipeline],
pipeline: Pipeline | None,
data: dict,
account: Account,
dependencies: Optional[list[PluginDependency]] = None,
dependencies: list[PluginDependency] | None = None,
) -> Pipeline:
"""Create a new app or update an existing one."""
if not account.current_tenant_id:
@ -599,7 +599,7 @@ class RagPipelineDslService:
)
workflow = (
self._session.query(Workflow)
.filter(
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",
@ -673,7 +673,7 @@ class RagPipelineDslService:
workflow = (
self._session.query(Workflow)
.filter(
.where(
Workflow.tenant_id == pipeline.tenant_id,
Workflow.app_id == pipeline.id,
Workflow.version == "draft",

View File

@ -1,7 +1,6 @@
import json
from datetime import UTC, datetime
from pathlib import Path
from typing import Optional
from uuid import uuid4
import yaml
@ -21,7 +20,7 @@ from services.plugin.plugin_service import PluginService
class RagPipelineTransformService:
def transform_dataset(self, dataset_id: str):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset not found")
if dataset.pipeline_id and dataset.runtime_mode == "rag_pipeline":
@ -90,7 +89,7 @@ class RagPipelineTransformService:
"status": "success",
}
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: Optional[str]):
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None):
pipeline_yaml = {}
if doc_form == "text_model":
match datasource_type:
@ -152,7 +151,7 @@ class RagPipelineTransformService:
return node
def _deal_knowledge_index(
self, dataset: Dataset, doc_form: str, indexing_technique: Optional[str], retrieval_model: dict, node: dict
self, dataset: Dataset, doc_form: str, indexing_technique: str | None, retrieval_model: dict, node: dict
):
knowledge_configuration_dict = node.get("data", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration_dict)
@ -289,7 +288,7 @@ class RagPipelineTransformService:
jina_node_id = "1752491761974"
firecrawl_node_id = "1752565402678"
documents = db.session.query(Document).filter(Document.dataset_id == dataset.id).all()
documents = db.session.query(Document).where(Document.dataset_id == dataset.id).all()
for document in documents:
data_source_info_dict = document.data_source_info_dict
@ -299,7 +298,7 @@ class RagPipelineTransformService:
document.data_source_type = "local_file"
file_id = data_source_info_dict.get("upload_file_id")
if file_id:
file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if file:
data_source_info = json.dumps(
{

View File

@ -99,12 +99,6 @@ class ToolTransformService:
provider.declaration.identity.icon = ToolTransformService.get_plugin_icon_url(
tenant_id=tenant_id, filename=provider.declaration.identity.icon
)
else:
provider.declaration.identity.icon = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider.type.value,
provider_name=provider.name,
icon=provider.declaration.identity.icon,
)
@classmethod
def builtin_provider_to_user_provider(

View File

@ -1,6 +1,5 @@
import logging
import time
from typing import Optional
import click
from celery import shared_task
@ -17,7 +16,7 @@ logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: Optional[str], file_ids: list[str]):
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]):
"""
Clean document when document deleted.
:param document_ids: document ids

View File

@ -33,7 +33,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
if action == "upgrade":
dataset_documents = (
db.session.query(DatasetDocument)
.filter(
.where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
@ -44,7 +44,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update(
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
@ -54,7 +54,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
# add from vector index
segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
@ -76,19 +76,19 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
# clean keywords
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
index_processor.load(dataset, documents, with_keywords=False)
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
elif action == "update":
dataset_documents = (
db.session.query(DatasetDocument)
.filter(
.where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
@ -100,7 +100,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update(
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
@ -113,7 +113,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
try:
segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
@ -148,12 +148,12 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()

View File

@ -104,20 +104,20 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
with Session(db.engine, expire_on_commit=False) as session:
# Load required entities
account = session.query(Account).filter(Account.id == user_id).first()
account = session.query(Account).where(Account.id == user_id).first()
if not account:
raise ValueError(f"Account {user_id} not found")
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
tenant = session.query(Tenant).where(Tenant.id == tenant_id).first()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
account.current_tenant = tenant
pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
if not pipeline:
raise ValueError(f"Pipeline {pipeline_id} not found")
workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError(f"Workflow {pipeline.workflow_id} not found")

View File

@ -125,20 +125,20 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
with Session(db.engine) as session:
# Load required entities
account = session.query(Account).filter(Account.id == user_id).first()
account = session.query(Account).where(Account.id == user_id).first()
if not account:
raise ValueError(f"Account {user_id} not found")
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
tenant = session.query(Tenant).where(Tenant.id == tenant_id).first()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
account.current_tenant = tenant
pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
if not pipeline:
raise ValueError(f"Pipeline {pipeline_id} not found")
workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError(f"Workflow {pipeline.workflow_id} not found")

View File

@ -38,7 +38,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
if not user:
logger.info(click.style(f"User not found: {user_id}", fg="red"))
return
tenant = db.session.query(Tenant).filter(Tenant.id == dataset.tenant_id).first()
tenant = db.session.query(Tenant).where(Tenant.id == dataset.tenant_id).first()
if not tenant:
raise ValueError("Tenant not found")
user.current_tenant = tenant