diff --git a/api/commands.py b/api/commands.py index 39c40fdf73..3ff0d1fbe1 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1429,9 +1429,9 @@ def transform_datasource_credentials(): notion_plugin_id = "langgenius/notion_datasource" firecrawl_plugin_id = "langgenius/firecrawl_datasource" jina_plugin_id = "langgenius/jina_datasource" - notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) - firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) - jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) + notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage] + firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage] + jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage] oauth_credential_type = CredentialType.OAUTH2 api_key_credential_type = CredentialType.API_KEY diff --git a/api/configs/__init__.py b/api/configs/__init__.py index 3a172601c9..04642b5e9a 100644 --- a/api/configs/__init__.py +++ b/api/configs/__init__.py @@ -1,3 +1,3 @@ from .app_config import DifyConfig -dify_config = DifyConfig() +dify_config = DifyConfig() # pyright: ignore[reportCallIssue] diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 4339501f73..ee02ff3937 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -248,6 +248,8 @@ __all__ = [ "datasets", "datasets_document", "datasets_segments", + "datasource_auth", + "datasource_content_preview", "email_register", "endpoint", "extension", @@ -273,10 +275,16 @@ __all__ = [ "parameter", "ping", "plugin", + "rag_pipeline", + "rag_pipeline_datasets", + "rag_pipeline_draft_variable", + "rag_pipeline_import", + "rag_pipeline_workflow", "recommended_app", "saved_message", "setup", "site", + "spec", "statistic", "tags", "tool_providers", diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 90a8040d81..b8ff40cb5f 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -45,9 +45,9 @@ class BaseAppGenerator: mapping=v, tenant_id=tenant_id, config=FileUploadConfig( - allowed_file_types=entity_dictionary[k].allowed_file_types, - allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, - allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, + allowed_file_types=entity_dictionary[k].allowed_file_types, # pyright: ignore[reportArgumentType] + allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, # pyright: ignore[reportArgumentType] + allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, # pyright: ignore[reportArgumentType] ), strict_type_validation=strict_type_validation, ) @@ -60,9 +60,9 @@ class BaseAppGenerator: mappings=v, tenant_id=tenant_id, config=FileUploadConfig( - allowed_file_types=entity_dictionary[k].allowed_file_types, - allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, - allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, + allowed_file_types=entity_dictionary[k].allowed_file_types, # pyright: ignore[reportArgumentType] + allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, # pyright: ignore[reportArgumentType] + allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods, # pyright: ignore[reportArgumentType] ), ) for k, v in user_inputs.items() diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 6c4bf4139e..e2be4146e1 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -124,7 +124,9 @@ class CompletionAppRunner(AppRunner): config=dataset_config, query=query or "", invoke_from=application_generate_entity.invoke_from, - show_retrieve_source=app_config.additional_features.show_retrieve_source, + show_retrieve_source=app_config.additional_features.show_retrieve_source + if app_config.additional_features + else False, hit_callback=hit_callback, message_id=message.id, inputs=inputs, diff --git a/api/core/app/apps/pipeline/generate_response_converter.py b/api/core/app/apps/pipeline/generate_response_converter.py index 10ec73a7d2..e125958180 100644 --- a/api/core/app/apps/pipeline/generate_response_converter.py +++ b/api/core/app/apps/pipeline/generate_response_converter.py @@ -58,7 +58,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) - response_chunk.update(data) + response_chunk.update(cast(dict, data)) else: response_chunk.update(sub_stream_response.to_dict()) yield response_chunk @@ -87,9 +87,9 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter): if isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) - response_chunk.update(data) + response_chunk.update(cast(dict, data)) elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): - response_chunk.update(sub_stream_response.to_ignore_detail_dict()) + response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict())) else: response_chunk.update(sub_stream_response.to_dict()) yield response_chunk diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 2765b4600a..b1e98ed3ea 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -71,7 +71,7 @@ class PipelineGenerator(BaseAppGenerator): call_depth: int, workflow_thread_pool_id: Optional[str], is_retry: bool = False, - ) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ... + ) -> Generator[Mapping | str, None, None]: ... @overload def generate( diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 1c055fe8b6..14dd78489a 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -260,26 +260,6 @@ class RagPipelineGenerateEntity(WorkflowAppGenerateEntity): original_document_id: Optional[str] = None start_node_id: Optional[str] = None - class SingleIterationRunEntity(BaseModel): - """ - Single Iteration Run Entity. - """ - - node_id: str - inputs: dict - - single_iteration_run: Optional[SingleIterationRunEntity] = None - - class SingleLoopRunEntity(BaseModel): - """ - Single Loop Run Entity. - """ - - node_id: str - inputs: dict - - single_loop_run: Optional[SingleLoopRunEntity] = None - # Import TraceQueueManager at runtime to resolve forward references from core.ops.ops_trace_manager import TraceQueueManager diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 90ffdcf1f6..0004fb592e 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -138,6 +138,8 @@ class MessageCycleManager: :param event: event :return: """ + if not self._application_generate_entity.app_config.additional_features: + raise ValueError("Additional features not found") if self._application_generate_entity.app_config.additional_features.show_retrieve_source: self._task_state.metadata.retriever_resources = event.retriever_resources diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py index 66d69499e7..beb5ce7b04 100644 --- a/api/core/datasource/datasource_file_manager.py +++ b/api/core/datasource/datasource_file_manager.py @@ -136,7 +136,6 @@ class DatasourceFileManager: original_url=file_url, name=filename, size=len(blob), - key=filepath, ) db.session.add(tool_file) diff --git a/api/core/datasource/utils/configuration.py b/api/core/datasource/utils/configuration.py deleted file mode 100644 index 6a5fba65bd..0000000000 --- a/api/core/datasource/utils/configuration.py +++ /dev/null @@ -1,265 +0,0 @@ -from copy import deepcopy -from typing import Any - -from pydantic import BaseModel - -from core.entities.provider_entities import BasicProviderConfig -from core.helper import encrypter -from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType -from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType -from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ( - ToolParameter, - ToolProviderType, -) - - -class ProviderConfigEncrypter(BaseModel): - tenant_id: str - config: list[BasicProviderConfig] - provider_type: str - provider_identity: str - - def _deep_copy(self, data: dict[str, str]) -> dict[str, str]: - """ - deep copy data - """ - return deepcopy(data) - - def encrypt(self, data: dict[str, str]) -> dict[str, str]: - """ - encrypt tool credentials with tenant id - - return a deep copy of credentials with encrypted values - """ - data = self._deep_copy(data) - - # get fields need to be decrypted - fields = dict[str, BasicProviderConfig]() - for credential in self.config: - fields[credential.name] = credential - - for field_name, field in fields.items(): - if field.type == BasicProviderConfig.Type.SECRET_INPUT: - if field_name in data: - encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "") - data[field_name] = encrypted - - return data - - def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]: - """ - mask tool credentials - - return a deep copy of credentials with masked values - """ - data = self._deep_copy(data) - - # get fields need to be decrypted - fields = dict[str, BasicProviderConfig]() - for credential in self.config: - fields[credential.name] = credential - - for field_name, field in fields.items(): - if field.type == BasicProviderConfig.Type.SECRET_INPUT: - if field_name in data: - if len(data[field_name]) > 6: - data[field_name] = ( - data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:] - ) - else: - data[field_name] = "*" * len(data[field_name]) - - return data - - def decrypt(self, data: dict[str, str]) -> dict[str, str]: - """ - decrypt tool credentials with tenant id - - return a deep copy of credentials with decrypted values - """ - cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f"{self.provider_type}.{self.provider_identity}", - cache_type=ToolProviderCredentialsCacheType.PROVIDER, - ) - cached_credentials = cache.get() - if cached_credentials: - return cached_credentials - data = self._deep_copy(data) - # get fields need to be decrypted - fields = dict[str, BasicProviderConfig]() - for credential in self.config: - fields[credential.name] = credential - - for field_name, field in fields.items(): - if field.type == BasicProviderConfig.Type.SECRET_INPUT: - if field_name in data: - try: - # if the value is None or empty string, skip decrypt - if not data[field_name]: - continue - - data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) - except Exception: - pass - - cache.set(data) - return data - - def delete_tool_credentials_cache(self): - cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f"{self.provider_type}.{self.provider_identity}", - cache_type=ToolProviderCredentialsCacheType.PROVIDER, - ) - cache.delete() - - -class ToolParameterConfigurationManager: - """ - Tool parameter configuration manager - """ - - tenant_id: str - tool_runtime: Tool - provider_name: str - provider_type: ToolProviderType - identity_id: str - - def __init__( - self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str - ) -> None: - self.tenant_id = tenant_id - self.tool_runtime = tool_runtime - self.provider_name = provider_name - self.provider_type = provider_type - self.identity_id = identity_id - - def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]: - """ - deep copy parameters - """ - return deepcopy(parameters) - - def _merge_parameters(self) -> list[ToolParameter]: - """ - merge parameters - """ - # get tool parameters - tool_parameters = self.tool_runtime.entity.parameters or [] - # get tool runtime parameters - runtime_parameters = self.tool_runtime.get_runtime_parameters() - # override parameters - current_parameters = tool_parameters.copy() - for runtime_parameter in runtime_parameters: - found = False - for index, parameter in enumerate(current_parameters): - if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form: - current_parameters[index] = runtime_parameter - found = True - break - - if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: - current_parameters.append(runtime_parameter) - - return current_parameters - - def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: - """ - mask tool parameters - - return a deep copy of parameters with masked values - """ - parameters = self._deep_copy(parameters) - - # override parameters - current_parameters = self._merge_parameters() - - for parameter in current_parameters: - if ( - parameter.form == ToolParameter.ToolParameterForm.FORM - and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT - ): - if parameter.name in parameters: - if len(parameters[parameter.name]) > 6: - parameters[parameter.name] = ( - parameters[parameter.name][:2] - + "*" * (len(parameters[parameter.name]) - 4) - + parameters[parameter.name][-2:] - ) - else: - parameters[parameter.name] = "*" * len(parameters[parameter.name]) - - return parameters - - def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: - """ - encrypt tool parameters with tenant id - - return a deep copy of parameters with encrypted values - """ - # override parameters - current_parameters = self._merge_parameters() - - parameters = self._deep_copy(parameters) - - for parameter in current_parameters: - if ( - parameter.form == ToolParameter.ToolParameterForm.FORM - and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT - ): - if parameter.name in parameters: - encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name]) - parameters[parameter.name] = encrypted - - return parameters - - def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: - """ - decrypt tool parameters with tenant id - - return a deep copy of parameters with decrypted values - """ - - cache = ToolParameterCache( - tenant_id=self.tenant_id, - provider=f"{self.provider_type.value}.{self.provider_name}", - tool_name=self.tool_runtime.entity.identity.name, - cache_type=ToolParameterCacheType.PARAMETER, - identity_id=self.identity_id, - ) - cached_parameters = cache.get() - if cached_parameters: - return cached_parameters - - # override parameters - current_parameters = self._merge_parameters() - has_secret_input = False - - for parameter in current_parameters: - if ( - parameter.form == ToolParameter.ToolParameterForm.FORM - and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT - ): - if parameter.name in parameters: - try: - has_secret_input = True - parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name]) - except Exception: - pass - - if has_secret_input: - cache.set(parameters) - - return parameters - - def delete_tool_parameters_cache(self): - cache = ToolParameterCache( - tenant_id=self.tenant_id, - provider=f"{self.provider_type.value}.{self.provider_name}", - tool_name=self.tool_runtime.entity.identity.name, - cache_type=ToolParameterCacheType.PARAMETER, - identity_id=self.identity_id, - ) - cache.delete() diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index d249e02064..bfdd5c214d 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -62,7 +62,7 @@ class DatasourceFileMessageTransformer: mimetype = meta.get("mime_type") if not mimetype: - mimetype = guess_type(filename)[0] or "application/octet-stream" + mimetype = guess_type(filename)[0] or "application/octet-stream" # pyright: ignore[reportArgumentType] # if message is str, encode it to bytes diff --git a/api/core/datasource/utils/parser.py b/api/core/datasource/utils/parser.py index f72291783a..33682159c8 100644 --- a/api/core/datasource/utils/parser.py +++ b/api/core/datasource/utils/parser.py @@ -100,7 +100,7 @@ class ApiBasedToolSchemaParser: interface["operation"]["requestBody"]["content"][content_type]["schema"] = root # parse body parameters - if "schema" in interface["operation"]["requestBody"]["content"][content_type]: + if "schema" in interface["operation"]["requestBody"]["content"][content_type]: # pyright: ignore[reportIndexIssue, reportPossiblyUnboundVariable] body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"] required = body_schema.get("required", []) properties = body_schema.get("properties", {}) @@ -247,7 +247,7 @@ class ApiBasedToolSchemaParser: # convert paths for path, path_item in swagger["paths"].items(): - openapi["paths"][path] = {} + openapi["paths"][path] = {} # pyright: ignore[reportIndexIssue] for method, operation in path_item.items(): if "operationId" not in operation: raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.") @@ -258,7 +258,7 @@ class ApiBasedToolSchemaParser: if warning is not None: warning["missing_summary"] = f"No summary or description found in operation {method} {path}." - openapi["paths"][path][method] = { + openapi["paths"][path][method] = { # pyright: ignore[reportIndexIssue] "operationId": operation["operationId"], "summary": operation.get("summary", ""), "description": operation.get("description", ""), @@ -267,11 +267,11 @@ class ApiBasedToolSchemaParser: } if "requestBody" in operation: - openapi["paths"][path][method]["requestBody"] = operation["requestBody"] + openapi["paths"][path][method]["requestBody"] = operation["requestBody"] # pyright: ignore[reportIndexIssue] # convert definitions for name, definition in swagger["definitions"].items(): - openapi["components"]["schemas"][name] = definition + openapi["components"]["schemas"][name] = definition # pyright: ignore[reportIndexIssue, reportArgumentType] return openapi diff --git a/api/core/file/datasource_file_parser.py b/api/core/file/datasource_file_parser.py deleted file mode 100644 index 52687951ac..0000000000 --- a/api/core/file/datasource_file_parser.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import TYPE_CHECKING, Any, cast - -from core.datasource import datasource_file_manager -from core.datasource.datasource_file_manager import DatasourceFileManager - -if TYPE_CHECKING: - from core.datasource.datasource_file_manager import DatasourceFileManager - -tool_file_manager: dict[str, Any] = {"manager": None} - - -class DatasourceFileParser: - @staticmethod - def get_datasource_file_manager() -> "DatasourceFileManager": - return cast("DatasourceFileManager", datasource_file_manager["manager"]) diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index ce2ece48e1..120fb73cdb 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -102,7 +102,7 @@ def download(f: File, /): FileTransferMethod.LOCAL_FILE, FileTransferMethod.DATASOURCE_FILE, ): - return _download_file_content(f._storage_key) + return _download_file_content(f.storage_key) elif f.transfer_method == FileTransferMethod.REMOTE_URL: response = ssrf_proxy.get(f.remote_url, follow_redirects=True) response.raise_for_status() @@ -141,6 +141,8 @@ def _get_encoded_string(f: File, /): data = _download_file_content(f.storage_key) case FileTransferMethod.TOOL_FILE: data = _download_file_content(f.storage_key) + case FileTransferMethod.DATASOURCE_FILE: + data = _download_file_content(f.storage_key) encoded_string = base64.b64encode(data).decode("utf-8") return encoded_string diff --git a/api/core/file/models.py b/api/core/file/models.py index 990d3fe91d..a242851a63 100644 --- a/api/core/file/models.py +++ b/api/core/file/models.py @@ -145,6 +145,9 @@ class File(BaseModel): case FileTransferMethod.TOOL_FILE: if not self.related_id: raise ValueError("Missing file related_id") + case FileTransferMethod.DATASOURCE_FILE: + if not self.related_id: + raise ValueError("Missing file related_id") return self @property diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 400b00ef83..8c86482cfc 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -288,7 +288,8 @@ class DatasetService: names, "Untitled", ) - + if not current_user or not current_user.id: + raise ValueError("Current user or current user id not found") pipeline = Pipeline( tenant_id=tenant_id, name=rag_pipeline_dataset_create_entity.name, @@ -814,6 +815,8 @@ class DatasetService: def update_rag_pipeline_dataset_settings( session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False ): + if not current_user or not current_user.current_tenant_id: + raise ValueError("Current user or current tenant not found") dataset = session.merge(dataset) if not has_published: dataset.chunk_structure = knowledge_configuration.chunk_structure @@ -821,7 +824,7 @@ class DatasetService: if knowledge_configuration.indexing_technique == "high_quality": model_manager = ModelManager() embedding_model = model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, + tenant_id=current_user.current_tenant_id, # ignore type error provider=knowledge_configuration.embedding_model_provider or "", model_type=ModelType.TEXT_EMBEDDING, model=knowledge_configuration.embedding_model or "", @@ -895,6 +898,7 @@ class DatasetService: ): action = "update" model_manager = ModelManager() + embedding_model = None try: embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, @@ -908,14 +912,15 @@ class DatasetService: # Skip the rest of the embedding model update skip_embedding_update = True if not skip_embedding_update: - dataset.embedding_model = embedding_model.model - dataset.embedding_model_provider = embedding_model.provider - dataset_collection_binding = ( - DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model + if embedding_model: + dataset.embedding_model = embedding_model.model + dataset.embedding_model_provider = embedding_model.provider + dataset_collection_binding = ( + DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) ) - ) - dataset.collection_binding_id = dataset_collection_binding.id + dataset.collection_binding_id = dataset_collection_binding.id except LLMBadRequestError: raise ValueError( "No Embedding Model available. Please configure a valid provider " @@ -1014,6 +1019,8 @@ class DatasetService: if dataset is None: raise NotFound("Dataset not found.") dataset.enable_api = status + if not current_user or not current_user.id: + raise ValueError("Current user or current user id not found") dataset.updated_by = current_user.id dataset.updated_at = naive_utc_now() db.session.commit() @@ -1350,6 +1357,8 @@ class DocumentService: redis_client.setex(retry_indexing_cache_key, 600, 1) # trigger async task document_ids = [document.id for document in documents] + if not current_user or not current_user.id: + raise ValueError("Current user or current user id not found") retry_document_indexing_task.delay(dataset_id, document_ids, current_user.id) @staticmethod diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 870360ceb6..8dceeee7ec 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -1,5 +1,6 @@ import logging import time +from collections.abc import Mapping from typing import Any from flask_login import current_user @@ -68,11 +69,13 @@ class DatasourceProviderService: tenant_id: str, provider: str, plugin_id: str, - raw_credentials: dict[str, Any], + raw_credentials: Mapping[str, Any], datasource_provider: DatasourceProvider, ) -> dict[str, Any]: provider_credential_secret_variables = self.extract_secret_variables( - tenant_id=tenant_id, provider_id=f"{plugin_id}/{provider}", credential_type=datasource_provider.auth_type + tenant_id=tenant_id, + provider_id=f"{plugin_id}/{provider}", + credential_type=CredentialType.of(datasource_provider.auth_type), ) encrypted_credentials = raw_credentials.copy() for key, value in encrypted_credentials.items(): diff --git a/api/services/message_service.py b/api/services/message_service.py index e2e27443ba..5df80b7aa3 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -241,6 +241,9 @@ class MessageService: app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + if not app_config.additional_features: + raise ValueError("Additional features not found") + if not app_config.additional_features.suggested_questions_after_answer: raise SuggestedQuestionsAfterAnswerDisabledError() diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 99c192d709..bbe102f5ee 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -828,10 +828,10 @@ class RagPipelineService: ) error = node_run_result.error if not run_succeeded else None except WorkflowNodeRunFailedError as e: - node_instance = e._node + node_instance = e._node # type: ignore run_succeeded = False node_run_result = None - error = e._error + error = e._error # type: ignore workflow_node_execution = WorkflowNodeExecution( id=str(uuid4()), @@ -1253,7 +1253,7 @@ class RagPipelineService: repository.save(workflow_node_execution) # Convert node_execution to WorkflowNodeExecution after save - workflow_node_execution_db_model = repository._to_db_model(workflow_node_execution) + workflow_node_execution_db_model = repository._to_db_model(workflow_node_execution) # type: ignore with Session(bind=db.engine) as session, session.begin(): draft_var_saver = DraftVariableSaver( diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 6ddb7a70ae..78440b4889 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -47,6 +47,8 @@ class RagPipelineTransformService: self._deal_dependencies(pipeline_yaml, dataset.tenant_id) # Extract app data workflow_data = pipeline_yaml.get("workflow") + if not workflow_data: + raise ValueError("Missing workflow data for rag pipeline") graph = workflow_data.get("graph", {}) nodes = graph.get("nodes", []) new_nodes = [] @@ -252,7 +254,7 @@ class RagPipelineTransformService: plugin_unique_identifier = dependency.get("value", {}).get("plugin_unique_identifier") plugin_id = plugin_unique_identifier.split(":")[0] if plugin_id not in installed_plugins_ids: - plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(plugin_id) + plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(plugin_id) # type: ignore if plugin_unique_identifier: need_install_plugin_unique_identifiers.append(plugin_unique_identifier) if need_install_plugin_unique_identifiers: diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index dd3288d8c8..4362bb0291 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -1,7 +1,6 @@ import dataclasses from collections.abc import Mapping -from enum import StrEnum -from typing import Any, Generic, TypeAlias, TypedDict, TypeVar, overload +from typing import Any, Generic, TypeAlias, TypeVar, overload from configs import dify_config from core.file.models import File @@ -39,30 +38,6 @@ class _PCKeys: CHILD_CONTENTS = "child_contents" -class _QAStructureItem(TypedDict): - question: str - answer: str - - -class _QAStructure(TypedDict): - qa_chunks: list[_QAStructureItem] - - -class _ParentChildChunkItem(TypedDict): - parent_content: str - child_contents: list[str] - - -class _ParentChildStructure(TypedDict): - parent_mode: str - parent_child_chunks: list[_ParentChildChunkItem] - - -class _SpecialChunkType(StrEnum): - parent_child = "parent_child" - qa = "qa" - - _T = TypeVar("_T") @@ -392,7 +367,7 @@ class VariableTruncator: def _truncate_json_primitives(self, val: dict, target_size: int) -> _PartResult[dict]: ... @overload - def _truncate_json_primitives(self, val: bool, target_size: int) -> _PartResult[bool]: ... + def _truncate_json_primitives(self, val: bool, target_size: int) -> _PartResult[bool]: ... # type: ignore @overload def _truncate_json_primitives(self, val: int, target_size: int) -> _PartResult[int]: ... diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 9ce5b6dbe0..dccd891981 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -146,7 +146,7 @@ class WorkflowConverter: graph=graph, model_config=app_config.model, prompt_template=app_config.prompt_template, - file_upload=app_config.additional_features.file_upload, + file_upload=app_config.additional_features.file_upload if app_config.additional_features else None, external_data_variable_node_mapping=external_data_variable_node_mapping, ) diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index e79a0802b6..1378c20128 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -430,6 +430,10 @@ class WorkflowDraftVariableService: .where(WorkflowDraftVariable.id == variable.id) ) variable_reloaded = self._session.execute(variable_query).scalars().first() + if variable_reloaded is None: + logger.warning("Associated WorkflowDraftVariable not found, draft_var_id=%s", variable.id) + self._session.delete(variable) + return variable_file = variable_reloaded.variable_file if variable_file is None: logger.warning( diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 9b5a2ea274..520b0b8fc0 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -811,7 +811,7 @@ class WorkflowService: return node, node_run_result, run_succeeded, error except WorkflowNodeRunFailedError as e: - return e._node, None, False, e._error + return e._node, None, False, e._error # type: ignore def _apply_error_strategy(self, node: Node, node_run_result: NodeRunResult) -> NodeRunResult: """Apply error strategy when node execution fails.""" diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index 8634530418..4780e48558 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -156,7 +156,8 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], from core.app.apps.pipeline.pipeline_generator import PipelineGenerator pipeline_generator = PipelineGenerator() - pipeline_generator._generate( + # Using protected method intentionally for async execution + pipeline_generator._generate( # type: ignore[attr-defined] flask_app=flask_app, context=context, pipeline=pipeline, diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index cac37a565a..72916972df 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -177,7 +177,8 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], from core.app.apps.pipeline.pipeline_generator import PipelineGenerator pipeline_generator = PipelineGenerator() - pipeline_generator._generate( + # Using protected method intentionally for async execution + pipeline_generator._generate( # type: ignore[attr-defined] flask_app=flask_app, context=context, pipeline=pipeline, diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 2eb3f2a112..f8f39583ac 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -452,12 +452,12 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int: # Delete from object storage and collect upload file IDs upload_file_ids = [] - for variable_file_id, storage_key, upload_file_id in file_records: + for _, storage_key, upload_file_id in file_records: try: storage.delete(storage_key) upload_file_ids.append(upload_file_id) files_deleted += 1 - except Exception as e: + except Exception: logging.exception("Failed to delete storage object %s", storage_key) # Continue with database cleanup even if storage deletion fails upload_file_ids.append(upload_file_id)