mirror of https://github.com/langgenius/dify.git
Merge branch 'feat/rag-2' into fix/dependency-version
This commit is contained in:
commit
18027b530a
|
|
@ -1440,12 +1440,12 @@ def transform_datasource_credentials():
|
|||
notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all()
|
||||
if notion_credentials:
|
||||
notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {}
|
||||
for credential in notion_credentials:
|
||||
tenant_id = credential.tenant_id
|
||||
for notion_credential in notion_credentials:
|
||||
tenant_id = notion_credential.tenant_id
|
||||
if tenant_id not in notion_credentials_tenant_mapping:
|
||||
notion_credentials_tenant_mapping[tenant_id] = []
|
||||
notion_credentials_tenant_mapping[tenant_id].append(credential)
|
||||
for tenant_id, credentials in notion_credentials_tenant_mapping.items():
|
||||
notion_credentials_tenant_mapping[tenant_id].append(notion_credential)
|
||||
for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items():
|
||||
# check notion plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
|
|
@ -1454,12 +1454,12 @@ def transform_datasource_credentials():
|
|||
# install notion plugin
|
||||
PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier])
|
||||
auth_count = 0
|
||||
for credential in credentials:
|
||||
for notion_tenant_credential in notion_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential oauth params
|
||||
access_token = credential.access_token
|
||||
access_token = notion_tenant_credential.access_token
|
||||
# notion info
|
||||
notion_info = credential.source_info
|
||||
notion_info = notion_tenant_credential.source_info
|
||||
workspace_id = notion_info.get("workspace_id")
|
||||
workspace_name = notion_info.get("workspace_name")
|
||||
workspace_icon = notion_info.get("workspace_icon")
|
||||
|
|
@ -1487,12 +1487,12 @@ def transform_datasource_credentials():
|
|||
firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all()
|
||||
if firecrawl_credentials:
|
||||
firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for credential in firecrawl_credentials:
|
||||
tenant_id = credential.tenant_id
|
||||
for firecrawl_credential in firecrawl_credentials:
|
||||
tenant_id = firecrawl_credential.tenant_id
|
||||
if tenant_id not in firecrawl_credentials_tenant_mapping:
|
||||
firecrawl_credentials_tenant_mapping[tenant_id] = []
|
||||
firecrawl_credentials_tenant_mapping[tenant_id].append(credential)
|
||||
for tenant_id, credentials in firecrawl_credentials_tenant_mapping.items():
|
||||
firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential)
|
||||
for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items():
|
||||
# check firecrawl plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
|
|
@ -1502,10 +1502,10 @@ def transform_datasource_credentials():
|
|||
PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for credential in credentials:
|
||||
for firecrawl_tenant_credential in firecrawl_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(credential.credentials)
|
||||
credentials_json = json.loads(firecrawl_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
base_url = credentials_json.get("config", {}).get("base_url")
|
||||
new_credentials = {
|
||||
|
|
@ -1530,12 +1530,12 @@ def transform_datasource_credentials():
|
|||
jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all()
|
||||
if jina_credentials:
|
||||
jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {}
|
||||
for credential in jina_credentials:
|
||||
tenant_id = credential.tenant_id
|
||||
for jina_credential in jina_credentials:
|
||||
tenant_id = jina_credential.tenant_id
|
||||
if tenant_id not in jina_credentials_tenant_mapping:
|
||||
jina_credentials_tenant_mapping[tenant_id] = []
|
||||
jina_credentials_tenant_mapping[tenant_id].append(credential)
|
||||
for tenant_id, credentials in jina_credentials_tenant_mapping.items():
|
||||
jina_credentials_tenant_mapping[tenant_id].append(jina_credential)
|
||||
for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items():
|
||||
# check jina plugin is installed
|
||||
installed_plugins = installer_manager.list_plugins(tenant_id)
|
||||
installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins]
|
||||
|
|
@ -1546,10 +1546,10 @@ def transform_datasource_credentials():
|
|||
PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier])
|
||||
|
||||
auth_count = 0
|
||||
for credential in credentials:
|
||||
for jina_tenant_credential in jina_tenant_credentials:
|
||||
auth_count += 1
|
||||
# get credential api key
|
||||
credentials_json = json.loads(credential.credentials)
|
||||
credentials_json = json.loads(jina_tenant_credential.credentials)
|
||||
api_key = credentials_json.get("config", {}).get("api_key")
|
||||
new_credentials = {
|
||||
"integration_secret": api_key,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
from .app_config import DifyConfig
|
||||
|
||||
dify_config = DifyConfig() # pyright: ignore[reportCallIssue]
|
||||
dify_config = DifyConfig() # type: ignore
|
||||
|
|
|
|||
|
|
@ -512,11 +512,11 @@ class WorkflowVariableTruncationConfig(BaseSettings):
|
|||
description="Maximum size for variable to trigger final truncation.",
|
||||
)
|
||||
WORKFLOW_VARIABLE_TRUNCATION_STRING_LENGTH: PositiveInt = Field(
|
||||
50000,
|
||||
100000,
|
||||
description="maximum length for string to trigger tuncation, measure in number of characters",
|
||||
)
|
||||
WORKFLOW_VARIABLE_TRUNCATION_ARRAY_LENGTH: PositiveInt = Field(
|
||||
100,
|
||||
1000,
|
||||
description="maximum length for array to trigger truncation.",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ class CustomizedPipelineTemplateApi(Resource):
|
|||
def post(self, template_id: str):
|
||||
with Session(db.engine) as session:
|
||||
template = (
|
||||
session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
|
||||
session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
|
||||
)
|
||||
if not template:
|
||||
raise ValueError("Customized pipeline template not found.")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from controllers.console.datasets.error import PipelineNotFoundError
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -10,7 +9,7 @@ from models.dataset import Pipeline
|
|||
|
||||
|
||||
def get_rag_pipeline(
|
||||
view: Optional[Callable] = None,
|
||||
view: Callable | None = None,
|
||||
):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
|
|
@ -28,7 +27,7 @@ def get_rag_pipeline(
|
|||
|
||||
pipeline = (
|
||||
db.session.query(Pipeline)
|
||||
.filter(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
|
||||
.where(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -421,11 +421,10 @@ class PluginUploadFileRequestApi(Resource):
|
|||
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile):
|
||||
# generate signed url
|
||||
url = get_signed_file_url_for_plugin(
|
||||
payload.filename,
|
||||
payload.mimetype,
|
||||
tenant_model.id,
|
||||
user_model.id,
|
||||
user_model.session_id if isinstance(user_model, EndUser) else None,
|
||||
filename=payload.filename,
|
||||
mimetype=payload.mimetype,
|
||||
tenant_id=tenant_model.id,
|
||||
user_id=user_model.id,
|
||||
)
|
||||
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
|
||||
|
||||
|
|
|
|||
|
|
@ -32,11 +32,20 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser:
|
|||
user_model = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not user_model:
|
||||
user_model = (
|
||||
session.query(EndUser)
|
||||
.where(
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not user_model:
|
||||
user_model = EndUser(
|
||||
tenant_id=tenant_id,
|
||||
|
|
|
|||
|
|
@ -124,6 +124,9 @@ class DocumentAddByTextApi(DatasetApiResource):
|
|||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
|
|
@ -204,6 +207,8 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
|||
name = args.get("name")
|
||||
if text is None or name is None:
|
||||
raise ValueError("Both text and name must be strings.")
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_text(
|
||||
text=str(text), text_name=str(name), user_id=current_user.id, tenant_id=tenant_id
|
||||
)
|
||||
|
|
@ -308,6 +313,8 @@ class DocumentAddByFileApi(DatasetApiResource):
|
|||
|
||||
if not isinstance(current_user, EndUser):
|
||||
raise ValueError("Invalid user account")
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
|
|
@ -396,8 +403,12 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
|||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("current_user is required")
|
||||
|
||||
if not isinstance(current_user, EndUser):
|
||||
raise ValueError("Invalid user account")
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
|
|
@ -577,7 +588,7 @@ class DocumentApi(DatasetApiResource):
|
|||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||
elif metadata == "without":
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
|
|
@ -610,7 +621,7 @@ class DocumentApi(DatasetApiResource):
|
|||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
document_process_rules = document.dataset_process_rule.to_dict()
|
||||
document_process_rules = document.dataset_process_rule.to_dict() if document.dataset_process_rule else {}
|
||||
data_source_info = document.data_source_detail_dict
|
||||
response = {
|
||||
"id": document.id,
|
||||
|
|
|
|||
|
|
@ -47,3 +47,9 @@ class DatasetInUseError(BaseHTTPException):
|
|||
error_code = "dataset_in_use"
|
||||
description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it."
|
||||
code = 409
|
||||
|
||||
|
||||
class PipelineRunError(BaseHTTPException):
|
||||
error_code = "pipeline_run_error"
|
||||
description = "An error occurred while running the pipeline."
|
||||
code = 500
|
||||
|
|
|
|||
|
|
@ -216,6 +216,9 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
|
|||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
if not current_user:
|
||||
raise ValueError("Invalid user account")
|
||||
|
||||
try:
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Sequence
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
|
@ -114,9 +114,9 @@ class VariableEntity(BaseModel):
|
|||
hide: bool = False
|
||||
max_length: int | None = None
|
||||
options: Sequence[str] = Field(default_factory=list)
|
||||
allowed_file_types: Optional[Sequence[FileType]] = Field(default_factory=list)
|
||||
allowed_file_extensions: Optional[Sequence[str]] = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Optional[Sequence[FileTransferMethod]] = Field(default_factory=list)
|
||||
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
|
||||
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
|
||||
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
|
||||
|
||||
@field_validator("description", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -134,8 +134,8 @@ class RagPipelineVariableEntity(VariableEntity):
|
|||
Rag Pipeline Variable Entity.
|
||||
"""
|
||||
|
||||
tooltips: Optional[str] = None
|
||||
placeholder: Optional[str] = None
|
||||
tooltips: str | None = None
|
||||
placeholder: str | None = None
|
||||
belong_to_node_id: str
|
||||
|
||||
|
||||
|
|
@ -298,7 +298,7 @@ class AppConfig(BaseModel):
|
|||
tenant_id: str
|
||||
app_id: str
|
||||
app_mode: AppMode
|
||||
additional_features: Optional[AppAdditionalFeatures] = None
|
||||
additional_features: AppAdditionalFeatures | None = None
|
||||
variables: list[VariableEntity] = []
|
||||
sensitive_word_avoidance: SensitiveWordAvoidanceEntity | None = None
|
||||
|
||||
|
|
|
|||
|
|
@ -45,9 +45,9 @@ class BaseAppGenerator:
|
|||
mapping=v,
|
||||
tenant_id=tenant_id,
|
||||
config=FileUploadConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types, # pyright: ignore[reportArgumentType]
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, # pyright: ignore[reportArgumentType]
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, # pyright: ignore[reportArgumentType]
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types or [],
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [],
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
|
||||
),
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
|
|
@ -60,9 +60,9 @@ class BaseAppGenerator:
|
|||
mappings=v,
|
||||
tenant_id=tenant_id,
|
||||
config=FileUploadConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types, # pyright: ignore[reportArgumentType]
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, # pyright: ignore[reportArgumentType]
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, # pyright: ignore[reportArgumentType]
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types or [],
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [],
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [],
|
||||
),
|
||||
)
|
||||
for k, v in user_inputs.items()
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
:param blocking_response: blocking response
|
||||
:return:
|
||||
"""
|
||||
return dict(blocking_response.to_dict())
|
||||
return dict(blocking_response.model_dump())
|
||||
|
||||
@classmethod
|
||||
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||
|
|
@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(cast(dict, data))
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump())
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
|
|
@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict()))
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
response_chunk.update(sub_stream_response.model_dump())
|
||||
yield response_chunk
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import threading
|
|||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Literal, Optional, Union, cast, overload
|
||||
from typing import Any, Literal, Union, cast, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
|
|
@ -69,7 +69,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
workflow_thread_pool_id: str | None,
|
||||
is_retry: bool = False,
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
|
|
@ -84,7 +84,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
workflow_thread_pool_id: str | None,
|
||||
is_retry: bool = False,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
|
|
@ -99,7 +99,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
workflow_thread_pool_id: str | None,
|
||||
is_retry: bool = False,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||
|
||||
|
|
@ -113,7 +113,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
is_retry: bool = False,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
|
||||
# Add null check for dataset
|
||||
|
|
@ -314,7 +314,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
streaming: bool = True,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
|
|
@ -331,7 +331,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
"""
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
# init queue manager
|
||||
workflow = db.session.query(Workflow).filter(Workflow.id == workflow_id).first()
|
||||
workflow = db.session.query(Workflow).where(Workflow.id == workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_id}")
|
||||
queue_manager = PipelineQueueManager(
|
||||
|
|
@ -568,7 +568,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
queue_manager: AppQueueManager,
|
||||
context: contextvars.Context,
|
||||
variable_loader: VariableLoader,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
|
|
@ -744,7 +744,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
Format datasource info list.
|
||||
"""
|
||||
if datasource_type == "online_drive":
|
||||
all_files = []
|
||||
all_files: list[Mapping[str, Any]] = []
|
||||
datasource_node_data = None
|
||||
datasource_nodes = workflow.graph_dict.get("nodes", [])
|
||||
for datasource_node in datasource_nodes:
|
||||
|
|
@ -801,11 +801,11 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
self,
|
||||
datasource_runtime: OnlineDriveDatasourcePlugin,
|
||||
prefix: str,
|
||||
bucket: Optional[str],
|
||||
bucket: str | None,
|
||||
user_id: str,
|
||||
all_files: list,
|
||||
datasource_info: Mapping[str, Any],
|
||||
next_page_parameters: Optional[dict] = None,
|
||||
next_page_parameters: dict | None = None,
|
||||
):
|
||||
"""
|
||||
Get files in a folder.
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import logging
|
||||
import time
|
||||
from typing import Optional, cast
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
|
||||
|
|
@ -40,7 +40,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
variable_loader: VariableLoader,
|
||||
workflow: Workflow,
|
||||
system_user_id: str,
|
||||
workflow_thread_pool_id: Optional[str] = None,
|
||||
workflow_thread_pool_id: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
:param application_generate_entity: application generate entity
|
||||
|
|
@ -69,13 +69,13 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
|
||||
user_id = None
|
||||
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first()
|
||||
if end_user:
|
||||
user_id = end_user.session_id
|
||||
else:
|
||||
user_id = self.application_generate_entity.user_id
|
||||
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first()
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == app_config.app_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
|
||||
|
|
@ -188,16 +188,14 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
)
|
||||
self._handle_event(workflow_entry, event)
|
||||
|
||||
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]:
|
||||
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id
|
||||
)
|
||||
.where(Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
|
|
@ -205,7 +203,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
return workflow
|
||||
|
||||
def _init_rag_pipeline_graph(
|
||||
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: Optional[str] = None
|
||||
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: str | None = None
|
||||
) -> Graph:
|
||||
"""
|
||||
Init pipeline graph
|
||||
|
|
@ -272,7 +270,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
|
|||
if document_id and dataset_id:
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.filter(Document.id == document_id, Document.dataset_id == dataset_id)
|
||||
.where(Document.id == document_id, Document.dataset_id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
if document:
|
||||
|
|
|
|||
|
|
@ -242,7 +242,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
|||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_loop_run: Optional[SingleLoopRunEntity] = None
|
||||
single_loop_run: SingleLoopRunEntity | None = None
|
||||
|
||||
|
||||
class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
||||
|
|
@ -256,9 +256,9 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
|||
datasource_info: Mapping[str, Any]
|
||||
dataset_id: str
|
||||
batch: str
|
||||
document_id: Optional[str] = None
|
||||
original_document_id: Optional[str] = None
|
||||
start_node_id: Optional[str] = None
|
||||
document_id: str | None = None
|
||||
original_document_id: str | None = None
|
||||
start_node_id: str | None = None
|
||||
|
||||
|
||||
# Import TraceQueueManager at runtime to resolve forward references
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Mapping, Sequence
|
||||
from enum import StrEnum
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
|
@ -252,8 +252,8 @@ class NodeStartStreamResponse(StreamResponse):
|
|||
node_type: str
|
||||
title: str
|
||||
index: int
|
||||
predecessor_node_id: Optional[str] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: str | None = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
inputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: dict[str, object] = Field(default_factory=dict)
|
||||
|
|
@ -310,12 +310,12 @@ class NodeFinishStreamResponse(StreamResponse):
|
|||
node_type: str
|
||||
title: str
|
||||
index: int
|
||||
predecessor_node_id: Optional[str] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: str | None = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
inputs_truncated: bool = False
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
process_data_truncated: bool = False
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
outputs_truncated: bool = True
|
||||
status: str
|
||||
error: str | None = None
|
||||
|
|
@ -382,12 +382,12 @@ class NodeRetryStreamResponse(StreamResponse):
|
|||
node_type: str
|
||||
title: str
|
||||
index: int
|
||||
predecessor_node_id: Optional[str] = None
|
||||
inputs: Optional[Mapping[str, Any]] = None
|
||||
predecessor_node_id: str | None = None
|
||||
inputs: Mapping[str, Any] | None = None
|
||||
inputs_truncated: bool = False
|
||||
process_data: Optional[Mapping[str, Any]] = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
process_data_truncated: bool = False
|
||||
outputs: Optional[Mapping[str, Any]] = None
|
||||
outputs: Mapping[str, Any] | None = None
|
||||
outputs_truncated: bool = False
|
||||
status: str
|
||||
error: str | None = None
|
||||
|
|
@ -503,11 +503,11 @@ class IterationNodeCompletedStreamResponse(StreamResponse):
|
|||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
outputs: Optional[Mapping] = None
|
||||
outputs: Mapping | None = None
|
||||
outputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: Optional[dict] = None
|
||||
inputs: Optional[Mapping] = None
|
||||
extras: dict | None = None
|
||||
inputs: Mapping | None = None
|
||||
inputs_truncated: bool = False
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: str | None = None
|
||||
|
|
@ -541,8 +541,8 @@ class LoopNodeStartStreamResponse(StreamResponse):
|
|||
metadata: Mapping = {}
|
||||
inputs: Mapping = {}
|
||||
inputs_truncated: bool = False
|
||||
parallel_id: Optional[str] = None
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
parallel_id: str | None = None
|
||||
parallel_start_node_id: str | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.LOOP_STARTED
|
||||
workflow_run_id: str
|
||||
|
|
@ -590,11 +590,11 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
|
|||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
outputs: Optional[Mapping] = None
|
||||
outputs: Mapping | None = None
|
||||
outputs_truncated: bool = False
|
||||
created_at: int
|
||||
extras: Optional[dict] = None
|
||||
inputs: Optional[Mapping] = None
|
||||
extras: dict | None = None
|
||||
inputs: Mapping | None = None
|
||||
inputs_truncated: bool = False
|
||||
status: WorkflowNodeExecutionStatus
|
||||
error: str | None = None
|
||||
|
|
|
|||
|
|
@ -17,9 +17,9 @@ class DatasourceRuntime(BaseModel):
|
|||
"""
|
||||
|
||||
tenant_id: str
|
||||
datasource_id: Optional[str] = None
|
||||
datasource_id: str | None = None
|
||||
invoke_from: Optional["InvokeFrom"] = None
|
||||
datasource_invoke_from: Optional[DatasourceInvokeFrom] = None
|
||||
datasource_invoke_from: DatasourceInvokeFrom | None = None
|
||||
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import os
|
|||
import time
|
||||
from datetime import datetime
|
||||
from mimetypes import guess_extension, guess_type
|
||||
from typing import Optional, Union
|
||||
from typing import Union
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
|
|
@ -62,10 +62,10 @@ class DatasourceFileManager:
|
|||
*,
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: Optional[str],
|
||||
conversation_id: str | None,
|
||||
file_binary: bytes,
|
||||
mimetype: str,
|
||||
filename: Optional[str] = None,
|
||||
filename: str | None = None,
|
||||
) -> UploadFile:
|
||||
extension = guess_extension(mimetype) or ".bin"
|
||||
unique_name = uuid4().hex
|
||||
|
|
@ -106,7 +106,7 @@ class DatasourceFileManager:
|
|||
user_id: str,
|
||||
tenant_id: str,
|
||||
file_url: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
conversation_id: str | None = None,
|
||||
) -> ToolFile:
|
||||
# try to download image
|
||||
try:
|
||||
|
|
@ -152,13 +152,7 @@ class DatasourceFileManager:
|
|||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
upload_file: UploadFile | None = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == id).first()
|
||||
|
||||
if not upload_file:
|
||||
return None
|
||||
|
|
@ -176,13 +170,7 @@ class DatasourceFileManager:
|
|||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
message_file: MessageFile | None = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(
|
||||
MessageFile.id == id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
message_file: MessageFile | None = db.session.query(MessageFile).where(MessageFile.id == id).first()
|
||||
|
||||
# Check if message_file is not None
|
||||
if message_file is not None:
|
||||
|
|
@ -196,13 +184,7 @@ class DatasourceFileManager:
|
|||
else:
|
||||
tool_file_id = None
|
||||
|
||||
tool_file: ToolFile | None = (
|
||||
db.session.query(ToolFile)
|
||||
.filter(
|
||||
ToolFile.id == tool_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
tool_file: ToolFile | None = db.session.query(ToolFile).where(ToolFile.id == tool_file_id).first()
|
||||
|
||||
if not tool_file:
|
||||
return None
|
||||
|
|
@ -220,13 +202,7 @@ class DatasourceFileManager:
|
|||
|
||||
:return: the binary of the file, mime type
|
||||
"""
|
||||
upload_file: UploadFile | None = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == upload_file_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
upload_file: UploadFile | None = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
|
||||
|
||||
if not upload_file:
|
||||
return None, None
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class DatasourceManager:
|
|||
provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id)
|
||||
if not provider_entity:
|
||||
raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found")
|
||||
|
||||
controller: DatasourcePluginProviderController | None = None
|
||||
match datasource_type:
|
||||
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||
controller = OnlineDocumentDatasourcePluginProviderController(
|
||||
|
|
@ -79,7 +79,11 @@ class DatasourceManager:
|
|||
case _:
|
||||
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||
|
||||
datasource_plugin_providers[provider_id] = controller
|
||||
if controller:
|
||||
datasource_plugin_providers[provider_id] = controller
|
||||
|
||||
if controller is None:
|
||||
raise DatasourceProviderNotFoundError(f"Datasource provider {provider_id} not found.")
|
||||
|
||||
return controller
|
||||
|
||||
|
|
|
|||
|
|
@ -12,9 +12,9 @@ class DatasourceApiEntity(BaseModel):
|
|||
name: str # identifier
|
||||
label: I18nObject # label
|
||||
description: I18nObject
|
||||
parameters: Optional[list[DatasourceParameter]] = None
|
||||
parameters: list[DatasourceParameter] | None = None
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
output_schema: Optional[dict] = None
|
||||
output_schema: dict | None = None
|
||||
|
||||
|
||||
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||
|
|
@ -28,12 +28,12 @@ class DatasourceProviderApiEntity(BaseModel):
|
|||
icon: str | dict
|
||||
label: I18nObject # label
|
||||
type: str
|
||||
masked_credentials: Optional[dict] = None
|
||||
original_credentials: Optional[dict] = None
|
||||
masked_credentials: dict | None = None
|
||||
original_credentials: dict | None = None
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource")
|
||||
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource")
|
||||
plugin_id: str | None = Field(default="", description="The plugin id of the datasource")
|
||||
plugin_unique_identifier: str | None = Field(default="", description="The unique identifier of the datasource")
|
||||
datasources: list[DatasourceApiEntity] = Field(default_factory=list)
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
|
||||
|
|
@ -62,7 +62,7 @@ class DatasourceProviderApiEntity(BaseModel):
|
|||
"description": self.description.to_dict(),
|
||||
"icon": self.icon,
|
||||
"label": self.label.to_dict(),
|
||||
"type": self.type.value,
|
||||
"type": self.type,
|
||||
"team_credentials": self.masked_credentials,
|
||||
"is_team_authorization": self.is_team_authorization,
|
||||
"allow_delete": self.allow_delete,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
|
|
@ -9,9 +7,9 @@ class I18nObject(BaseModel):
|
|||
"""
|
||||
|
||||
en_US: str
|
||||
zh_Hans: Optional[str] = Field(default=None)
|
||||
pt_BR: Optional[str] = Field(default=None)
|
||||
ja_JP: Optional[str] = Field(default=None)
|
||||
zh_Hans: str | None = Field(default=None)
|
||||
pt_BR: str | None = Field(default=None)
|
||||
ja_JP: str | None = Field(default=None)
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import enum
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||
from yarl import URL
|
||||
|
|
@ -80,7 +80,7 @@ class DatasourceParameter(PluginParameter):
|
|||
name: str,
|
||||
typ: DatasourceParameterType,
|
||||
required: bool,
|
||||
options: Optional[list[str]] = None,
|
||||
options: list[str] | None = None,
|
||||
) -> "DatasourceParameter":
|
||||
"""
|
||||
get a simple datasource parameter
|
||||
|
|
@ -120,14 +120,14 @@ class DatasourceIdentity(BaseModel):
|
|||
name: str = Field(..., description="The name of the datasource")
|
||||
label: I18nObject = Field(..., description="The label of the datasource")
|
||||
provider: str = Field(..., description="The provider of the datasource")
|
||||
icon: Optional[str] = None
|
||||
icon: str | None = None
|
||||
|
||||
|
||||
class DatasourceEntity(BaseModel):
|
||||
identity: DatasourceIdentity
|
||||
parameters: list[DatasourceParameter] = Field(default_factory=list)
|
||||
description: I18nObject = Field(..., description="The label of the datasource")
|
||||
output_schema: Optional[dict] = None
|
||||
output_schema: dict | None = None
|
||||
|
||||
@field_validator("parameters", mode="before")
|
||||
@classmethod
|
||||
|
|
@ -141,7 +141,7 @@ class DatasourceProviderIdentity(BaseModel):
|
|||
description: I18nObject = Field(..., description="The description of the tool")
|
||||
icon: str = Field(..., description="The icon of the tool")
|
||||
label: I18nObject = Field(..., description="The label of the tool")
|
||||
tags: Optional[list[ToolLabelEnum]] = Field(
|
||||
tags: list[ToolLabelEnum] | None = Field(
|
||||
default=[],
|
||||
description="The tags of the tool",
|
||||
)
|
||||
|
|
@ -169,7 +169,7 @@ class DatasourceProviderEntity(BaseModel):
|
|||
|
||||
identity: DatasourceProviderIdentity
|
||||
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||
oauth_schema: Optional[OAuthSchema] = None
|
||||
oauth_schema: OAuthSchema | None = None
|
||||
provider_type: DatasourceProviderType
|
||||
|
||||
|
||||
|
|
@ -183,8 +183,8 @@ class DatasourceInvokeMeta(BaseModel):
|
|||
"""
|
||||
|
||||
time_cost: float = Field(..., description="The time cost of the tool invoke")
|
||||
error: Optional[str] = None
|
||||
tool_config: Optional[dict] = None
|
||||
error: str | None = None
|
||||
tool_config: dict | None = None
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "DatasourceInvokeMeta":
|
||||
|
|
@ -233,10 +233,10 @@ class OnlineDocumentPage(BaseModel):
|
|||
|
||||
page_id: str = Field(..., description="The page id")
|
||||
page_name: str = Field(..., description="The page title")
|
||||
page_icon: Optional[dict] = Field(None, description="The page icon")
|
||||
page_icon: dict | None = Field(None, description="The page icon")
|
||||
type: str = Field(..., description="The type of the page")
|
||||
last_edited_time: str = Field(..., description="The last edited time")
|
||||
parent_id: Optional[str] = Field(None, description="The parent page id")
|
||||
parent_id: str | None = Field(None, description="The parent page id")
|
||||
|
||||
|
||||
class OnlineDocumentInfo(BaseModel):
|
||||
|
|
@ -244,9 +244,9 @@ class OnlineDocumentInfo(BaseModel):
|
|||
Online document info
|
||||
"""
|
||||
|
||||
workspace_id: Optional[str] = Field(None, description="The workspace id")
|
||||
workspace_name: Optional[str] = Field(None, description="The workspace name")
|
||||
workspace_icon: Optional[str] = Field(None, description="The workspace icon")
|
||||
workspace_id: str | None = Field(None, description="The workspace id")
|
||||
workspace_name: str | None = Field(None, description="The workspace name")
|
||||
workspace_icon: str | None = Field(None, description="The workspace icon")
|
||||
total: int = Field(..., description="The total number of documents")
|
||||
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")
|
||||
|
||||
|
|
@ -307,10 +307,10 @@ class WebSiteInfo(BaseModel):
|
|||
Website info
|
||||
"""
|
||||
|
||||
status: Optional[str] = Field(..., description="crawl job status")
|
||||
web_info_list: Optional[list[WebSiteInfoDetail]] = []
|
||||
total: Optional[int] = Field(default=0, description="The total number of websites")
|
||||
completed: Optional[int] = Field(default=0, description="The number of completed websites")
|
||||
status: str | None = Field(..., description="crawl job status")
|
||||
web_info_list: list[WebSiteInfoDetail] | None = []
|
||||
total: int | None = Field(default=0, description="The total number of websites")
|
||||
completed: int | None = Field(default=0, description="The number of completed websites")
|
||||
|
||||
|
||||
class WebsiteCrawlMessage(BaseModel):
|
||||
|
|
@ -346,10 +346,10 @@ class OnlineDriveFileBucket(BaseModel):
|
|||
Online drive file bucket
|
||||
"""
|
||||
|
||||
bucket: Optional[str] = Field(None, description="The file bucket")
|
||||
bucket: str | None = Field(None, description="The file bucket")
|
||||
files: list[OnlineDriveFile] = Field(..., description="The file list")
|
||||
is_truncated: bool = Field(False, description="Whether the result is truncated")
|
||||
next_page_parameters: Optional[dict] = Field(None, description="Parameters for fetching the next page")
|
||||
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
|
||||
|
||||
|
||||
class OnlineDriveBrowseFilesRequest(BaseModel):
|
||||
|
|
@ -357,10 +357,10 @@ class OnlineDriveBrowseFilesRequest(BaseModel):
|
|||
Get online drive file list request
|
||||
"""
|
||||
|
||||
bucket: Optional[str] = Field(None, description="The file bucket")
|
||||
bucket: str | None = Field(None, description="The file bucket")
|
||||
prefix: str = Field(..., description="The parent folder ID")
|
||||
max_keys: int = Field(20, description="Page size for pagination")
|
||||
next_page_parameters: Optional[dict] = Field(None, description="Parameters for fetching the next page")
|
||||
next_page_parameters: dict | None = Field(None, description="Parameters for fetching the next page")
|
||||
|
||||
|
||||
class OnlineDriveBrowseFilesResponse(BaseModel):
|
||||
|
|
@ -377,4 +377,4 @@ class OnlineDriveDownloadFileRequest(BaseModel):
|
|||
"""
|
||||
|
||||
id: str = Field(..., description="The id of the file")
|
||||
bucket: Optional[str] = Field(None, description="The name of the bucket")
|
||||
bucket: str | None = Field(None, description="The name of the bucket")
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
import logging
|
||||
from collections.abc import Generator
|
||||
from mimetypes import guess_extension, guess_type
|
||||
from typing import Optional
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from models.tools import ToolFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -17,7 +17,7 @@ class DatasourceFileMessageTransformer:
|
|||
messages: Generator[DatasourceMessage, None, None],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
conversation_id: str | None = None,
|
||||
) -> Generator[DatasourceMessage, None, None]:
|
||||
"""
|
||||
Transform datasource message and handle file download
|
||||
|
|
@ -32,20 +32,20 @@ class DatasourceFileMessageTransformer:
|
|||
try:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
tool_file_manager = ToolFileManager()
|
||||
file = tool_file_manager.create_file_by_url(
|
||||
tool_file: ToolFile | None = tool_file_manager.create_file_by_url(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
file_url=message.message.text,
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
if tool_file:
|
||||
url = f"/files/datasources/{tool_file.id}{guess_extension(tool_file.mimetype) or '.png'}"
|
||||
|
||||
url = f"/files/datasources/{file.id}{guess_extension(file.mimetype) or '.png'}"
|
||||
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
except Exception as e:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.TEXT,
|
||||
|
|
@ -72,7 +72,7 @@ class DatasourceFileMessageTransformer:
|
|||
# FIXME: should do a type check here.
|
||||
assert isinstance(message.message.blob, bytes)
|
||||
tool_file_manager = ToolFileManager()
|
||||
file = tool_file_manager.create_file_by_raw(
|
||||
blob_tool_file: ToolFile | None = tool_file_manager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
|
|
@ -80,25 +80,27 @@ class DatasourceFileMessageTransformer:
|
|||
mimetype=mimetype,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mimetype))
|
||||
|
||||
# check if file is image
|
||||
if "image" in mimetype:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.BINARY_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
if blob_tool_file:
|
||||
url = cls.get_datasource_file_url(
|
||||
datasource_file_id=blob_tool_file.id, extension=guess_extension(blob_tool_file.mimetype)
|
||||
)
|
||||
|
||||
# check if file is image
|
||||
if "image" in mimetype:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
else:
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.BINARY_LINK,
|
||||
message=DatasourceMessage.TextMessage(text=url),
|
||||
meta=meta.copy() if meta is not None else {},
|
||||
)
|
||||
elif message.type == DatasourceMessage.MessageType.FILE:
|
||||
meta = message.meta or {}
|
||||
file = meta.get("file", None)
|
||||
file: File | None = meta.get("file")
|
||||
if isinstance(file, File):
|
||||
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
assert file.related_id is not None
|
||||
|
|
@ -121,5 +123,5 @@ class DatasourceFileMessageTransformer:
|
|||
yield message
|
||||
|
||||
@classmethod
|
||||
def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str:
|
||||
def get_datasource_file_url(cls, datasource_file_id: str, extension: str | None) -> str:
|
||||
return f"/files/datasources/{datasource_file_id}{extension or '.bin'}"
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import uuid
|
|||
from json import dumps as json_dumps
|
||||
from json import loads as json_loads
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from flask import request
|
||||
from requests import get
|
||||
|
|
@ -169,9 +168,9 @@ class ApiBasedToolSchemaParser:
|
|||
return bundles
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
|
||||
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType | None:
|
||||
parameter = parameter or {}
|
||||
typ: Optional[str] = None
|
||||
typ: str | None = None
|
||||
if parameter.get("format") == "binary":
|
||||
return ToolParameter.ToolParameterType.FILE
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
|
@ -16,7 +14,7 @@ class QAPreviewDetail(BaseModel):
|
|||
class IndexingEstimate(BaseModel):
|
||||
total_segments: int
|
||||
preview: list[PreviewDetail]
|
||||
qa_preview: Optional[list[QAPreviewDetail]] = None
|
||||
qa_preview: list[QAPreviewDetail] | None = None
|
||||
|
||||
|
||||
class PipelineDataset(BaseModel):
|
||||
|
|
@ -30,10 +28,10 @@ class PipelineDocument(BaseModel):
|
|||
id: str
|
||||
position: int
|
||||
data_source_type: str
|
||||
data_source_info: Optional[dict] = None
|
||||
data_source_info: dict | None = None
|
||||
name: str
|
||||
indexing_status: str
|
||||
error: Optional[str] = None
|
||||
error: str | None = None
|
||||
enabled: bool
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -25,9 +25,7 @@ def get_signed_file_url(upload_file_id: str, as_attachment=False) -> str:
|
|||
return f"{url}?{query_string}"
|
||||
|
||||
|
||||
def get_signed_file_url_for_plugin(
|
||||
filename: str, mimetype: str, tenant_id: str, user_id: str, session_id: str | None
|
||||
) -> str:
|
||||
def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str:
|
||||
# Plugin access should use internal URL for Docker network communication
|
||||
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
|
||||
url = f"{base_url}/files/upload/for-plugin"
|
||||
|
|
@ -37,9 +35,7 @@ def get_signed_file_url_for_plugin(
|
|||
msg = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}"
|
||||
sign = hmac.new(key, msg.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
url_user_id = session_id or user_id
|
||||
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={url_user_id}&tenant_id={tenant_id}"
|
||||
return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={user_id}&tenant_id={tenant_id}"
|
||||
|
||||
|
||||
def verify_plugin_file_signature(
|
||||
|
|
|
|||
|
|
@ -119,6 +119,7 @@ class File(BaseModel):
|
|||
assert self.related_id is not None
|
||||
assert self.extension is not None
|
||||
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension)
|
||||
return None
|
||||
|
||||
def to_plugin_parameter(self) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import datetime
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum, auto
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from packaging.version import InvalidVersion, Version
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
|
@ -67,10 +67,10 @@ class PluginCategory(StrEnum):
|
|||
|
||||
class PluginDeclaration(BaseModel):
|
||||
class Plugins(BaseModel):
|
||||
tools: Optional[list[str]] = Field(default_factory=list[str])
|
||||
models: Optional[list[str]] = Field(default_factory=list[str])
|
||||
endpoints: Optional[list[str]] = Field(default_factory=list[str])
|
||||
datasources: Optional[list[str]] = Field(default_factory=list[str])
|
||||
tools: list[str] | None = Field(default_factory=list[str])
|
||||
models: list[str] | None = Field(default_factory=list[str])
|
||||
endpoints: list[str] | None = Field(default_factory=list[str])
|
||||
datasources: list[str] | None = Field(default_factory=list[str])
|
||||
|
||||
class Meta(BaseModel):
|
||||
minimum_dify_version: str | None = Field(default=None)
|
||||
|
|
@ -101,11 +101,11 @@ class PluginDeclaration(BaseModel):
|
|||
tags: list[str] = Field(default_factory=list)
|
||||
repo: str | None = Field(default=None)
|
||||
verified: bool = Field(default=False)
|
||||
tool: Optional[ToolProviderEntity] = None
|
||||
model: Optional[ProviderEntity] = None
|
||||
endpoint: Optional[EndpointProviderDeclaration] = None
|
||||
agent_strategy: Optional[AgentStrategyProviderEntity] = None
|
||||
datasource: Optional[DatasourceProviderEntity] = None
|
||||
tool: ToolProviderEntity | None = None
|
||||
model: ProviderEntity | None = None
|
||||
endpoint: EndpointProviderDeclaration | None = None
|
||||
agent_strategy: AgentStrategyProviderEntity | None = None
|
||||
datasource: DatasourceProviderEntity | None = None
|
||||
meta: Meta
|
||||
|
||||
@field_validator("version")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from collections.abc import Mapping
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -27,12 +27,12 @@ class DatasourceErrorEvent(BaseDatasourceEvent):
|
|||
class DatasourceCompletedEvent(BaseDatasourceEvent):
|
||||
event: str = DatasourceStreamEvent.COMPLETED.value
|
||||
data: Mapping[str, Any] | list = Field(..., description="result")
|
||||
total: Optional[int] = Field(default=0, description="total")
|
||||
completed: Optional[int] = Field(default=0, description="completed")
|
||||
time_consuming: Optional[float] = Field(default=0.0, description="time consuming")
|
||||
total: int | None = Field(default=0, description="total")
|
||||
completed: int | None = Field(default=0, description="completed")
|
||||
time_consuming: float | None = Field(default=0.0, description="time consuming")
|
||||
|
||||
|
||||
class DatasourceProcessingEvent(BaseDatasourceEvent):
|
||||
event: str = DatasourceStreamEvent.PROCESSING.value
|
||||
total: Optional[int] = Field(..., description="total")
|
||||
completed: Optional[int] = Field(..., description="completed")
|
||||
total: int | None = Field(..., description="total")
|
||||
completed: int | None = Field(..., description="completed")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from models.dataset import Document
|
||||
|
|
@ -11,7 +9,7 @@ class NotionInfo(BaseModel):
|
|||
Notion import info.
|
||||
"""
|
||||
|
||||
credential_id: Optional[str] = None
|
||||
credential_id: str | None = None
|
||||
notion_workspace_id: str
|
||||
notion_obj_id: str
|
||||
notion_page_type: str
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import logging
|
||||
import operator
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import requests
|
||||
|
||||
|
|
@ -35,9 +35,9 @@ class NotionExtractor(BaseExtractor):
|
|||
notion_obj_id: str,
|
||||
notion_page_type: str,
|
||||
tenant_id: str,
|
||||
document_model: Optional[DocumentModel] = None,
|
||||
notion_access_token: Optional[str] = None,
|
||||
credential_id: Optional[str] = None,
|
||||
document_model: DocumentModel | None = None,
|
||||
notion_access_token: str | None = None,
|
||||
credential_id: str | None = None,
|
||||
):
|
||||
self._notion_access_token = None
|
||||
self._document_model = document_model
|
||||
|
|
@ -369,7 +369,7 @@ class NotionExtractor(BaseExtractor):
|
|||
return cast(str, data["last_edited_time"])
|
||||
|
||||
@classmethod
|
||||
def _get_access_token(cls, tenant_id: str, credential_id: Optional[str]) -> str:
|
||||
def _get_access_token(cls, tenant_id: str, credential_id: str | None) -> str:
|
||||
# get credential from tenant_id and credential_id
|
||||
if not credential_id:
|
||||
raise Exception(f"No credential id found for tenant {tenant_id}")
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import json
|
|||
import logging
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Optional, TypeVar, Union
|
||||
from typing import Any, TypeVar, Union
|
||||
|
||||
import psycopg2.errors
|
||||
from sqlalchemy import UnaryExpression, asc, desc, select
|
||||
|
|
@ -291,7 +291,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
return None
|
||||
|
||||
value_json = _deterministic_json_dump(json_encodable_value)
|
||||
assert value_json is not None, "value_json should be None here."
|
||||
assert value_json is not None, "value_json should be not None here."
|
||||
|
||||
suffix = type_.value
|
||||
upload_file = self._file_service.upload_file(
|
||||
|
|
@ -530,7 +530,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||
def get_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
order_config: OrderConfig | None = None,
|
||||
triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ class SchemaRegistry:
|
|||
except (OSError, json.JSONDecodeError) as e:
|
||||
print(f"Warning: failed to load schema {version}/{schema_name}: {e}")
|
||||
|
||||
def get_schema(self, uri: str) -> Optional[Any]:
|
||||
def get_schema(self, uri: str) -> Any | None:
|
||||
"""Retrieves a schema by URI with version support"""
|
||||
version, schema_name = self._parse_uri(uri)
|
||||
if not version or not schema_name:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import re
|
|||
import threading
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Union
|
||||
|
||||
from core.schemas.registry import SchemaRegistry
|
||||
|
||||
|
|
@ -53,8 +53,8 @@ class QueueItem:
|
|||
"""Represents an item in the BFS queue"""
|
||||
|
||||
current: Any
|
||||
parent: Optional[Any]
|
||||
key: Optional[Union[str, int]]
|
||||
parent: Any | None
|
||||
key: Union[str, int] | None
|
||||
depth: int
|
||||
ref_path: set[str]
|
||||
|
||||
|
|
@ -65,7 +65,7 @@ class SchemaResolver:
|
|||
_cache: dict[str, SchemaDict] = {}
|
||||
_cache_lock = threading.Lock()
|
||||
|
||||
def __init__(self, registry: Optional[SchemaRegistry] = None, max_depth: int = 10):
|
||||
def __init__(self, registry: SchemaRegistry | None = None, max_depth: int = 10):
|
||||
"""
|
||||
Initialize the schema resolver
|
||||
|
||||
|
|
@ -202,7 +202,7 @@ class SchemaResolver:
|
|||
)
|
||||
)
|
||||
|
||||
def _get_resolved_schema(self, ref_uri: str) -> Optional[SchemaDict]:
|
||||
def _get_resolved_schema(self, ref_uri: str) -> SchemaDict | None:
|
||||
"""Get resolved schema from cache or registry"""
|
||||
# Check cache first
|
||||
with self._cache_lock:
|
||||
|
|
@ -223,7 +223,7 @@ class SchemaResolver:
|
|||
|
||||
|
||||
def resolve_dify_schema_refs(
|
||||
schema: SchemaType, registry: Optional[SchemaRegistry] = None, max_depth: int = 30
|
||||
schema: SchemaType, registry: SchemaRegistry | None = None, max_depth: int = 30
|
||||
) -> SchemaType:
|
||||
"""
|
||||
Resolve $ref references in Dify schema to actual schema content
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from core.schemas.registry import SchemaRegistry
|
||||
|
||||
|
|
@ -7,7 +7,7 @@ from core.schemas.registry import SchemaRegistry
|
|||
class SchemaManager:
|
||||
"""Schema manager provides high-level schema operations"""
|
||||
|
||||
def __init__(self, registry: Optional[SchemaRegistry] = None):
|
||||
def __init__(self, registry: SchemaRegistry | None = None):
|
||||
self.registry = registry or SchemaRegistry.default_registry()
|
||||
|
||||
def get_all_schema_definitions(self, version: str = "v1") -> list[Mapping[str, Any]]:
|
||||
|
|
@ -22,7 +22,7 @@ class SchemaManager:
|
|||
"""
|
||||
return self.registry.get_all_schemas_for_version(version)
|
||||
|
||||
def get_schema_by_name(self, schema_name: str, version: str = "v1") -> Optional[Mapping[str, Any]]:
|
||||
def get_schema_by_name(self, schema_name: str, version: str = "v1") -> Mapping[str, Any] | None:
|
||||
"""
|
||||
Get a specific schema by name
|
||||
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ class VariablePool(BaseModel):
|
|||
self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var)
|
||||
# Add rag pipeline variables to the variable pool
|
||||
if self.rag_pipeline_variables:
|
||||
rag_pipeline_variables_map = defaultdict(dict)
|
||||
rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict)
|
||||
for rag_var in self.rag_pipeline_variables:
|
||||
node_id = rag_var.variable.belong_to_node_id
|
||||
key = rag_var.variable.variable
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ and don't contain implementation details like tenant_id, app_id, etc.
|
|||
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
|
|
@ -53,9 +53,9 @@ class WorkflowNodeExecution(BaseModel):
|
|||
|
||||
# Execution data
|
||||
# The `inputs` and `outputs` fields hold the full content
|
||||
inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node
|
||||
process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data
|
||||
outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node
|
||||
inputs: Mapping[str, Any] | None = None # Input variables used by this node
|
||||
process_data: Mapping[str, Any] | None = None # Intermediate processing data
|
||||
outputs: Mapping[str, Any] | None = None # Output variables produced by this node
|
||||
|
||||
# Execution state
|
||||
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status
|
||||
|
|
|
|||
|
|
@ -13,4 +13,4 @@ class WorkflowNodeRunFailedError(Exception):
|
|||
|
||||
@property
|
||||
def error(self) -> str:
|
||||
return self._error
|
||||
return self._error
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
|
@ -19,7 +19,7 @@ from core.file.enums import FileTransferMethod, FileType
|
|||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.variables.segments import ArrayAnySegment
|
||||
from core.variables.variables import ArrayAnyVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, SystemVariableKey
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
|
|
@ -50,7 +50,7 @@ class DatasourceNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = DatasourceNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -59,7 +59,7 @@ class DatasourceNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
@ -87,29 +87,18 @@ class DatasourceNode(Node):
|
|||
raise DatasourceNodeError("Invalid datasource info format")
|
||||
datasource_info: dict[str, Any] = datasource_info_value
|
||||
# get datasource runtime
|
||||
try:
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
|
||||
if datasource_type is None:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
if datasource_type is None:
|
||||
raise DatasourceNodeError("Datasource type is not set")
|
||||
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||
)
|
||||
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
|
||||
except DatasourceNodeError as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||
error=f"Failed to get datasource runtime: {str(e)}",
|
||||
error_type=type(e).__name__,
|
||||
)
|
||||
)
|
||||
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
|
||||
datasource_name=node_data.datasource_name or "",
|
||||
tenant_id=self.tenant_id,
|
||||
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||
)
|
||||
datasource_info["icon"] = datasource_runtime.get_icon_url(self.tenant_id)
|
||||
|
||||
parameters_for_log = datasource_info
|
||||
|
||||
|
|
@ -179,7 +168,7 @@ class DatasourceNode(Node):
|
|||
related_id = datasource_info.get("related_id")
|
||||
if not related_id:
|
||||
raise DatasourceNodeError("File is not exist")
|
||||
upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first()
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == related_id).first()
|
||||
if not upload_file:
|
||||
raise ValueError("Invalid upload file Info")
|
||||
|
||||
|
|
@ -282,27 +271,6 @@ class DatasourceNode(Node):
|
|||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
def _append_variables_recursively(
|
||||
self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue
|
||||
):
|
||||
"""
|
||||
Append variables recursively
|
||||
:param node_id: node id
|
||||
:param variable_key_list: variable key list
|
||||
:param variable_value: variable value
|
||||
:return:
|
||||
"""
|
||||
variable_pool.add([node_id] + [".".join(variable_key_list)], variable_value)
|
||||
|
||||
# if variable_value is a dict, then recursively append variables
|
||||
if isinstance(variable_value, dict):
|
||||
for key, value in variable_value.items():
|
||||
# construct new key list
|
||||
new_key_list = variable_key_list + [key]
|
||||
self._append_variables_recursively(
|
||||
variable_pool=variable_pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
|
@ -423,13 +391,6 @@ class DatasourceNode(Node):
|
|||
)
|
||||
elif message.type == DatasourceMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||
if self.node_type == NodeType.AGENT:
|
||||
msg_metadata = message.message.json_object.pop("execution_metadata", {})
|
||||
agent_execution_metadata = {
|
||||
key: value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
json.append(message.message.json_object)
|
||||
elif message.type == DatasourceMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
|
@ -10,7 +10,7 @@ class DatasourceEntity(BaseModel):
|
|||
plugin_id: str
|
||||
provider_name: str # redundancy
|
||||
provider_type: str
|
||||
datasource_name: Optional[str] = "local_file"
|
||||
datasource_name: str | None = "local_file"
|
||||
datasource_configurations: dict[str, Any] | None = None
|
||||
plugin_unique_identifier: str | None = None # redundancy
|
||||
|
||||
|
|
@ -19,7 +19,7 @@ class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
|||
class DatasourceInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Optional[Literal["mixed", "variable", "constant"]] = None
|
||||
type: Literal["mixed", "variable", "constant"] | None = None
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Literal, Optional, Union
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -65,12 +65,12 @@ class RetrievalSetting(BaseModel):
|
|||
|
||||
search_method: Literal["semantic_search", "keyword_search", "fulltext_search", "hybrid_search"]
|
||||
top_k: int
|
||||
score_threshold: Optional[float] = 0.5
|
||||
score_threshold: float | None = 0.5
|
||||
score_threshold_enabled: bool = False
|
||||
reranking_mode: str = "reranking_model"
|
||||
reranking_enable: bool = True
|
||||
reranking_model: Optional[RerankingModelConfig] = None
|
||||
weights: Optional[WeightedScoreConfig] = None
|
||||
reranking_model: RerankingModelConfig | None = None
|
||||
weights: WeightedScoreConfig | None = None
|
||||
|
||||
|
||||
class IndexMethod(BaseModel):
|
||||
|
|
@ -107,10 +107,10 @@ class OnlineDocumentInfo(BaseModel):
|
|||
"""
|
||||
|
||||
provider: str
|
||||
workspace_id: Optional[str] = None
|
||||
workspace_id: str | None = None
|
||||
page_id: str
|
||||
page_type: str
|
||||
icon: Optional[OnlineDocumentIcon] = None
|
||||
icon: OnlineDocumentIcon | None = None
|
||||
|
||||
|
||||
class WebsiteInfo(BaseModel):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import datetime
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from sqlalchemy import func, select
|
||||
|
||||
|
|
@ -43,7 +43,7 @@ class KnowledgeIndexNode(Node):
|
|||
def init_node_data(self, data: Mapping[str, Any]) -> None:
|
||||
self._node_data = KnowledgeIndexNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self) -> Optional[ErrorStrategy]:
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
|
|
@ -52,7 +52,7 @@ class KnowledgeIndexNode(Node):
|
|||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self) -> Optional[str]:
|
||||
def _get_description(self) -> str | None:
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
|
|
@ -160,7 +160,7 @@ class KnowledgeIndexNode(Node):
|
|||
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||
document.word_count = (
|
||||
db.session.query(func.sum(DocumentSegment.word_count))
|
||||
.filter(
|
||||
.where(
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
)
|
||||
|
|
@ -168,7 +168,7 @@ class KnowledgeIndexNode(Node):
|
|||
)
|
||||
db.session.add(document)
|
||||
# update document segment status
|
||||
db.session.query(DocumentSegment).filter(
|
||||
db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.document_id == document.id,
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
).update(
|
||||
|
|
|
|||
|
|
@ -92,7 +92,9 @@ class LoopNode(Node):
|
|||
if self._node_data.loop_variables:
|
||||
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
|
||||
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
|
||||
"variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value),
|
||||
"variable": lambda var: self.graph_runtime_state.variable_pool.get(var.value)
|
||||
if isinstance(var.value, list)
|
||||
else None,
|
||||
}
|
||||
for loop_variable in self._node_data.loop_variables:
|
||||
if loop_variable.value_type not in value_processor:
|
||||
|
|
|
|||
|
|
@ -326,7 +326,7 @@ def _build_from_datasource_file(
|
|||
) -> File:
|
||||
datasource_file = (
|
||||
db.session.query(UploadFile)
|
||||
.filter(
|
||||
.where(
|
||||
UploadFile.id == mapping.get("datasource_file_id"),
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import re
|
|||
import time
|
||||
from datetime import datetime
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Optional, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func, select
|
||||
|
|
@ -76,13 +76,13 @@ class Dataset(Base):
|
|||
|
||||
@property
|
||||
def total_documents(self):
|
||||
return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()
|
||||
return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
|
||||
|
||||
@property
|
||||
def total_available_documents(self):
|
||||
return (
|
||||
db.session.query(func.count(Document.id))
|
||||
.filter(
|
||||
.where(
|
||||
Document.dataset_id == self.id,
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
|
|
@ -173,10 +173,10 @@ class Dataset(Base):
|
|||
)
|
||||
|
||||
@property
|
||||
def doc_form(self) -> Optional[str]:
|
||||
def doc_form(self) -> str | None:
|
||||
if self.chunk_structure:
|
||||
return self.chunk_structure
|
||||
document = db.session.query(Document).filter(Document.dataset_id == self.id).first()
|
||||
document = db.session.query(Document).where(Document.dataset_id == self.id).first()
|
||||
if document:
|
||||
return document.doc_form
|
||||
return None
|
||||
|
|
@ -234,7 +234,7 @@ class Dataset(Base):
|
|||
@property
|
||||
def is_published(self):
|
||||
if self.pipeline_id:
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == self.pipeline_id).first()
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first()
|
||||
if pipeline:
|
||||
return pipeline.is_published
|
||||
return False
|
||||
|
|
@ -1244,7 +1244,7 @@ class PipelineBuiltInTemplate(Base): # type: ignore[name-defined]
|
|||
|
||||
@property
|
||||
def created_user_name(self):
|
||||
account = db.session.query(Account).filter(Account.id == self.created_by).first()
|
||||
account = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
if account:
|
||||
return account.name
|
||||
return ""
|
||||
|
|
@ -1274,7 +1274,7 @@ class PipelineCustomizedTemplate(Base): # type: ignore[name-defined]
|
|||
|
||||
@property
|
||||
def created_user_name(self):
|
||||
account = db.session.query(Account).filter(Account.id == self.created_by).first()
|
||||
account = db.session.query(Account).where(Account.id == self.created_by).first()
|
||||
if account:
|
||||
return account.name
|
||||
return ""
|
||||
|
|
@ -1297,7 +1297,7 @@ class Pipeline(Base): # type: ignore[name-defined]
|
|||
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
def retrieve_dataset(self, session: Session):
|
||||
return session.query(Dataset).filter(Dataset.pipeline_id == self.id).first()
|
||||
return session.query(Dataset).where(Dataset.pipeline_id == self.id).first()
|
||||
|
||||
|
||||
class DocumentPipelineExecutionLog(Base):
|
||||
|
|
|
|||
|
|
@ -1546,7 +1546,7 @@ class WorkflowDraftVariableFile(Base):
|
|||
comment="Size of the original variable content in bytes",
|
||||
)
|
||||
|
||||
length: Mapped[Optional[int]] = mapped_column(
|
||||
length: Mapped[int | None] = mapped_column(
|
||||
sa.Integer,
|
||||
nullable=True,
|
||||
comment=(
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import time
|
|||
import uuid
|
||||
from collections import Counter
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Literal
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import exists, func, select
|
||||
|
|
@ -315,8 +315,8 @@ class DatasetService:
|
|||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def get_dataset(dataset_id) -> Optional[Dataset]:
|
||||
dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
def get_dataset(dataset_id) -> Dataset | None:
|
||||
dataset: Dataset | None = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -419,7 +419,7 @@ class DatasetService:
|
|||
def _has_dataset_same_name(tenant_id: str, dataset_id: str, name: str):
|
||||
dataset = (
|
||||
db.session.query(Dataset)
|
||||
.filter(
|
||||
.where(
|
||||
Dataset.id != dataset_id,
|
||||
Dataset.name == name,
|
||||
Dataset.tenant_id == tenant_id,
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ class DatasourceProviderService:
|
|||
provider_id=f"{plugin_id}/{provider}",
|
||||
credential_type=CredentialType.of(datasource_provider.auth_type),
|
||||
)
|
||||
encrypted_credentials = raw_credentials.copy()
|
||||
encrypted_credentials = dict(raw_credentials)
|
||||
for key, value in encrypted_credentials.items():
|
||||
if key in provider_credential_secret_variables:
|
||||
encrypted_credentials[key] = encrypter.encrypt_token(tenant_id, value)
|
||||
|
|
@ -690,7 +690,7 @@ class DatasourceProviderService:
|
|||
# Get all provider configurations of the current workspace
|
||||
datasource_providers: list[DatasourceProvider] = (
|
||||
db.session.query(DatasourceProvider)
|
||||
.filter(
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
|
|
@ -862,7 +862,7 @@ class DatasourceProviderService:
|
|||
# Get all provider configurations of the current workspace
|
||||
datasource_providers: list[DatasourceProvider] = (
|
||||
db.session.query(DatasourceProvider)
|
||||
.filter(
|
||||
.where(
|
||||
DatasourceProvider.tenant_id == tenant_id,
|
||||
DatasourceProvider.provider == provider,
|
||||
DatasourceProvider.plugin_id == plugin_id,
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
from typing import Literal, Optional
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class IconInfo(BaseModel):
|
||||
icon: str
|
||||
icon_background: Optional[str] = None
|
||||
icon_type: Optional[str] = None
|
||||
icon_url: Optional[str] = None
|
||||
icon_background: str | None = None
|
||||
icon_type: str | None = None
|
||||
icon_url: str | None = None
|
||||
|
||||
|
||||
class PipelineTemplateInfoEntity(BaseModel):
|
||||
|
|
@ -21,8 +21,8 @@ class RagPipelineDatasetCreateEntity(BaseModel):
|
|||
description: str
|
||||
icon_info: IconInfo
|
||||
permission: str
|
||||
partial_member_list: Optional[list[str]] = None
|
||||
yaml_content: Optional[str] = None
|
||||
partial_member_list: list[str] | None = None
|
||||
yaml_content: str | None = None
|
||||
|
||||
|
||||
class RerankingModelConfig(BaseModel):
|
||||
|
|
@ -30,8 +30,8 @@ class RerankingModelConfig(BaseModel):
|
|||
Reranking Model Config.
|
||||
"""
|
||||
|
||||
reranking_provider_name: Optional[str] = ""
|
||||
reranking_model_name: Optional[str] = ""
|
||||
reranking_provider_name: str | None = ""
|
||||
reranking_model_name: str | None = ""
|
||||
|
||||
|
||||
class VectorSetting(BaseModel):
|
||||
|
|
@ -57,8 +57,8 @@ class WeightedScoreConfig(BaseModel):
|
|||
Weighted score Config.
|
||||
"""
|
||||
|
||||
vector_setting: Optional[VectorSetting]
|
||||
keyword_setting: Optional[KeywordSetting]
|
||||
vector_setting: VectorSetting | None
|
||||
keyword_setting: KeywordSetting | None
|
||||
|
||||
|
||||
class EmbeddingSetting(BaseModel):
|
||||
|
|
@ -85,12 +85,12 @@ class RetrievalSetting(BaseModel):
|
|||
|
||||
search_method: Literal["semantic_search", "fulltext_search", "keyword_search", "hybrid_search"]
|
||||
top_k: int
|
||||
score_threshold: Optional[float] = 0.5
|
||||
score_threshold: float | None = 0.5
|
||||
score_threshold_enabled: bool = False
|
||||
reranking_mode: Optional[str] = "reranking_model"
|
||||
reranking_enable: Optional[bool] = True
|
||||
reranking_model: Optional[RerankingModelConfig] = None
|
||||
weights: Optional[WeightedScoreConfig] = None
|
||||
reranking_mode: str | None = "reranking_model"
|
||||
reranking_enable: bool | None = True
|
||||
reranking_model: RerankingModelConfig | None = None
|
||||
weights: WeightedScoreConfig | None = None
|
||||
|
||||
|
||||
class IndexMethod(BaseModel):
|
||||
|
|
@ -112,7 +112,7 @@ class KnowledgeConfiguration(BaseModel):
|
|||
indexing_technique: Literal["high_quality", "economy"]
|
||||
embedding_model_provider: str = ""
|
||||
embedding_model: str = ""
|
||||
keyword_number: Optional[int] = 10
|
||||
keyword_number: int | None = 10
|
||||
retrieval_model: RetrievalSetting
|
||||
|
||||
@field_validator("embedding_model_provider", mode="before")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from collections.abc import Mapping
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -7,9 +7,9 @@ from pydantic import BaseModel
|
|||
class DatasourceNodeRunApiEntity(BaseModel):
|
||||
pipeline_id: str
|
||||
node_id: str
|
||||
inputs: Mapping[str, Any]
|
||||
inputs: dict[str, Any]
|
||||
datasource_type: str
|
||||
credential_id: Optional[str] = None
|
||||
credential_id: str | None = None
|
||||
is_published: bool
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ class PipelineGenerateService:
|
|||
Update document status to waiting
|
||||
:param document_id: document id
|
||||
"""
|
||||
document = db.session.query(Document).filter(Document.id == document_id).first()
|
||||
document = db.session.query(Document).where(Document.id == document_id).first()
|
||||
if document:
|
||||
document.indexing_status = "waiting"
|
||||
db.session.add(document)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import json
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
|
|
@ -14,7 +13,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||
Retrieval pipeline template from built-in, the location is constants/pipeline_templates.json
|
||||
"""
|
||||
|
||||
builtin_data: Optional[dict] = None
|
||||
builtin_data: dict | None = None
|
||||
|
||||
def get_type(self) -> str:
|
||||
return PipelineTemplateType.BUILTIN
|
||||
|
|
@ -54,7 +53,7 @@ class BuiltInPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||
return builtin_data.get("pipeline_templates", {}).get(language, {})
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> Optional[dict]:
|
||||
def fetch_pipeline_template_detail_from_builtin(cls, template_id: str) -> dict | None:
|
||||
"""
|
||||
Fetch pipeline template detail from builtin.
|
||||
:param template_id: Template ID
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
from flask_login import current_user
|
||||
|
||||
|
|
@ -37,7 +35,7 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||
"""
|
||||
pipeline_customized_templates = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
.filter(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language)
|
||||
.where(PipelineCustomizedTemplate.tenant_id == tenant_id, PipelineCustomizedTemplate.language == language)
|
||||
.order_by(PipelineCustomizedTemplate.position.asc(), PipelineCustomizedTemplate.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
|
@ -56,14 +54,14 @@ class CustomizedPipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||
return {"pipeline_templates": recommended_pipelines_results}
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]:
|
||||
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None:
|
||||
"""
|
||||
Fetch pipeline template detail from db.
|
||||
:param template_id: Template ID
|
||||
:return:
|
||||
"""
|
||||
pipeline_template = (
|
||||
db.session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
|
||||
db.session.query(PipelineCustomizedTemplate).where(PipelineCustomizedTemplate.id == template_id).first()
|
||||
)
|
||||
if not pipeline_template:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
|
@ -33,7 +31,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||
"""
|
||||
|
||||
pipeline_built_in_templates: list[PipelineBuiltInTemplate] = (
|
||||
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.language == language).all()
|
||||
db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.language == language).all()
|
||||
)
|
||||
|
||||
recommended_pipelines_results = []
|
||||
|
|
@ -53,7 +51,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||
return {"pipeline_templates": recommended_pipelines_results}
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> Optional[dict]:
|
||||
def fetch_pipeline_template_detail_from_db(cls, template_id: str) -> dict | None:
|
||||
"""
|
||||
Fetch pipeline template detail from db.
|
||||
:param pipeline_id: Pipeline ID
|
||||
|
|
@ -61,7 +59,7 @@ class DatabasePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||
"""
|
||||
# is in public recommended list
|
||||
pipeline_template = (
|
||||
db.session.query(PipelineBuiltInTemplate).filter(PipelineBuiltInTemplate.id == template_id).first()
|
||||
db.session.query(PipelineBuiltInTemplate).where(PipelineBuiltInTemplate.id == template_id).first()
|
||||
)
|
||||
|
||||
if not pipeline_template:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class PipelineTemplateRetrievalBase(ABC):
|
||||
|
|
@ -10,7 +9,7 @@ class PipelineTemplateRetrievalBase(ABC):
|
|||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_pipeline_template_detail(self, template_id: str) -> Optional[dict]:
|
||||
def get_pipeline_template_detail(self, template_id: str) -> dict | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
|
|
@ -36,7 +35,7 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase):
|
|||
return PipelineTemplateType.REMOTE
|
||||
|
||||
@classmethod
|
||||
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> Optional[dict]:
|
||||
def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict | None:
|
||||
"""
|
||||
Fetch pipeline template detail from dify official.
|
||||
:param template_id: Pipeline ID
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import threading
|
|||
import time
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, Union, cast
|
||||
from typing import Any, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from flask_login import current_user
|
||||
|
|
@ -112,7 +112,7 @@ class RagPipelineService:
|
|||
return result
|
||||
|
||||
@classmethod
|
||||
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> Optional[dict]:
|
||||
def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None:
|
||||
"""
|
||||
Get pipeline template detail.
|
||||
:param template_id: template id
|
||||
|
|
@ -121,12 +121,12 @@ class RagPipelineService:
|
|||
if type == "built-in":
|
||||
mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
built_in_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
|
||||
built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id)
|
||||
return built_in_result
|
||||
else:
|
||||
mode = "customized"
|
||||
retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)()
|
||||
customized_result: Optional[dict] = retrieval_instance.get_pipeline_template_detail(template_id)
|
||||
customized_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id)
|
||||
return customized_result
|
||||
|
||||
@classmethod
|
||||
|
|
@ -138,7 +138,7 @@ class RagPipelineService:
|
|||
"""
|
||||
customized_template: PipelineCustomizedTemplate | None = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
.filter(
|
||||
.where(
|
||||
PipelineCustomizedTemplate.id == template_id,
|
||||
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
|
|
@ -151,7 +151,7 @@ class RagPipelineService:
|
|||
if template_name:
|
||||
template = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
.filter(
|
||||
.where(
|
||||
PipelineCustomizedTemplate.name == template_name,
|
||||
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
|
||||
PipelineCustomizedTemplate.id != template_id,
|
||||
|
|
@ -174,7 +174,7 @@ class RagPipelineService:
|
|||
"""
|
||||
customized_template: PipelineCustomizedTemplate | None = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
.filter(
|
||||
.where(
|
||||
PipelineCustomizedTemplate.id == template_id,
|
||||
PipelineCustomizedTemplate.tenant_id == current_user.current_tenant_id,
|
||||
)
|
||||
|
|
@ -185,14 +185,14 @@ class RagPipelineService:
|
|||
db.session.delete(customized_template)
|
||||
db.session.commit()
|
||||
|
||||
def get_draft_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
|
||||
def get_draft_workflow(self, pipeline: Pipeline) -> Workflow | None:
|
||||
"""
|
||||
Get draft workflow
|
||||
"""
|
||||
# fetch draft workflow by rag pipeline
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == "draft",
|
||||
|
|
@ -203,7 +203,7 @@ class RagPipelineService:
|
|||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def get_published_workflow(self, pipeline: Pipeline) -> Optional[Workflow]:
|
||||
def get_published_workflow(self, pipeline: Pipeline) -> Workflow | None:
|
||||
"""
|
||||
Get published workflow
|
||||
"""
|
||||
|
|
@ -214,7 +214,7 @@ class RagPipelineService:
|
|||
# fetch published workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.id == pipeline.workflow_id,
|
||||
|
|
@ -267,7 +267,7 @@ class RagPipelineService:
|
|||
*,
|
||||
pipeline: Pipeline,
|
||||
graph: dict,
|
||||
unique_hash: Optional[str],
|
||||
unique_hash: str | None,
|
||||
account: Account,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
|
|
@ -378,18 +378,16 @@ class RagPipelineService:
|
|||
Get default block configs
|
||||
"""
|
||||
# return default block config
|
||||
default_block_configs = []
|
||||
default_block_configs: list[dict[str, Any]] = []
|
||||
for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
|
||||
node_class = node_class_mapping[LATEST_VERSION]
|
||||
default_config = node_class.get_default_config()
|
||||
if default_config:
|
||||
default_block_configs.append(default_config)
|
||||
default_block_configs.append(dict(default_config))
|
||||
|
||||
return default_block_configs
|
||||
|
||||
def get_default_block_config(
|
||||
self, node_type: str, filters: Optional[dict] = None
|
||||
) -> Optional[Mapping[str, object]]:
|
||||
def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None:
|
||||
"""
|
||||
Get default config of node.
|
||||
:param node_type: node type
|
||||
|
|
@ -495,7 +493,7 @@ class RagPipelineService:
|
|||
account: Account,
|
||||
datasource_type: str,
|
||||
is_published: bool,
|
||||
credential_id: Optional[str] = None,
|
||||
credential_id: str | None = None,
|
||||
) -> Generator[Mapping[str, Any], None, None]:
|
||||
"""
|
||||
Run published workflow datasource
|
||||
|
|
@ -580,10 +578,10 @@ class RagPipelineService:
|
|||
)
|
||||
yield start_event.model_dump()
|
||||
try:
|
||||
for message in online_document_result:
|
||||
for online_document_message in online_document_result:
|
||||
end_time = time.time()
|
||||
online_document_event = DatasourceCompletedEvent(
|
||||
data=message.result, time_consuming=round(end_time - start_time, 2)
|
||||
data=online_document_message.result, time_consuming=round(end_time - start_time, 2)
|
||||
)
|
||||
yield online_document_event.model_dump()
|
||||
except Exception as e:
|
||||
|
|
@ -609,10 +607,10 @@ class RagPipelineService:
|
|||
completed=0,
|
||||
)
|
||||
yield start_event.model_dump()
|
||||
for message in online_drive_result:
|
||||
for online_drive_message in online_drive_result:
|
||||
end_time = time.time()
|
||||
online_drive_event = DatasourceCompletedEvent(
|
||||
data=message.result,
|
||||
data=online_drive_message.result,
|
||||
time_consuming=round(end_time - start_time, 2),
|
||||
total=None,
|
||||
completed=None,
|
||||
|
|
@ -629,19 +627,20 @@ class RagPipelineService:
|
|||
)
|
||||
start_time = time.time()
|
||||
try:
|
||||
for message in website_crawl_result:
|
||||
for website_crawl_message in website_crawl_result:
|
||||
end_time = time.time()
|
||||
if message.result.status == "completed":
|
||||
crawl_event: DatasourceCompletedEvent | DatasourceProcessingEvent
|
||||
if website_crawl_message.result.status == "completed":
|
||||
crawl_event = DatasourceCompletedEvent(
|
||||
data=message.result.web_info_list or [],
|
||||
total=message.result.total,
|
||||
completed=message.result.completed,
|
||||
data=website_crawl_message.result.web_info_list or [],
|
||||
total=website_crawl_message.result.total,
|
||||
completed=website_crawl_message.result.completed,
|
||||
time_consuming=round(end_time - start_time, 2),
|
||||
)
|
||||
else:
|
||||
crawl_event = DatasourceProcessingEvent(
|
||||
total=message.result.total,
|
||||
completed=message.result.completed,
|
||||
total=website_crawl_message.result.total,
|
||||
completed=website_crawl_message.result.completed,
|
||||
)
|
||||
yield crawl_event.model_dump()
|
||||
except Exception as e:
|
||||
|
|
@ -661,7 +660,7 @@ class RagPipelineService:
|
|||
account: Account,
|
||||
datasource_type: str,
|
||||
is_published: bool,
|
||||
credential_id: Optional[str] = None,
|
||||
credential_id: str | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Run published workflow datasource
|
||||
|
|
@ -723,12 +722,12 @@ class RagPipelineService:
|
|||
)
|
||||
try:
|
||||
variables: dict[str, Any] = {}
|
||||
for message in online_document_result:
|
||||
if message.type == DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
for online_document_message in online_document_result:
|
||||
if online_document_message.type == DatasourceMessage.MessageType.VARIABLE:
|
||||
assert isinstance(online_document_message.message, DatasourceMessage.VariableMessage)
|
||||
variable_name = online_document_message.message.variable_name
|
||||
variable_value = online_document_message.message.variable_value
|
||||
if online_document_message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||
if variable_name not in variables:
|
||||
|
|
@ -793,8 +792,9 @@ class RagPipelineService:
|
|||
for event in generator:
|
||||
if isinstance(event, (NodeRunSucceededEvent, NodeRunFailedEvent)):
|
||||
node_run_result = event.node_run_result
|
||||
# sign output files
|
||||
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) or {}
|
||||
if node_run_result:
|
||||
# sign output files
|
||||
node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) or {}
|
||||
break
|
||||
|
||||
if not node_run_result:
|
||||
|
|
@ -876,7 +876,7 @@ class RagPipelineService:
|
|||
if invoke_from.value == InvokeFrom.PUBLISHED.value:
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if document_id:
|
||||
document = db.session.query(Document).filter(Document.id == document_id.value).first()
|
||||
document = db.session.query(Document).where(Document.id == document_id.value).first()
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = error
|
||||
|
|
@ -887,7 +887,7 @@ class RagPipelineService:
|
|||
|
||||
def update_workflow(
|
||||
self, *, session: Session, workflow_id: str, tenant_id: str, account_id: str, data: dict
|
||||
) -> Optional[Workflow]:
|
||||
) -> Workflow | None:
|
||||
"""
|
||||
Update workflow attributes
|
||||
|
||||
|
|
@ -1016,7 +1016,7 @@ class RagPipelineService:
|
|||
"""
|
||||
limit = int(args.get("limit", 20))
|
||||
|
||||
base_query = db.session.query(WorkflowRun).filter(
|
||||
base_query = db.session.query(WorkflowRun).where(
|
||||
WorkflowRun.tenant_id == pipeline.tenant_id,
|
||||
WorkflowRun.app_id == pipeline.id,
|
||||
or_(
|
||||
|
|
@ -1026,7 +1026,7 @@ class RagPipelineService:
|
|||
)
|
||||
|
||||
if args.get("last_id"):
|
||||
last_workflow_run = base_query.filter(
|
||||
last_workflow_run = base_query.where(
|
||||
WorkflowRun.id == args.get("last_id"),
|
||||
).first()
|
||||
|
||||
|
|
@ -1034,7 +1034,7 @@ class RagPipelineService:
|
|||
raise ValueError("Last workflow run not exists")
|
||||
|
||||
workflow_runs = (
|
||||
base_query.filter(
|
||||
base_query.where(
|
||||
WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
|
||||
)
|
||||
.order_by(WorkflowRun.created_at.desc())
|
||||
|
|
@ -1047,7 +1047,7 @@ class RagPipelineService:
|
|||
has_more = False
|
||||
if len(workflow_runs) == limit:
|
||||
current_page_first_workflow_run = workflow_runs[-1]
|
||||
rest_count = base_query.filter(
|
||||
rest_count = base_query.where(
|
||||
WorkflowRun.created_at < current_page_first_workflow_run.created_at,
|
||||
WorkflowRun.id != current_page_first_workflow_run.id,
|
||||
).count()
|
||||
|
|
@ -1057,7 +1057,7 @@ class RagPipelineService:
|
|||
|
||||
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
|
||||
|
||||
def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> Optional[WorkflowRun]:
|
||||
def get_rag_pipeline_workflow_run(self, pipeline: Pipeline, run_id: str) -> WorkflowRun | None:
|
||||
"""
|
||||
Get workflow run detail
|
||||
|
||||
|
|
@ -1066,7 +1066,7 @@ class RagPipelineService:
|
|||
"""
|
||||
workflow_run = (
|
||||
db.session.query(WorkflowRun)
|
||||
.filter(
|
||||
.where(
|
||||
WorkflowRun.tenant_id == pipeline.tenant_id,
|
||||
WorkflowRun.app_id == pipeline.id,
|
||||
WorkflowRun.id == run_id,
|
||||
|
|
@ -1113,12 +1113,12 @@ class RagPipelineService:
|
|||
"""
|
||||
Publish customized pipeline template
|
||||
"""
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
if not pipeline.workflow_id:
|
||||
raise ValueError("Pipeline workflow not found")
|
||||
workflow = db.session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
|
||||
workflow = db.session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
with Session(db.engine) as session:
|
||||
|
|
@ -1131,7 +1131,7 @@ class RagPipelineService:
|
|||
if template_name:
|
||||
template = (
|
||||
db.session.query(PipelineCustomizedTemplate)
|
||||
.filter(
|
||||
.where(
|
||||
PipelineCustomizedTemplate.name == template_name,
|
||||
PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id,
|
||||
)
|
||||
|
|
@ -1142,7 +1142,7 @@ class RagPipelineService:
|
|||
|
||||
max_position = (
|
||||
db.session.query(func.max(PipelineCustomizedTemplate.position))
|
||||
.filter(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
|
||||
.where(PipelineCustomizedTemplate.tenant_id == pipeline.tenant_id)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
|
|
@ -1169,7 +1169,7 @@ class RagPipelineService:
|
|||
def is_workflow_exist(self, pipeline: Pipeline) -> bool:
|
||||
return (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
|
|
@ -1278,7 +1278,7 @@ class RagPipelineService:
|
|||
# Query active recommended plugins
|
||||
pipeline_recommended_plugins = (
|
||||
db.session.query(PipelineRecommendedPlugin)
|
||||
.filter(PipelineRecommendedPlugin.active == True)
|
||||
.where(PipelineRecommendedPlugin.active == True)
|
||||
.order_by(PipelineRecommendedPlugin.position.asc())
|
||||
.all()
|
||||
)
|
||||
|
|
@ -1329,12 +1329,12 @@ class RagPipelineService:
|
|||
"""
|
||||
document_pipeline_excution_log = (
|
||||
db.session.query(DocumentPipelineExecutionLog)
|
||||
.filter(DocumentPipelineExecutionLog.document_id == document.id)
|
||||
.where(DocumentPipelineExecutionLog.document_id == document.id)
|
||||
.first()
|
||||
)
|
||||
if not document_pipeline_excution_log:
|
||||
raise ValueError("Document pipeline execution log not found")
|
||||
pipeline = db.session.query(Pipeline).filter(Pipeline.id == document_pipeline_excution_log.pipeline_id).first()
|
||||
pipeline = db.session.query(Pipeline).where(Pipeline.id == document_pipeline_excution_log.pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
# convert to app config
|
||||
|
|
@ -1358,3 +1358,99 @@ class RagPipelineService:
|
|||
workflow_thread_pool_id=None,
|
||||
is_retry=True,
|
||||
)
|
||||
|
||||
def get_datasource_plugins(self, tenant_id: str, dataset_id: str, is_published: bool) -> list[dict]:
|
||||
"""
|
||||
Get datasource plugins
|
||||
"""
|
||||
dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
|
||||
workflow: Workflow | None = None
|
||||
if is_published:
|
||||
workflow = self.get_published_workflow(pipeline=pipeline)
|
||||
else:
|
||||
workflow = self.get_draft_workflow(pipeline=pipeline)
|
||||
if not pipeline or not workflow:
|
||||
raise ValueError("Pipeline or workflow not found")
|
||||
|
||||
datasource_nodes = workflow.graph_dict.get("nodes", [])
|
||||
datasource_plugins = []
|
||||
for datasource_node in datasource_nodes:
|
||||
if datasource_node.get("type") == "datasource":
|
||||
datasource_node_data = datasource_node.get("data", {})
|
||||
if not datasource_node_data:
|
||||
continue
|
||||
|
||||
variables = workflow.rag_pipeline_variables
|
||||
if variables:
|
||||
variables_map = {item["variable"]: item for item in variables}
|
||||
else:
|
||||
variables_map = {}
|
||||
|
||||
datasource_parameters = datasource_node_data.get("datasource_parameters", {})
|
||||
user_input_variables_keys = []
|
||||
user_input_variables = []
|
||||
|
||||
for _, value in datasource_parameters.items():
|
||||
if value.get("value") and isinstance(value.get("value"), str):
|
||||
pattern = r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z0-9_][a-zA-Z0-9_]{0,29}){1,10})#\}\}"
|
||||
match = re.match(pattern, value["value"])
|
||||
if match:
|
||||
full_path = match.group(1)
|
||||
last_part = full_path.split(".")[-1]
|
||||
user_input_variables_keys.append(last_part)
|
||||
elif value.get("value") and isinstance(value.get("value"), list):
|
||||
last_part = value.get("value")[-1]
|
||||
user_input_variables_keys.append(last_part)
|
||||
for key, value in variables_map.items():
|
||||
if key in user_input_variables_keys:
|
||||
user_input_variables.append(value)
|
||||
|
||||
# get credentials
|
||||
datasource_provider_service: DatasourceProviderService = DatasourceProviderService()
|
||||
credentials: list[dict[Any, Any]] = datasource_provider_service.list_datasource_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=datasource_node_data.get("provider_name"),
|
||||
plugin_id=datasource_node_data.get("plugin_id"),
|
||||
)
|
||||
credential_info_list: list[Any] = []
|
||||
for credential in credentials:
|
||||
credential_info_list.append(
|
||||
{
|
||||
"id": credential.get("id"),
|
||||
"name": credential.get("name"),
|
||||
"type": credential.get("type"),
|
||||
"is_default": credential.get("is_default"),
|
||||
}
|
||||
)
|
||||
|
||||
datasource_plugins.append(
|
||||
{
|
||||
"node_id": datasource_node.get("id"),
|
||||
"plugin_id": datasource_node_data.get("plugin_id"),
|
||||
"provider_name": datasource_node_data.get("provider_name"),
|
||||
"datasource_type": datasource_node_data.get("provider_type"),
|
||||
"title": datasource_node_data.get("title"),
|
||||
"user_input_variables": user_input_variables,
|
||||
"credentials": credential_info_list,
|
||||
}
|
||||
)
|
||||
|
||||
return datasource_plugins
|
||||
|
||||
def get_pipeline(self, tenant_id: str, dataset_id: str) -> Pipeline:
|
||||
"""
|
||||
Get pipeline
|
||||
"""
|
||||
dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError("Pipeline not found")
|
||||
return pipeline
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import uuid
|
|||
from collections.abc import Mapping
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
from typing import Optional, cast
|
||||
from typing import cast
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
|
|
@ -66,11 +66,11 @@ class ImportStatus(StrEnum):
|
|||
class RagPipelineImportInfo(BaseModel):
|
||||
id: str
|
||||
status: ImportStatus
|
||||
pipeline_id: Optional[str] = None
|
||||
pipeline_id: str | None = None
|
||||
current_dsl_version: str = CURRENT_DSL_VERSION
|
||||
imported_dsl_version: str = ""
|
||||
error: str = ""
|
||||
dataset_id: Optional[str] = None
|
||||
dataset_id: str | None = None
|
||||
|
||||
|
||||
class CheckDependenciesResult(BaseModel):
|
||||
|
|
@ -121,12 +121,12 @@ class RagPipelineDslService:
|
|||
*,
|
||||
account: Account,
|
||||
import_mode: str,
|
||||
yaml_content: Optional[str] = None,
|
||||
yaml_url: Optional[str] = None,
|
||||
pipeline_id: Optional[str] = None,
|
||||
dataset: Optional[Dataset] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
icon_info: Optional[IconInfo] = None,
|
||||
yaml_content: str | None = None,
|
||||
yaml_url: str | None = None,
|
||||
pipeline_id: str | None = None,
|
||||
dataset: Dataset | None = None,
|
||||
dataset_name: str | None = None,
|
||||
icon_info: IconInfo | None = None,
|
||||
) -> RagPipelineImportInfo:
|
||||
"""Import an app from YAML content or URL."""
|
||||
import_id = str(uuid.uuid4())
|
||||
|
|
@ -318,7 +318,7 @@ class RagPipelineDslService:
|
|||
if knowledge_configuration.indexing_technique == "high_quality":
|
||||
dataset_collection_binding = (
|
||||
self._session.query(DatasetCollectionBinding)
|
||||
.filter(
|
||||
.where(
|
||||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.embedding_model_provider,
|
||||
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
|
||||
|
|
@ -452,7 +452,7 @@ class RagPipelineDslService:
|
|||
if knowledge_configuration.indexing_technique == "high_quality":
|
||||
dataset_collection_binding = (
|
||||
self._session.query(DatasetCollectionBinding)
|
||||
.filter(
|
||||
.where(
|
||||
DatasetCollectionBinding.provider_name
|
||||
== knowledge_configuration.embedding_model_provider,
|
||||
DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model,
|
||||
|
|
@ -530,10 +530,10 @@ class RagPipelineDslService:
|
|||
def _create_or_update_pipeline(
|
||||
self,
|
||||
*,
|
||||
pipeline: Optional[Pipeline],
|
||||
pipeline: Pipeline | None,
|
||||
data: dict,
|
||||
account: Account,
|
||||
dependencies: Optional[list[PluginDependency]] = None,
|
||||
dependencies: list[PluginDependency] | None = None,
|
||||
) -> Pipeline:
|
||||
"""Create a new app or update an existing one."""
|
||||
if not account.current_tenant_id:
|
||||
|
|
@ -599,7 +599,7 @@ class RagPipelineDslService:
|
|||
)
|
||||
workflow = (
|
||||
self._session.query(Workflow)
|
||||
.filter(
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == "draft",
|
||||
|
|
@ -673,7 +673,7 @@ class RagPipelineDslService:
|
|||
|
||||
workflow = (
|
||||
self._session.query(Workflow)
|
||||
.filter(
|
||||
.where(
|
||||
Workflow.tenant_id == pipeline.tenant_id,
|
||||
Workflow.app_id == pipeline.id,
|
||||
Workflow.version == "draft",
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import json
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
|
|
@ -21,7 +20,7 @@ from services.plugin.plugin_service import PluginService
|
|||
|
||||
class RagPipelineTransformService:
|
||||
def transform_dataset(self, dataset_id: str):
|
||||
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise ValueError("Dataset not found")
|
||||
if dataset.pipeline_id and dataset.runtime_mode == "rag_pipeline":
|
||||
|
|
@ -90,7 +89,7 @@ class RagPipelineTransformService:
|
|||
"status": "success",
|
||||
}
|
||||
|
||||
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: Optional[str]):
|
||||
def _get_transform_yaml(self, doc_form: str, datasource_type: str, indexing_technique: str | None):
|
||||
pipeline_yaml = {}
|
||||
if doc_form == "text_model":
|
||||
match datasource_type:
|
||||
|
|
@ -152,7 +151,7 @@ class RagPipelineTransformService:
|
|||
return node
|
||||
|
||||
def _deal_knowledge_index(
|
||||
self, dataset: Dataset, doc_form: str, indexing_technique: Optional[str], retrieval_model: dict, node: dict
|
||||
self, dataset: Dataset, doc_form: str, indexing_technique: str | None, retrieval_model: dict, node: dict
|
||||
):
|
||||
knowledge_configuration_dict = node.get("data", {})
|
||||
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration_dict)
|
||||
|
|
@ -289,7 +288,7 @@ class RagPipelineTransformService:
|
|||
jina_node_id = "1752491761974"
|
||||
firecrawl_node_id = "1752565402678"
|
||||
|
||||
documents = db.session.query(Document).filter(Document.dataset_id == dataset.id).all()
|
||||
documents = db.session.query(Document).where(Document.dataset_id == dataset.id).all()
|
||||
|
||||
for document in documents:
|
||||
data_source_info_dict = document.data_source_info_dict
|
||||
|
|
@ -299,7 +298,7 @@ class RagPipelineTransformService:
|
|||
document.data_source_type = "local_file"
|
||||
file_id = data_source_info_dict.get("upload_file_id")
|
||||
if file_id:
|
||||
file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
|
||||
file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
if file:
|
||||
data_source_info = json.dumps(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -99,12 +99,6 @@ class ToolTransformService:
|
|||
provider.declaration.identity.icon = ToolTransformService.get_plugin_icon_url(
|
||||
tenant_id=tenant_id, filename=provider.declaration.identity.icon
|
||||
)
|
||||
else:
|
||||
provider.declaration.identity.icon = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider.type.value,
|
||||
provider_name=provider.name,
|
||||
icon=provider.declaration.identity.icon,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def builtin_provider_to_user_provider(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
|
@ -17,7 +16,7 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: Optional[str], file_ids: list[str]):
|
||||
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]):
|
||||
"""
|
||||
Clean document when document deleted.
|
||||
:param document_ids: document ids
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
|||
if action == "upgrade":
|
||||
dataset_documents = (
|
||||
db.session.query(DatasetDocument)
|
||||
.filter(
|
||||
.where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
|
|
@ -44,7 +44,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
|||
|
||||
if dataset_documents:
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
|
|
@ -54,7 +54,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
|||
# add from vector index
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
|
||||
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
|
|
@ -76,19 +76,19 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
|||
# clean keywords
|
||||
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
elif action == "update":
|
||||
dataset_documents = (
|
||||
db.session.query(DatasetDocument)
|
||||
.filter(
|
||||
.where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
|
|
@ -100,7 +100,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
|||
if dataset_documents:
|
||||
# update document status
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
|
|
@ -113,7 +113,7 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
|||
try:
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
|
||||
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
|
|
@ -148,12 +148,12 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
|
|||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
|
|
|
|||
|
|
@ -104,20 +104,20 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
|
|||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Load required entities
|
||||
account = session.query(Account).filter(Account.id == user_id).first()
|
||||
account = session.query(Account).where(Account.id == user_id).first()
|
||||
if not account:
|
||||
raise ValueError(f"Account {user_id} not found")
|
||||
|
||||
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
|
||||
tenant = session.query(Tenant).where(Tenant.id == tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant {tenant_id} not found")
|
||||
account.current_tenant = tenant
|
||||
|
||||
pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
|
||||
pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError(f"Pipeline {pipeline_id} not found")
|
||||
|
||||
workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
|
||||
workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -125,20 +125,20 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
|
|||
|
||||
with Session(db.engine) as session:
|
||||
# Load required entities
|
||||
account = session.query(Account).filter(Account.id == user_id).first()
|
||||
account = session.query(Account).where(Account.id == user_id).first()
|
||||
if not account:
|
||||
raise ValueError(f"Account {user_id} not found")
|
||||
|
||||
tenant = session.query(Tenant).filter(Tenant.id == tenant_id).first()
|
||||
tenant = session.query(Tenant).where(Tenant.id == tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant {tenant_id} not found")
|
||||
account.current_tenant = tenant
|
||||
|
||||
pipeline = session.query(Pipeline).filter(Pipeline.id == pipeline_id).first()
|
||||
pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError(f"Pipeline {pipeline_id} not found")
|
||||
|
||||
workflow = session.query(Workflow).filter(Workflow.id == pipeline.workflow_id).first()
|
||||
workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
|
|||
if not user:
|
||||
logger.info(click.style(f"User not found: {user_id}", fg="red"))
|
||||
return
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == dataset.tenant_id).first()
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == dataset.tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError("Tenant not found")
|
||||
user.current_tenant = tenant
|
||||
|
|
|
|||
Loading…
Reference in New Issue