mirror of https://github.com/langgenius/dify.git
fix style check
This commit is contained in:
parent
f963eb525c
commit
69a402ba99
|
|
@ -1429,9 +1429,9 @@ def transform_datasource_credentials():
|
|||
notion_plugin_id = "langgenius/notion_datasource"
|
||||
firecrawl_plugin_id = "langgenius/firecrawl_datasource"
|
||||
jina_plugin_id = "langgenius/jina_datasource"
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id)
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id)
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id)
|
||||
notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage]
|
||||
oauth_credential_type = CredentialType.OAUTH2
|
||||
api_key_credential_type = CredentialType.API_KEY
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
from .app_config import DifyConfig
|
||||
|
||||
dify_config = DifyConfig()
|
||||
dify_config = DifyConfig() # pyright: ignore[reportCallIssue]
|
||||
|
|
|
|||
|
|
@ -248,6 +248,8 @@ __all__ = [
|
|||
"datasets",
|
||||
"datasets_document",
|
||||
"datasets_segments",
|
||||
"datasource_auth",
|
||||
"datasource_content_preview",
|
||||
"email_register",
|
||||
"endpoint",
|
||||
"extension",
|
||||
|
|
@ -273,10 +275,16 @@ __all__ = [
|
|||
"parameter",
|
||||
"ping",
|
||||
"plugin",
|
||||
"rag_pipeline",
|
||||
"rag_pipeline_datasets",
|
||||
"rag_pipeline_draft_variable",
|
||||
"rag_pipeline_import",
|
||||
"rag_pipeline_workflow",
|
||||
"recommended_app",
|
||||
"saved_message",
|
||||
"setup",
|
||||
"site",
|
||||
"spec",
|
||||
"statistic",
|
||||
"tags",
|
||||
"tool_providers",
|
||||
|
|
|
|||
|
|
@ -45,9 +45,9 @@ class BaseAppGenerator:
|
|||
mapping=v,
|
||||
tenant_id=tenant_id,
|
||||
config=FileUploadConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types, # pyright: ignore[reportArgumentType]
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, # pyright: ignore[reportArgumentType]
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, # pyright: ignore[reportArgumentType]
|
||||
),
|
||||
strict_type_validation=strict_type_validation,
|
||||
)
|
||||
|
|
@ -60,9 +60,9 @@ class BaseAppGenerator:
|
|||
mappings=v,
|
||||
tenant_id=tenant_id,
|
||||
config=FileUploadConfig(
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types,
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions,
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods,
|
||||
allowed_file_types=entity_dictionary[k].allowed_file_types, # pyright: ignore[reportArgumentType]
|
||||
allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, # pyright: ignore[reportArgumentType]
|
||||
allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, # pyright: ignore[reportArgumentType]
|
||||
),
|
||||
)
|
||||
for k, v in user_inputs.items()
|
||||
|
|
|
|||
|
|
@ -124,7 +124,9 @@ class CompletionAppRunner(AppRunner):
|
|||
config=dataset_config,
|
||||
query=query or "",
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
show_retrieve_source=app_config.additional_features.show_retrieve_source,
|
||||
show_retrieve_source=app_config.additional_features.show_retrieve_source
|
||||
if app_config.additional_features
|
||||
else False,
|
||||
hit_callback=hit_callback,
|
||||
message_id=message.id,
|
||||
inputs=inputs,
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
response_chunk.update(cast(dict, data))
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield response_chunk
|
||||
|
|
@ -87,9 +87,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
|||
|
||||
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
response_chunk.update(cast(dict, data))
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict()))
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.to_dict())
|
||||
yield response_chunk
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class PipelineGenerator(BaseAppGenerator):
|
|||
call_depth: int,
|
||||
workflow_thread_pool_id: Optional[str],
|
||||
is_retry: bool = False,
|
||||
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ...
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
def generate(
|
||||
|
|
|
|||
|
|
@ -260,26 +260,6 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity):
|
|||
original_document_id: Optional[str] = None
|
||||
start_node_id: Optional[str] = None
|
||||
|
||||
class SingleIterationRunEntity(BaseModel):
|
||||
"""
|
||||
Single Iteration Run Entity.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_iteration_run: Optional[SingleIterationRunEntity] = None
|
||||
|
||||
class SingleLoopRunEntity(BaseModel):
|
||||
"""
|
||||
Single Loop Run Entity.
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
|
||||
single_loop_run: Optional[SingleLoopRunEntity] = None
|
||||
|
||||
|
||||
# Import TraceQueueManager at runtime to resolve forward references
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
|
|
|||
|
|
@ -138,6 +138,8 @@ class MessageCycleManager:
|
|||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
if not self._application_generate_entity.app_config.additional_features:
|
||||
raise ValueError("Additional features not found")
|
||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
self._task_state.metadata.retriever_resources = event.retriever_resources
|
||||
|
||||
|
|
|
|||
|
|
@ -136,7 +136,6 @@ class DatasourceFileManager:
|
|||
original_url=file_url,
|
||||
name=filename,
|
||||
size=len(blob),
|
||||
key=filepath,
|
||||
)
|
||||
|
||||
db.session.add(tool_file)
|
||||
|
|
|
|||
|
|
@ -1,265 +0,0 @@
|
|||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
||||
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
|
||||
|
||||
class ProviderConfigEncrypter(BaseModel):
|
||||
tenant_id: str
|
||||
config: list[BasicProviderConfig]
|
||||
provider_type: str
|
||||
provider_identity: str
|
||||
|
||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
deep copy data
|
||||
"""
|
||||
return deepcopy(data)
|
||||
|
||||
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
encrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with encrypted values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||
data[field_name] = encrypted
|
||||
|
||||
return data
|
||||
|
||||
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool credentials
|
||||
|
||||
return a deep copy of credentials with masked values
|
||||
"""
|
||||
data = self._deep_copy(data)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
if len(data[field_name]) > 6:
|
||||
data[field_name] = (
|
||||
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
data[field_name] = "*" * len(data[field_name])
|
||||
|
||||
return data
|
||||
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cached_credentials = cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
data = self._deep_copy(data)
|
||||
# get fields need to be decrypted
|
||||
fields = dict[str, BasicProviderConfig]()
|
||||
for credential in self.config:
|
||||
fields[credential.name] = credential
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||
if field_name in data:
|
||||
try:
|
||||
# if the value is None or empty string, skip decrypt
|
||||
if not data[field_name]:
|
||||
continue
|
||||
|
||||
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cache.set(data)
|
||||
return data
|
||||
|
||||
def delete_tool_credentials_cache(self):
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
|
||||
class ToolParameterConfigurationManager:
|
||||
"""
|
||||
Tool parameter configuration manager
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
tool_runtime: Tool
|
||||
provider_name: str
|
||||
provider_type: ToolProviderType
|
||||
identity_id: str
|
||||
|
||||
def __init__(
|
||||
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
|
||||
) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.tool_runtime = tool_runtime
|
||||
self.provider_name = provider_name
|
||||
self.provider_type = provider_type
|
||||
self.identity_id = identity_id
|
||||
|
||||
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
deep copy parameters
|
||||
"""
|
||||
return deepcopy(parameters)
|
||||
|
||||
def _merge_parameters(self) -> list[ToolParameter]:
|
||||
"""
|
||||
merge parameters
|
||||
"""
|
||||
# get tool parameters
|
||||
tool_parameters = self.tool_runtime.entity.parameters or []
|
||||
# get tool runtime parameters
|
||||
runtime_parameters = self.tool_runtime.get_runtime_parameters()
|
||||
# override parameters
|
||||
current_parameters = tool_parameters.copy()
|
||||
for runtime_parameter in runtime_parameters:
|
||||
found = False
|
||||
for index, parameter in enumerate(current_parameters):
|
||||
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
|
||||
current_parameters[index] = runtime_parameter
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
current_parameters.append(runtime_parameter)
|
||||
|
||||
return current_parameters
|
||||
|
||||
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
mask tool parameters
|
||||
|
||||
return a deep copy of parameters with masked values
|
||||
"""
|
||||
parameters = self._deep_copy(parameters)
|
||||
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
|
||||
for parameter in current_parameters:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
if len(parameters[parameter.name]) > 6:
|
||||
parameters[parameter.name] = (
|
||||
parameters[parameter.name][:2]
|
||||
+ "*" * (len(parameters[parameter.name]) - 4)
|
||||
+ parameters[parameter.name][-2:]
|
||||
)
|
||||
else:
|
||||
parameters[parameter.name] = "*" * len(parameters[parameter.name])
|
||||
|
||||
return parameters
|
||||
|
||||
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
encrypt tool parameters with tenant id
|
||||
|
||||
return a deep copy of parameters with encrypted values
|
||||
"""
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
|
||||
parameters = self._deep_copy(parameters)
|
||||
|
||||
for parameter in current_parameters:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
parameters[parameter.name] = encrypted
|
||||
|
||||
return parameters
|
||||
|
||||
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool parameters with tenant id
|
||||
|
||||
return a deep copy of parameters with decrypted values
|
||||
"""
|
||||
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.entity.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id,
|
||||
)
|
||||
cached_parameters = cache.get()
|
||||
if cached_parameters:
|
||||
return cached_parameters
|
||||
|
||||
# override parameters
|
||||
current_parameters = self._merge_parameters()
|
||||
has_secret_input = False
|
||||
|
||||
for parameter in current_parameters:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
try:
|
||||
has_secret_input = True
|
||||
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if has_secret_input:
|
||||
cache.set(parameters)
|
||||
|
||||
return parameters
|
||||
|
||||
def delete_tool_parameters_cache(self):
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.entity.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id,
|
||||
)
|
||||
cache.delete()
|
||||
|
|
@ -62,7 +62,7 @@ class DatasourceFileMessageTransformer:
|
|||
|
||||
mimetype = meta.get("mime_type")
|
||||
if not mimetype:
|
||||
mimetype = guess_type(filename)[0] or "application/octet-stream"
|
||||
mimetype = guess_type(filename)[0] or "application/octet-stream" # pyright: ignore[reportArgumentType]
|
||||
|
||||
# if message is str, encode it to bytes
|
||||
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ class ApiBasedToolSchemaParser:
|
|||
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
|
||||
|
||||
# parse body parameters
|
||||
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
|
||||
if "schema" in interface["operation"]["requestBody"]["content"][content_type]: # pyright: ignore[reportIndexIssue, reportPossiblyUnboundVariable]
|
||||
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
|
||||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
|
|
@ -247,7 +247,7 @@ class ApiBasedToolSchemaParser:
|
|||
|
||||
# convert paths
|
||||
for path, path_item in swagger["paths"].items():
|
||||
openapi["paths"][path] = {}
|
||||
openapi["paths"][path] = {} # pyright: ignore[reportIndexIssue]
|
||||
for method, operation in path_item.items():
|
||||
if "operationId" not in operation:
|
||||
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
|
||||
|
|
@ -258,7 +258,7 @@ class ApiBasedToolSchemaParser:
|
|||
if warning is not None:
|
||||
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
|
||||
|
||||
openapi["paths"][path][method] = {
|
||||
openapi["paths"][path][method] = { # pyright: ignore[reportIndexIssue]
|
||||
"operationId": operation["operationId"],
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
|
|
@ -267,11 +267,11 @@ class ApiBasedToolSchemaParser:
|
|||
}
|
||||
|
||||
if "requestBody" in operation:
|
||||
openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
|
||||
openapi["paths"][path][method]["requestBody"] = operation["requestBody"] # pyright: ignore[reportIndexIssue]
|
||||
|
||||
# convert definitions
|
||||
for name, definition in swagger["definitions"].items():
|
||||
openapi["components"]["schemas"][name] = definition
|
||||
openapi["components"]["schemas"][name] = definition # pyright: ignore[reportIndexIssue, reportArgumentType]
|
||||
|
||||
return openapi
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +0,0 @@
|
|||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from core.datasource import datasource_file_manager
|
||||
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||
|
||||
tool_file_manager: dict[str, Any] = {"manager": None}
|
||||
|
||||
|
||||
class DatasourceFileParser:
|
||||
@staticmethod
|
||||
def get_datasource_file_manager() -> "DatasourceFileManager":
|
||||
return cast("DatasourceFileManager", datasource_file_manager["manager"])
|
||||
|
|
@ -102,7 +102,7 @@ def download(f: File, /):
|
|||
FileTransferMethod.LOCAL_FILE,
|
||||
FileTransferMethod.DATASOURCE_FILE,
|
||||
):
|
||||
return _download_file_content(f._storage_key)
|
||||
return _download_file_content(f.storage_key)
|
||||
elif f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
|
||||
response.raise_for_status()
|
||||
|
|
@ -141,6 +141,8 @@ def _get_encoded_string(f: File, /):
|
|||
data = _download_file_content(f.storage_key)
|
||||
case FileTransferMethod.TOOL_FILE:
|
||||
data = _download_file_content(f.storage_key)
|
||||
case FileTransferMethod.DATASOURCE_FILE:
|
||||
data = _download_file_content(f.storage_key)
|
||||
|
||||
encoded_string = base64.b64encode(data).decode("utf-8")
|
||||
return encoded_string
|
||||
|
|
|
|||
|
|
@ -145,6 +145,9 @@ class File(BaseModel):
|
|||
case FileTransferMethod.TOOL_FILE:
|
||||
if not self.related_id:
|
||||
raise ValueError("Missing file related_id")
|
||||
case FileTransferMethod.DATASOURCE_FILE:
|
||||
if not self.related_id:
|
||||
raise ValueError("Missing file related_id")
|
||||
return self
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -288,7 +288,8 @@ class DatasetService:
|
|||
names,
|
||||
"Untitled",
|
||||
)
|
||||
|
||||
if not current_user or not current_user.id:
|
||||
raise ValueError("Current user or current user id not found")
|
||||
pipeline = Pipeline(
|
||||
tenant_id=tenant_id,
|
||||
name=rag_pipeline_dataset_create_entity.name,
|
||||
|
|
@ -814,6 +815,8 @@ class DatasetService:
|
|||
def update_rag_pipeline_dataset_settings(
|
||||
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
|
||||
):
|
||||
if not current_user or not current_user.current_tenant_id:
|
||||
raise ValueError("Current user or current tenant not found")
|
||||
dataset = session.merge(dataset)
|
||||
if not has_published:
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
|
|
@ -821,7 +824,7 @@ class DatasetService:
|
|||
if knowledge_configuration.indexing_technique == "high_quality":
|
||||
model_manager = ModelManager()
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_user.current_tenant_id, # ignore type error
|
||||
provider=knowledge_configuration.embedding_model_provider or "",
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=knowledge_configuration.embedding_model or "",
|
||||
|
|
@ -895,6 +898,7 @@ class DatasetService:
|
|||
):
|
||||
action = "update"
|
||||
model_manager = ModelManager()
|
||||
embedding_model = None
|
||||
try:
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
|
|
@ -908,14 +912,15 @@ class DatasetService:
|
|||
# Skip the rest of the embedding model update
|
||||
skip_embedding_update = True
|
||||
if not skip_embedding_update:
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
dataset_collection_binding = (
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
if embedding_model:
|
||||
dataset.embedding_model = embedding_model.model
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
dataset_collection_binding = (
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
)
|
||||
)
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
except LLMBadRequestError:
|
||||
raise ValueError(
|
||||
"No Embedding Model available. Please configure a valid provider "
|
||||
|
|
@ -1014,6 +1019,8 @@ class DatasetService:
|
|||
if dataset is None:
|
||||
raise NotFound("Dataset not found.")
|
||||
dataset.enable_api = status
|
||||
if not current_user or not current_user.id:
|
||||
raise ValueError("Current user or current user id not found")
|
||||
dataset.updated_by = current_user.id
|
||||
dataset.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
|
@ -1350,6 +1357,8 @@ class DocumentService:
|
|||
redis_client.setex(retry_indexing_cache_key, 600, 1)
|
||||
# trigger async task
|
||||
document_ids = [document.id for document in documents]
|
||||
if not current_user or not current_user.id:
|
||||
raise ValueError("Current user or current user id not found")
|
||||
retry_document_indexing_task.delay(dataset_id, document_ids, current_user.id)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask_login import current_user
|
||||
|
|
@ -68,11 +69,13 @@ class DatasourceProviderService:
|
|||
tenant_id: str,
|
||||
provider: str,
|
||||
plugin_id: str,
|
||||
raw_credentials: dict[str, Any],
|
||||
raw_credentials: Mapping[str, Any],
|
||||
datasource_provider: DatasourceProvider,
|
||||
) -> dict[str, Any]:
|
||||
provider_credential_secret_variables = self.extract_secret_variables(
|
||||
tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}", credential_type=datasource_provider.auth_type
|
||||
tenant_id=tenant_id,
|
||||
provider_id=f"{plugin_id}/{provider}",
|
||||
credential_type=CredentialType.of(datasource_provider.auth_type),
|
||||
)
|
||||
encrypted_credentials = raw_credentials.copy()
|
||||
for key, value in encrypted_credentials.items():
|
||||
|
|
|
|||
|
|
@ -241,6 +241,9 @@ class MessageService:
|
|||
|
||||
app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow)
|
||||
|
||||
if not app_config.additional_features:
|
||||
raise ValueError("Additional features not found")
|
||||
|
||||
if not app_config.additional_features.suggested_questions_after_answer:
|
||||
raise SuggestedQuestionsAfterAnswerDisabledError()
|
||||
|
||||
|
|
|
|||
|
|
@ -828,10 +828,10 @@ class RagPipelineService:
|
|||
)
|
||||
error = node_run_result.error if not run_succeeded else None
|
||||
except WorkflowNodeRunFailedError as e:
|
||||
node_instance = e._node
|
||||
node_instance = e._node # type: ignore
|
||||
run_succeeded = False
|
||||
node_run_result = None
|
||||
error = e._error
|
||||
error = e._error # type: ignore
|
||||
|
||||
workflow_node_execution = WorkflowNodeExecution(
|
||||
id=str(uuid4()),
|
||||
|
|
@ -1253,7 +1253,7 @@ class RagPipelineService:
|
|||
repository.save(workflow_node_execution)
|
||||
|
||||
# Convert node_execution to WorkflowNodeExecution after save
|
||||
workflow_node_execution_db_model = repository._to_db_model(workflow_node_execution)
|
||||
workflow_node_execution_db_model = repository._to_db_model(workflow_node_execution) # type: ignore
|
||||
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
draft_var_saver = DraftVariableSaver(
|
||||
|
|
|
|||
|
|
@ -47,6 +47,8 @@ class RagPipelineTransformService:
|
|||
self._deal_dependencies(pipeline_yaml, dataset.tenant_id)
|
||||
# Extract app data
|
||||
workflow_data = pipeline_yaml.get("workflow")
|
||||
if not workflow_data:
|
||||
raise ValueError("Missing workflow data for rag pipeline")
|
||||
graph = workflow_data.get("graph", {})
|
||||
nodes = graph.get("nodes", [])
|
||||
new_nodes = []
|
||||
|
|
@ -252,7 +254,7 @@ class RagPipelineTransformService:
|
|||
plugin_unique_identifier = dependency.get("value", {}).get("plugin_unique_identifier")
|
||||
plugin_id = plugin_unique_identifier.split(":")[0]
|
||||
if plugin_id not in installed_plugins_ids:
|
||||
plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(plugin_id)
|
||||
plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(plugin_id) # type: ignore
|
||||
if plugin_unique_identifier:
|
||||
need_install_plugin_unique_identifiers.append(plugin_unique_identifier)
|
||||
if need_install_plugin_unique_identifiers:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import dataclasses
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, overload
|
||||
from typing import Any, Generic, TypeAlias, TypeVar, overload
|
||||
|
||||
from configs import dify_config
|
||||
from core.file.models import File
|
||||
|
|
@ -39,30 +38,6 @@ class _PCKeys:
|
|||
CHILD_CONTENTS = "child_contents"
|
||||
|
||||
|
||||
class _QAStructureItem(TypedDict):
|
||||
question: str
|
||||
answer: str
|
||||
|
||||
|
||||
class _QAStructure(TypedDict):
|
||||
qa_chunks: list[_QAStructureItem]
|
||||
|
||||
|
||||
class _ParentChildChunkItem(TypedDict):
|
||||
parent_content: str
|
||||
child_contents: list[str]
|
||||
|
||||
|
||||
class _ParentChildStructure(TypedDict):
|
||||
parent_mode: str
|
||||
parent_child_chunks: list[_ParentChildChunkItem]
|
||||
|
||||
|
||||
class _SpecialChunkType(StrEnum):
|
||||
parent_child = "parent_child"
|
||||
qa = "qa"
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
|
|
@ -392,7 +367,7 @@ class VariableTruncator:
|
|||
def _truncate_json_primitives(self, val: dict, target_size: int) -> _PartResult[dict]: ...
|
||||
|
||||
@overload
|
||||
def _truncate_json_primitives(self, val: bool, target_size: int) -> _PartResult[bool]: ...
|
||||
def _truncate_json_primitives(self, val: bool, target_size: int) -> _PartResult[bool]: ... # type: ignore
|
||||
|
||||
@overload
|
||||
def _truncate_json_primitives(self, val: int, target_size: int) -> _PartResult[int]: ...
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ class WorkflowConverter:
|
|||
graph=graph,
|
||||
model_config=app_config.model,
|
||||
prompt_template=app_config.prompt_template,
|
||||
file_upload=app_config.additional_features.file_upload,
|
||||
file_upload=app_config.additional_features.file_upload if app_config.additional_features else None,
|
||||
external_data_variable_node_mapping=external_data_variable_node_mapping,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -430,6 +430,10 @@ class WorkflowDraftVariableService:
|
|||
.where(WorkflowDraftVariable.id == variable.id)
|
||||
)
|
||||
variable_reloaded = self._session.execute(variable_query).scalars().first()
|
||||
if variable_reloaded is None:
|
||||
logger.warning("Associated WorkflowDraftVariable not found, draft_var_id=%s", variable.id)
|
||||
self._session.delete(variable)
|
||||
return
|
||||
variable_file = variable_reloaded.variable_file
|
||||
if variable_file is None:
|
||||
logger.warning(
|
||||
|
|
|
|||
|
|
@ -811,7 +811,7 @@ class WorkflowService:
|
|||
return node, node_run_result, run_succeeded, error
|
||||
|
||||
except WorkflowNodeRunFailedError as e:
|
||||
return e._node, None, False, e._error
|
||||
return e._node, None, False, e._error # type: ignore
|
||||
|
||||
def _apply_error_strategy(self, node: Node, node_run_result: NodeRunResult) -> NodeRunResult:
|
||||
"""Apply error strategy when node execution fails."""
|
||||
|
|
|
|||
|
|
@ -156,7 +156,8 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
|
|||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
|
||||
pipeline_generator = PipelineGenerator()
|
||||
pipeline_generator._generate(
|
||||
# Using protected method intentionally for async execution
|
||||
pipeline_generator._generate( # type: ignore[attr-defined]
|
||||
flask_app=flask_app,
|
||||
context=context,
|
||||
pipeline=pipeline,
|
||||
|
|
|
|||
|
|
@ -177,7 +177,8 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
|
|||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
|
||||
pipeline_generator = PipelineGenerator()
|
||||
pipeline_generator._generate(
|
||||
# Using protected method intentionally for async execution
|
||||
pipeline_generator._generate( # type: ignore[attr-defined]
|
||||
flask_app=flask_app,
|
||||
context=context,
|
||||
pipeline=pipeline,
|
||||
|
|
|
|||
|
|
@ -452,12 +452,12 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
|
|||
|
||||
# Delete from object storage and collect upload file IDs
|
||||
upload_file_ids = []
|
||||
for variable_file_id, storage_key, upload_file_id in file_records:
|
||||
for _, storage_key, upload_file_id in file_records:
|
||||
try:
|
||||
storage.delete(storage_key)
|
||||
upload_file_ids.append(upload_file_id)
|
||||
files_deleted += 1
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.exception("Failed to delete storage object %s", storage_key)
|
||||
# Continue with database cleanup even if storage deletion fails
|
||||
upload_file_ids.append(upload_file_id)
|
||||
|
|
|
|||
Loading…
Reference in New Issue