diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 11df511840..e71b774d3e 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -90,7 +90,7 @@ class ModelConfigResource(Resource): if not isinstance(tool, dict) or len(tool.keys()) <= 3: continue - agent_tool_entity = AgentToolEntity(**tool) + agent_tool_entity = AgentToolEntity.model_validate(tool) # get tool try: tool_runtime = ToolManager.get_agent_tool_runtime( @@ -124,7 +124,7 @@ class ModelConfigResource(Resource): # encrypt agent tool parameters if it's secret-input agent_mode = new_app_model_config.agent_mode_dict for tool in agent_mode.get("tools") or []: - agent_tool_entity = AgentToolEntity(**tool) + agent_tool_entity = AgentToolEntity.model_validate(tool) # get tool key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 370e0c0d14..b0f18c11d4 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -15,7 +15,7 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType, from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin from core.indexing_runner import IndexingRunner from core.rag.extractor.entity.datasource_type import DatasourceType -from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo from core.rag.extractor.notion_extractor import NotionExtractor from extensions.ext_database import db from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields @@ -257,13 +257,15 @@ class DataSourceNotionApi(Resource): for page in notion_info["pages"]: extract_setting = ExtractSetting( datasource_type=DatasourceType.NOTION.value, - notion_info={ - "credential_id": credential_id, - "notion_workspace_id": workspace_id, - "notion_obj_id": page["page_id"], - "notion_page_type": page["type"], - "tenant_id": current_user.current_tenant_id, - }, + notion_info=NotionInfo.model_validate( + { + "credential_id": credential_id, + "notion_workspace_id": workspace_id, + "notion_obj_id": page["page_id"], + "notion_page_type": page["type"], + "tenant_id": current_user.current_tenant_id, + } + ), document_model=args["doc_form"], ) extract_settings.append(extract_setting) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index ac088b790e..284f88ff1e 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -24,7 +24,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager from core.rag.datasource.vdb.vector_type import VectorType from core.rag.extractor.entity.datasource_type import DatasourceType -from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from fields.app_fields import related_app_list @@ -513,13 +513,15 @@ class DatasetIndexingEstimateApi(Resource): for page in notion_info["pages"]: extract_setting = ExtractSetting( datasource_type=DatasourceType.NOTION.value, - notion_info={ - "credential_id": credential_id, - "notion_workspace_id": workspace_id, - "notion_obj_id": page["page_id"], - "notion_page_type": page["type"], - "tenant_id": current_user.current_tenant_id, - }, + notion_info=NotionInfo.model_validate( + { + "credential_id": credential_id, + "notion_workspace_id": workspace_id, + "notion_obj_id": page["page_id"], + "notion_page_type": page["type"], + "tenant_id": current_user.current_tenant_id, + } + ), document_model=args["doc_form"], ) extract_settings.append(extract_setting) @@ -528,14 +530,16 @@ class DatasetIndexingEstimateApi(Resource): for url in website_info_list["urls"]: extract_setting = ExtractSetting( datasource_type=DatasourceType.WEBSITE.value, - website_info={ - "provider": website_info_list["provider"], - "job_id": website_info_list["job_id"], - "url": url, - "tenant_id": current_user.current_tenant_id, - "mode": "crawl", - "only_main_content": website_info_list["only_main_content"], - }, + website_info=WebsiteInfo.model_validate( + { + "provider": website_info_list["provider"], + "job_id": website_info_list["job_id"], + "url": url, + "tenant_id": current_user.current_tenant_id, + "mode": "crawl", + "only_main_content": website_info_list["only_main_content"], + } + ), document_model=args["doc_form"], ) extract_settings.append(extract_setting) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index c5fa2061bf..a90730e997 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -44,7 +44,7 @@ from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.plugin.impl.exc import PluginDaemonClientSideError from core.rag.extractor.entity.datasource_type import DatasourceType -from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from extensions.ext_database import db from fields.document_fields import ( dataset_and_document_fields, @@ -305,7 +305,7 @@ class DatasetDocumentListApi(Resource): "doc_language", type=str, default="English", required=False, nullable=False, location="json" ) args = parser.parse_args() - knowledge_config = KnowledgeConfig(**args) + knowledge_config = KnowledgeConfig.model_validate(args) if not dataset.indexing_technique and not knowledge_config.indexing_technique: raise ValueError("indexing_technique is required.") @@ -395,7 +395,7 @@ class DatasetInitApi(Resource): parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json") args = parser.parse_args() - knowledge_config = KnowledgeConfig(**args) + knowledge_config = KnowledgeConfig.model_validate(args) if knowledge_config.indexing_technique == "high_quality": if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: raise ValueError("embedding model and embedding model provider are required for high quality indexing.") @@ -547,13 +547,15 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): continue extract_setting = ExtractSetting( datasource_type=DatasourceType.NOTION.value, - notion_info={ - "credential_id": data_source_info["credential_id"], - "notion_workspace_id": data_source_info["notion_workspace_id"], - "notion_obj_id": data_source_info["notion_page_id"], - "notion_page_type": data_source_info["type"], - "tenant_id": current_user.current_tenant_id, - }, + notion_info=NotionInfo.model_validate( + { + "credential_id": data_source_info["credential_id"], + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], + "tenant_id": current_user.current_tenant_id, + } + ), document_model=document.doc_form, ) extract_settings.append(extract_setting) @@ -562,14 +564,16 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): continue extract_setting = ExtractSetting( datasource_type=DatasourceType.WEBSITE.value, - website_info={ - "provider": data_source_info["provider"], - "job_id": data_source_info["job_id"], - "url": data_source_info["url"], - "tenant_id": current_user.current_tenant_id, - "mode": data_source_info["mode"], - "only_main_content": data_source_info["only_main_content"], - }, + website_info=WebsiteInfo.model_validate( + { + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], + "url": data_source_info["url"], + "tenant_id": current_user.current_tenant_id, + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], + } + ), document_model=document.doc_form, ) extract_settings.append(extract_setting) diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 9f2805e2c6..d6bd02483d 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -309,7 +309,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): ) args = parser.parse_args() SegmentService.segment_create_args_validate(args, document) - segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset) + segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset) return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 @setup_required @@ -564,7 +564,7 @@ class ChildChunkAddApi(Resource): args = parser.parse_args() try: chunks_data = args["chunks"] - chunks = [ChildChunkUpdateArgs(**chunk) for chunk in chunks_data] + chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data] child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset) except ChildChunkIndexingServiceError as e: raise ChildChunkIndexingError(str(e)) diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index dc3cd3fce9..8438458617 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -28,7 +28,7 @@ class DatasetMetadataCreateApi(Resource): parser.add_argument("type", type=str, required=True, nullable=False, location="json") parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - metadata_args = MetadataArgs(**args) + metadata_args = MetadataArgs.model_validate(args) dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -137,7 +137,7 @@ class DocumentMetadataEditApi(Resource): parser = reqparse.RequestParser() parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json") args = parser.parse_args() - metadata_args = MetadataOperationData(**args) + metadata_args = MetadataOperationData.model_validate(args) MetadataService.update_documents_metadata(dataset, metadata_args) diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 3af590afc8..e021f95283 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -88,7 +88,7 @@ class CustomizedPipelineTemplateApi(Resource): nullable=True, ) args = parser.parse_args() - pipeline_template_info = PipelineTemplateInfoEntity(**args) + pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args) RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info) return 200 diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index b683aa3160..a36d6b0745 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -128,7 +128,7 @@ def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseMo raise ValueError("invalid json") try: - payload = payload_type(**data) + payload = payload_type.model_validate(data) except Exception as e: raise ValueError(f"invalid payload: {str(e)}") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 961b96db91..92bbb76f0f 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -280,7 +280,7 @@ class DatasetListApi(DatasetApiResource): external_knowledge_id=args["external_knowledge_id"], embedding_model_provider=args["embedding_model_provider"], embedding_model_name=args["embedding_model"], - retrieval_model=RetrievalModel(**args["retrieval_model"]) + retrieval_model=RetrievalModel.model_validate(args["retrieval_model"]) if args["retrieval_model"] is not None else None, ) diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index c1122acd7b..961a338bc5 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -136,7 +136,7 @@ class DocumentAddByTextApi(DatasetApiResource): "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, } args["data_source"] = data_source - knowledge_config = KnowledgeConfig(**args) + knowledge_config = KnowledgeConfig.model_validate(args) # validate args DocumentService.document_create_args_validate(knowledge_config) @@ -221,7 +221,7 @@ class DocumentUpdateByTextApi(DatasetApiResource): args["data_source"] = data_source # validate args args["original_document_id"] = str(document_id) - knowledge_config = KnowledgeConfig(**args) + knowledge_config = KnowledgeConfig.model_validate(args) DocumentService.document_create_args_validate(knowledge_config) try: @@ -328,7 +328,7 @@ class DocumentAddByFileApi(DatasetApiResource): } args["data_source"] = data_source # validate args - knowledge_config = KnowledgeConfig(**args) + knowledge_config = KnowledgeConfig.model_validate(args) DocumentService.document_create_args_validate(knowledge_config) dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None @@ -426,7 +426,7 @@ class DocumentUpdateByFileApi(DatasetApiResource): # validate args args["original_document_id"] = str(document_id) - knowledge_config = KnowledgeConfig(**args) + knowledge_config = KnowledgeConfig.model_validate(args) DocumentService.document_create_args_validate(knowledge_config) try: diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index e01659dc68..51420fdd5f 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -51,7 +51,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): def post(self, tenant_id, dataset_id): """Create metadata for a dataset.""" args = metadata_create_parser.parse_args() - metadata_args = MetadataArgs(**args) + metadata_args = MetadataArgs.model_validate(args) dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -200,7 +200,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource): DatasetService.check_dataset_permission(dataset, current_user) args = document_metadata_parser.parse_args() - metadata_args = MetadataOperationData(**args) + metadata_args = MetadataOperationData.model_validate(args) MetadataService.update_documents_metadata(dataset, metadata_args) diff --git a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py index f05325d711..13ef8abc2d 100644 --- a/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/service_api/dataset/rag_pipeline/rag_pipeline_workflow.py @@ -98,7 +98,7 @@ class DatasourceNodeRunApi(DatasetApiResource): parser.add_argument("is_published", type=bool, required=True, location="json") args: ParseResult = parser.parse_args() - datasource_node_run_api_entity: DatasourceNodeRunApiEntity = DatasourceNodeRunApiEntity(**args) + datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args) assert isinstance(current_user, Account) rag_pipeline_service: RagPipelineService = RagPipelineService() pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id) diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index a22155b07a..d674c7467d 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -252,7 +252,7 @@ class DatasetSegmentApi(DatasetApiResource): args = segment_update_parser.parse_args() updated_segment = SegmentService.update_segment( - SegmentUpdateArgs(**args["segment"]), segment, document, dataset + SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset ) return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200 diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index eab26e5af9..c1f336fdde 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -40,7 +40,7 @@ class AgentConfigManager: "credential_id": tool.get("credential_id", None), } - agent_tools.append(AgentToolEntity(**agent_tool_properties)) + agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties)) if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in { "react_router", diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 145f629c4d..866c46d963 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -116,7 +116,7 @@ class PipelineRunner(WorkflowBasedAppRunner): rag_pipeline_variables = [] if workflow.rag_pipeline_variables: for v in workflow.rag_pipeline_variables: - rag_pipeline_variable = RAGPipelineVariable(**v) + rag_pipeline_variable = RAGPipelineVariable.model_validate(v) if ( rag_pipeline_variable.belong_to_node_id in (self.application_generate_entity.start_node_id, "shared") diff --git a/api/core/datasource/entities/common_entities.py b/api/core/datasource/entities/common_entities.py index ac36d83ae3..3c64632dbb 100644 --- a/api/core/datasource/entities/common_entities.py +++ b/api/core/datasource/entities/common_entities.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator class I18nObject(BaseModel): @@ -11,11 +11,12 @@ class I18nObject(BaseModel): pt_BR: str | None = Field(default=None) ja_JP: str | None = Field(default=None) - def __init__(self, **data): - super().__init__(**data) + @model_validator(mode="after") + def _(self): self.zh_Hans = self.zh_Hans or self.en_US self.pt_BR = self.pt_BR or self.en_US self.ja_JP = self.ja_JP or self.en_US + return self def to_dict(self) -> dict: return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP} diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 111de89178..2857729a81 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -5,7 +5,7 @@ from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, model_validator from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -73,9 +73,8 @@ class ProviderConfiguration(BaseModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def __init__(self, **data): - super().__init__(**data) - + @model_validator(mode="after") + def _(self): if self.provider.provider not in original_provider_configurate_methods: original_provider_configurate_methods[self.provider.provider] = [] for configurate_method in self.provider.configurate_methods: @@ -90,6 +89,7 @@ class ProviderConfiguration(BaseModel): and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods ): self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) + return self def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None: """ diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 0c1d03dc13..f92278f9e2 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -131,7 +131,7 @@ class CodeExecutor: if (code := response_data.get("code")) != 0: raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}") - response_code = CodeExecutionResponse(**response_data) + response_code = CodeExecutionResponse.model_validate(response_data) if response_code.data.error: raise CodeExecutionError(response_code.data.error) diff --git a/api/core/helper/marketplace.py b/api/core/helper/marketplace.py index 10f304c087..bddb864a95 100644 --- a/api/core/helper/marketplace.py +++ b/api/core/helper/marketplace.py @@ -26,7 +26,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version}) response.raise_for_status() - return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]] + return [MarketplacePluginDeclaration.model_validate(plugin) for plugin in response.json()["data"]["plugins"]] def batch_fetch_plugin_manifests_ignore_deserialization_error( @@ -41,7 +41,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error( result: list[MarketplacePluginDeclaration] = [] for plugin in response.json()["data"]["plugins"]: try: - result.append(MarketplacePluginDeclaration(**plugin)) + result.append(MarketplacePluginDeclaration.model_validate(plugin)) except Exception: pass diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index ee37024260..3682fdb667 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -20,7 +20,7 @@ from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.datasource_type import DatasourceType -from core.rag.extractor.entity.extract_setting import ExtractSetting +from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -357,14 +357,16 @@ class IndexingRunner: raise ValueError("no notion import info found") extract_setting = ExtractSetting( datasource_type=DatasourceType.NOTION.value, - notion_info={ - "credential_id": data_source_info["credential_id"], - "notion_workspace_id": data_source_info["notion_workspace_id"], - "notion_obj_id": data_source_info["notion_page_id"], - "notion_page_type": data_source_info["type"], - "document": dataset_document, - "tenant_id": dataset_document.tenant_id, - }, + notion_info=NotionInfo.model_validate( + { + "credential_id": data_source_info["credential_id"], + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], + "document": dataset_document, + "tenant_id": dataset_document.tenant_id, + } + ), document_model=dataset_document.doc_form, ) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) @@ -378,14 +380,16 @@ class IndexingRunner: raise ValueError("no website import info found") extract_setting = ExtractSetting( datasource_type=DatasourceType.WEBSITE.value, - website_info={ - "provider": data_source_info["provider"], - "job_id": data_source_info["job_id"], - "tenant_id": dataset_document.tenant_id, - "url": data_source_info["url"], - "mode": data_source_info["mode"], - "only_main_content": data_source_info["only_main_content"], - }, + website_info=WebsiteInfo.model_validate( + { + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], + "tenant_id": dataset_document.tenant_id, + "url": data_source_info["url"], + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], + } + ), document_model=dataset_document.doc_form, ) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) diff --git a/api/core/mcp/session/client_session.py b/api/core/mcp/session/client_session.py index 5817416ba4..fa1d309134 100644 --- a/api/core/mcp/session/client_session.py +++ b/api/core/mcp/session/client_session.py @@ -294,7 +294,7 @@ class ClientSession( method="completion/complete", params=types.CompleteRequestParams( ref=ref, - argument=types.CompletionArgument(**argument), + argument=types.CompletionArgument.model_validate(argument), ), ) ), diff --git a/api/core/model_runtime/entities/common_entities.py b/api/core/model_runtime/entities/common_entities.py index c7353de5af..b673efae22 100644 --- a/api/core/model_runtime/entities/common_entities.py +++ b/api/core/model_runtime/entities/common_entities.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from pydantic import BaseModel, model_validator class I18nObject(BaseModel): @@ -9,7 +9,8 @@ class I18nObject(BaseModel): zh_Hans: str | None = None en_US: str - def __init__(self, **data): - super().__init__(**data) + @model_validator(mode="after") + def _(self): if not self.zh_Hans: self.zh_Hans = self.en_US + return self diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index 2ccc9e0eae..831fb9d4db 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from enum import Enum, StrEnum, auto -from pydantic import BaseModel, ConfigDict, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, ModelType @@ -46,10 +46,11 @@ class FormOption(BaseModel): value: str show_on: list[FormShowOnObject] = [] - def __init__(self, **data): - super().__init__(**data) + @model_validator(mode="after") + def _(self): if not self.label: self.label = I18nObject(en_US=self.value) + return self class CredentialFormSchema(BaseModel): diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index e070c17abd..e1afc41bee 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -269,17 +269,17 @@ class ModelProviderFactory: } if model_type == ModelType.LLM: - return LargeLanguageModel(**init_params) # type: ignore + return LargeLanguageModel.model_validate(init_params) elif model_type == ModelType.TEXT_EMBEDDING: - return TextEmbeddingModel(**init_params) # type: ignore + return TextEmbeddingModel.model_validate(init_params) elif model_type == ModelType.RERANK: - return RerankModel(**init_params) # type: ignore + return RerankModel.model_validate(init_params) elif model_type == ModelType.SPEECH2TEXT: - return Speech2TextModel(**init_params) # type: ignore + return Speech2TextModel.model_validate(init_params) elif model_type == ModelType.MODERATION: - return ModerationModel(**init_params) # type: ignore + return ModerationModel.model_validate(init_params) elif model_type == ModelType.TTS: - return TTSModel(**init_params) # type: ignore + return TTSModel.model_validate(init_params) def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: """ diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index 573f4ec2a7..2d72b17a04 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -51,7 +51,7 @@ class ApiModeration(Moderation): params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query) result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump()) - return ModerationInputsResult(**result) + return ModerationInputsResult.model_validate(result) return ModerationInputsResult( flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response @@ -67,7 +67,7 @@ class ApiModeration(Moderation): params = ModerationOutputParams(app_id=self.app_id, text=text) result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump()) - return ModerationOutputsResult(**result) + return ModerationOutputsResult.model_validate(result) return ModerationOutputsResult( flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 10f37f75f8..7b789d8ac9 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -84,15 +84,15 @@ class RequestInvokeLLM(BaseRequestInvokeModel): for i in range(len(v)): if v[i]["role"] == PromptMessageRole.USER.value: - v[i] = UserPromptMessage(**v[i]) + v[i] = UserPromptMessage.model_validate(v[i]) elif v[i]["role"] == PromptMessageRole.ASSISTANT.value: - v[i] = AssistantPromptMessage(**v[i]) + v[i] = AssistantPromptMessage.model_validate(v[i]) elif v[i]["role"] == PromptMessageRole.SYSTEM.value: - v[i] = SystemPromptMessage(**v[i]) + v[i] = SystemPromptMessage.model_validate(v[i]) elif v[i]["role"] == PromptMessageRole.TOOL.value: - v[i] = ToolPromptMessage(**v[i]) + v[i] = ToolPromptMessage.model_validate(v[i]) else: - v[i] = PromptMessage(**v[i]) + v[i] = PromptMessage.model_validate(v[i]) return v diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 8e3df4da2c..62a5cc535a 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -94,7 +94,7 @@ class BasePluginClient: self, method: str, path: str, - type: type[T], + type_: type[T], headers: dict | None = None, data: bytes | dict | None = None, params: dict | None = None, @@ -104,13 +104,13 @@ class BasePluginClient: Make a stream request to the plugin daemon inner API and yield the response as a model. """ for line in self._stream_request(method, path, params, headers, data, files): - yield type(**json.loads(line)) # type: ignore + yield type_(**json.loads(line)) # type: ignore def _request_with_model( self, method: str, path: str, - type: type[T], + type_: type[T], headers: dict | None = None, data: bytes | None = None, params: dict | None = None, @@ -120,13 +120,13 @@ class BasePluginClient: Make a request to the plugin daemon inner API and return the response as a model. """ response = self._request(method, path, headers, data, params, files) - return type(**response.json()) # type: ignore + return type_(**response.json()) # type: ignore def _request_with_plugin_daemon_response( self, method: str, path: str, - type: type[T], + type_: type[T], headers: dict | None = None, data: bytes | dict | None = None, params: dict | None = None, @@ -140,22 +140,22 @@ class BasePluginClient: response = self._request(method, path, headers, data, params, files) response.raise_for_status() except HTTPError as e: - msg = f"Failed to request plugin daemon, status: {e.response.status_code}, url: {path}" - logger.exception(msg) + logger.exception("Failed to request plugin daemon, status: %s, url: %s", e.response.status_code, path) raise e except Exception as e: msg = f"Failed to request plugin daemon, url: {path}" - logger.exception(msg) + logger.exception("Failed to request plugin daemon, url: %s", path) raise ValueError(msg) from e try: json_response = response.json() if transformer: json_response = transformer(json_response) - rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore + # https://stackoverflow.com/questions/59634937/variable-foo-class-is-not-valid-as-type-but-why + rep = PluginDaemonBasicResponse[type_].model_validate(json_response) # type: ignore except Exception: msg = ( - f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type.__name__)}]," + f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type_.__name__)}]," f" url: {path}" ) logger.exception(msg) @@ -163,7 +163,7 @@ class BasePluginClient: if rep.code != 0: try: - error = PluginDaemonError(**json.loads(rep.message)) + error = PluginDaemonError.model_validate(json.loads(rep.message)) except Exception: raise ValueError(f"{rep.message}, code: {rep.code}") @@ -178,7 +178,7 @@ class BasePluginClient: self, method: str, path: str, - type: type[T], + type_: type[T], headers: dict | None = None, data: bytes | dict | None = None, params: dict | None = None, @@ -189,7 +189,7 @@ class BasePluginClient: """ for line in self._stream_request(method, path, params, headers, data, files): try: - rep = PluginDaemonBasicResponse[type].model_validate_json(line) # type: ignore + rep = PluginDaemonBasicResponse[type_].model_validate_json(line) # type: ignore except (ValueError, TypeError): # TODO modify this when line_data has code and message try: @@ -204,7 +204,7 @@ class BasePluginClient: if rep.code != 0: if rep.code == -500: try: - error = PluginDaemonError(**json.loads(rep.message)) + error = PluginDaemonError.model_validate(json.loads(rep.message)) except Exception: raise PluginDaemonInnerError(code=rep.code, message=rep.message) diff --git a/api/core/plugin/impl/datasource.py b/api/core/plugin/impl/datasource.py index 84087f8104..ce1ef71494 100644 --- a/api/core/plugin/impl/datasource.py +++ b/api/core/plugin/impl/datasource.py @@ -46,7 +46,9 @@ class PluginDatasourceManager(BasePluginClient): params={"page": 1, "page_size": 256}, transformer=transformer, ) - local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) + local_file_datasource_provider = PluginDatasourceProviderEntity.model_validate( + self._get_local_file_datasource_provider() + ) for provider in response: ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider) @@ -104,7 +106,7 @@ class PluginDatasourceManager(BasePluginClient): Fetch datasource provider for the given tenant and plugin. """ if provider_id == "langgenius/file/file": - return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider()) + return PluginDatasourceProviderEntity.model_validate(self._get_local_file_datasource_provider()) tool_provider_id = DatasourceProviderID(provider_id) diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 153da142f4..5dfc3c212e 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -162,7 +162,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/llm/invoke", - type=LLMResultChunk, + type_=LLMResultChunk, data=jsonable_encoder( { "user_id": user_id, @@ -208,7 +208,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/llm/num_tokens", - type=PluginLLMNumTokensResponse, + type_=PluginLLMNumTokensResponse, data=jsonable_encoder( { "user_id": user_id, @@ -250,7 +250,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke", - type=TextEmbeddingResult, + type_=TextEmbeddingResult, data=jsonable_encoder( { "user_id": user_id, @@ -291,7 +291,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens", - type=PluginTextEmbeddingNumTokensResponse, + type_=PluginTextEmbeddingNumTokensResponse, data=jsonable_encoder( { "user_id": user_id, @@ -334,7 +334,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/rerank/invoke", - type=RerankResult, + type_=RerankResult, data=jsonable_encoder( { "user_id": user_id, @@ -378,7 +378,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/tts/invoke", - type=PluginStringResultResponse, + type_=PluginStringResultResponse, data=jsonable_encoder( { "user_id": user_id, @@ -422,7 +422,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/tts/model/voices", - type=PluginVoicesResponse, + type_=PluginVoicesResponse, data=jsonable_encoder( { "user_id": user_id, @@ -466,7 +466,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/speech2text/invoke", - type=PluginStringResultResponse, + type_=PluginStringResultResponse, data=jsonable_encoder( { "user_id": user_id, @@ -506,7 +506,7 @@ class PluginModelClient(BasePluginClient): response = self._request_with_plugin_daemon_response_stream( method="POST", path=f"plugin/{tenant_id}/dispatch/moderation/invoke", - type=PluginBasicBooleanResponse, + type_=PluginBasicBooleanResponse, data=jsonable_encoder( { "user_id": user_id, diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 63a1d911ca..38358ccd6d 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -134,7 +134,7 @@ class RetrievalService: if not dataset: return [] metadata_condition = ( - MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None + MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None ) all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( dataset.tenant_id, diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index b9bf9d0d8c..c3bfbce98f 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -17,9 +17,6 @@ class NotionInfo(BaseModel): tenant_id: str model_config = ConfigDict(arbitrary_types_allowed=True) - def __init__(self, **data): - super().__init__(**data) - class WebsiteInfo(BaseModel): """ @@ -47,6 +44,3 @@ class ExtractSetting(BaseModel): website_info: WebsiteInfo | None = None document_model: str | None = None model_config = ConfigDict(arbitrary_types_allowed=True) - - def __init__(self, **data): - super().__init__(**data) diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 755aa88d08..4fcffbcc77 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -38,11 +38,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor): raise ValueError("No process rule found.") if process_rule.get("mode") == "automatic": automatic_rule = DatasetProcessRule.AUTOMATIC_RULES - rules = Rule(**automatic_rule) + rules = Rule.model_validate(automatic_rule) else: if not process_rule.get("rules"): raise ValueError("No rules found in process rule.") - rules = Rule(**process_rule.get("rules")) + rules = Rule.model_validate(process_rule.get("rules")) # Split the text documents into nodes. if not rules.segmentation: raise ValueError("No segmentation found in rules.") diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index e0ccd8b567..7bdde286f5 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -40,7 +40,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): raise ValueError("No process rule found.") if not process_rule.get("rules"): raise ValueError("No rules found in process rule.") - rules = Rule(**process_rule.get("rules")) + rules = Rule.model_validate(process_rule.get("rules")) all_documents: list[Document] = [] if rules.parent_mode == ParentMode.PARAGRAPH: # Split the text documents into nodes. @@ -110,7 +110,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): child_documents = document.children if child_documents: formatted_child_documents = [ - Document(**child_document.model_dump()) for child_document in child_documents + Document.model_validate(child_document.model_dump()) for child_document in child_documents ] vector.create(formatted_child_documents) @@ -224,7 +224,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): return child_nodes def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any): - parent_childs = ParentChildStructureChunk(**chunks) + parent_childs = ParentChildStructureChunk.model_validate(chunks) documents = [] for parent_child in parent_childs.parent_child_chunks: metadata = { @@ -274,7 +274,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): vector.create(all_child_documents) def format_preview(self, chunks: Any) -> Mapping[str, Any]: - parent_childs = ParentChildStructureChunk(**chunks) + parent_childs = ParentChildStructureChunk.model_validate(chunks) preview = [] for parent_child in parent_childs.parent_child_chunks: preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents}) diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 2054031643..9c8f70dba8 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -47,7 +47,7 @@ class QAIndexProcessor(BaseIndexProcessor): raise ValueError("No process rule found.") if not process_rule.get("rules"): raise ValueError("No rules found in process rule.") - rules = Rule(**process_rule.get("rules")) + rules = Rule.model_validate(process_rule.get("rules")) splitter = self._get_splitter( processing_rule_mode=process_rule.get("mode"), max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0, @@ -168,7 +168,7 @@ class QAIndexProcessor(BaseIndexProcessor): return docs def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any): - qa_chunks = QAStructureChunk(**chunks) + qa_chunks = QAStructureChunk.model_validate(chunks) documents = [] for qa_chunk in qa_chunks.qa_chunks: metadata = { @@ -191,7 +191,7 @@ class QAIndexProcessor(BaseIndexProcessor): raise ValueError("Indexing technique must be high quality.") def format_preview(self, chunks: Any) -> Mapping[str, Any]: - qa_chunks = QAStructureChunk(**chunks) + qa_chunks = QAStructureChunk.model_validate(chunks) preview = [] for qa_chunk in qa_chunks.qa_chunks: preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer}) diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 45fd16d684..29d34e722a 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -90,7 +90,7 @@ class BuiltinToolProviderController(ToolProviderController): tools.append( assistant_tool_class( provider=provider, - entity=ToolEntity(**tool), + entity=ToolEntity.model_validate(tool), runtime=ToolRuntime(tenant_id=""), ) ) diff --git a/api/core/tools/entities/common_entities.py b/api/core/tools/entities/common_entities.py index 2c6d9c1964..21d310bbb9 100644 --- a/api/core/tools/entities/common_entities.py +++ b/api/core/tools/entities/common_entities.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator class I18nObject(BaseModel): @@ -11,11 +11,12 @@ class I18nObject(BaseModel): pt_BR: str | None = Field(default=None) ja_JP: str | None = Field(default=None) - def __init__(self, **data): - super().__init__(**data) + @model_validator(mode="after") + def _populate_missing_locales(self): self.zh_Hans = self.zh_Hans or self.en_US self.pt_BR = self.pt_BR or self.en_US self.ja_JP = self.ja_JP or self.en_US + return self def to_dict(self): return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP} diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py index 5b04f0edbe..f269b8db9b 100644 --- a/api/core/tools/mcp_tool/provider.py +++ b/api/core/tools/mcp_tool/provider.py @@ -54,7 +54,7 @@ class MCPToolProviderController(ToolProviderController): """ tools = [] tools_data = json.loads(db_provider.tools) - remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data] + remote_mcp_tools = [RemoteMCPTool.model_validate(tool) for tool in tools_data] user = db_provider.load_user() tools = [ ToolEntity( diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 9e5f5a7c23..af68971ca7 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -1008,7 +1008,7 @@ class ToolManager: config = tool_configurations.get(parameter.name, {}) if not (config and isinstance(config, dict) and config.get("value") is not None): continue - tool_input = ToolNodeData.ToolInput(**tool_configurations.get(parameter.name, {})) + tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {})) if tool_input.type == "variable": variable = variable_pool.get(tool_input.value) if variable is None: diff --git a/api/core/workflow/graph_engine/command_channels/redis_channel.py b/api/core/workflow/graph_engine/command_channels/redis_channel.py index 056e17bf5d..c841459170 100644 --- a/api/core/workflow/graph_engine/command_channels/redis_channel.py +++ b/api/core/workflow/graph_engine/command_channels/redis_channel.py @@ -105,10 +105,10 @@ class RedisChannel: command_type = CommandType(command_type_value) if command_type == CommandType.ABORT: - return AbortCommand(**data) + return AbortCommand.model_validate(data) else: # For other command types, use base class - return GraphEngineCommand(**data) + return GraphEngineCommand.model_validate(data) except (ValueError, TypeError): return None diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 2bdfe4efce..7ec74084d0 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -16,7 +16,7 @@ class EndNode(Node): _node_data: EndNodeData def init_node_data(self, data: Mapping[str, Any]): - self._node_data = EndNodeData(**data) + self._node_data = EndNodeData.model_validate(data) def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index 80f39ccebc..90b7f4539b 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -18,7 +18,7 @@ class IterationStartNode(Node): _node_data: IterationStartNodeData def init_node_data(self, data: Mapping[str, Any]): - self._node_data = IterationStartNodeData(**data) + self._node_data = IterationStartNodeData.model_validate(data) def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 3243b22d44..180eb2ad90 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -41,7 +41,7 @@ class ListOperatorNode(Node): _node_data: ListOperatorNodeData def init_node_data(self, data: Mapping[str, Any]): - self._node_data = ListOperatorNodeData(**data) + self._node_data = ListOperatorNodeData.model_validate(data) def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py index 38aef06d24..e5bce1230c 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/core/workflow/nodes/loop/loop_end_node.py @@ -18,7 +18,7 @@ class LoopEndNode(Node): _node_data: LoopEndNodeData def init_node_data(self, data: Mapping[str, Any]): - self._node_data = LoopEndNodeData(**data) + self._node_data = LoopEndNodeData.model_validate(data) def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py index e777a8cbe9..e065dc90a0 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/core/workflow/nodes/loop/loop_start_node.py @@ -18,7 +18,7 @@ class LoopStartNode(Node): _node_data: LoopStartNodeData def init_node_data(self, data: Mapping[str, Any]): - self._node_data = LoopStartNodeData(**data) + self._node_data = LoopStartNodeData.model_validate(data) def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 2f33c54128..3b134be1a1 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -16,7 +16,7 @@ class StartNode(Node): _node_data: StartNodeData def init_node_data(self, data: Mapping[str, Any]): - self._node_data = StartNodeData(**data) + self._node_data = StartNodeData.model_validate(data) def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index be00d55937..0ac0d3d858 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -15,7 +15,7 @@ class VariableAggregatorNode(Node): _node_data: VariableAssignerNodeData def init_node_data(self, data: Mapping[str, Any]): - self._node_data = VariableAssignerNodeData(**data) + self._node_data = VariableAssignerNodeData.model_validate(data) def _get_error_strategy(self) -> ErrorStrategy | None: return self._node_data.error_strategy diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 6c9fc0bf1d..21b73b76b5 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -14,7 +14,7 @@ def handle(sender, **kwargs): for node_data in synced_draft_workflow.graph_dict.get("nodes", []): if node_data.get("data", {}).get("type") == NodeType.TOOL.value: try: - tool_entity = ToolEntity(**node_data["data"]) + tool_entity = ToolEntity.model_validate(node_data["data"]) tool_runtime = ToolManager.get_tool_runtime( provider_type=tool_entity.provider_type, provider_id=tool_entity.provider_id, diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 898ec1f153..7605d4082c 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -61,7 +61,7 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[str]: for node in knowledge_retrieval_nodes: try: - node_data = KnowledgeRetrievalNodeData(**node.get("data", {})) + node_data = KnowledgeRetrievalNodeData.model_validate(node.get("data", {})) dataset_ids.update(dataset_id for dataset_id in node_data.dataset_ids) except Exception: continue diff --git a/api/models/dataset.py b/api/models/dataset.py index 25ebe14738..6263c04365 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -754,7 +754,7 @@ class DocumentSegment(Base): if process_rule and process_rule.mode == "hierarchical": rules_dict = process_rule.rules_dict if rules_dict: - rules = Rule(**rules_dict) + rules = Rule.model_validate(rules_dict) if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: child_chunks = ( db.session.query(ChildChunk) @@ -772,7 +772,7 @@ class DocumentSegment(Base): if process_rule and process_rule.mode == "hierarchical": rules_dict = process_rule.rules_dict if rules_dict: - rules = Rule(**rules_dict) + rules = Rule.model_validate(rules_dict) if rules.parent_mode: child_chunks = ( db.session.query(ChildChunk) diff --git a/api/models/tools.py b/api/models/tools.py index 7211d7aa3a..d581d588a4 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -152,7 +152,7 @@ class ApiToolProvider(Base): def tools(self) -> list["ApiToolBundle"]: from core.tools.entities.tool_bundle import ApiToolBundle - return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)] + return [ApiToolBundle.model_validate(tool) for tool in json.loads(self.tools_str)] @property def credentials(self) -> dict[str, Any]: @@ -242,7 +242,10 @@ class WorkflowToolProvider(Base): def parameter_configurations(self) -> list["WorkflowToolParameterConfiguration"]: from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration - return [WorkflowToolParameterConfiguration(**config) for config in json.loads(self.parameter_configuration)] + return [ + WorkflowToolParameterConfiguration.model_validate(config) + for config in json.loads(self.parameter_configuration) + ] @property def app(self) -> App | None: @@ -312,7 +315,7 @@ class MCPToolProvider(Base): def mcp_tools(self) -> list["MCPTool"]: from core.mcp.types import Tool as MCPTool - return [MCPTool(**tool) for tool in json.loads(self.tools)] + return [MCPTool.model_validate(tool) for tool in json.loads(self.tools)] @property def provider_icon(self) -> Mapping[str, str] | str: @@ -552,4 +555,4 @@ class DeprecatedPublishedAppTool(Base): def description_i18n(self) -> "I18nObject": from core.tools.entities.common_entities import I18nObject - return I18nObject(**json.loads(self.description)) + return I18nObject.model_validate(json.loads(self.description)) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 8701fe4f4e..129e3b0492 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -659,31 +659,31 @@ class AppDslService: typ = node.get("data", {}).get("type") match typ: case NodeType.TOOL.value: - tool_entity = ToolNodeData(**node["data"]) + tool_entity = ToolNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id), ) case NodeType.LLM.value: - llm_entity = LLMNodeData(**node["data"]) + llm_entity = LLMNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider), ) case NodeType.QUESTION_CLASSIFIER.value: - question_classifier_entity = QuestionClassifierNodeData(**node["data"]) + question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( question_classifier_entity.model.provider ), ) case NodeType.PARAMETER_EXTRACTOR.value: - parameter_extractor_entity = ParameterExtractorNodeData(**node["data"]) + parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( parameter_extractor_entity.model.provider ), ) case NodeType.KNOWLEDGE_RETRIEVAL.value: - knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"]) + knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"]) if knowledge_retrieval_entity.retrieval_mode == "multiple": if knowledge_retrieval_entity.multiple_retrieval_config: if ( @@ -773,7 +773,7 @@ class AppDslService: """ Returns the leaked dependencies in current workspace """ - dependencies = [PluginDependency(**dep) for dep in dsl_dependencies] + dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies] if not dependencies: return [] diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index f8612456d6..4fbf33fd6f 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -70,7 +70,7 @@ class EnterpriseService: data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params) if not data: raise ValueError("No data found.") - return WebAppSettings(**data) + return WebAppSettings.model_validate(data) @classmethod def batch_get_app_access_mode_by_id(cls, app_ids: list[str]) -> dict[str, WebAppSettings]: @@ -100,7 +100,7 @@ class EnterpriseService: data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params) if not data: raise ValueError("No data found.") - return WebAppSettings(**data) + return WebAppSettings.model_validate(data) @classmethod def update_app_access_mode(cls, app_id: str, access_mode: str): diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 49d48f044c..0f5151919f 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -1,6 +1,7 @@ +from collections.abc import Sequence from enum import Enum -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, model_validator from configs import dify_config from core.entities.model_entities import ( @@ -71,7 +72,7 @@ class ProviderResponse(BaseModel): icon_large: I18nObject | None = None background: str | None = None help: ProviderHelpEntity | None = None - supported_model_types: list[ModelType] + supported_model_types: Sequence[ModelType] configurate_methods: list[ConfigurateMethod] provider_credential_schema: ProviderCredentialSchema | None = None model_credential_schema: ModelCredentialSchema | None = None @@ -82,9 +83,8 @@ class ProviderResponse(BaseModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) - def __init__(self, **data): - super().__init__(**data) - + @model_validator(mode="after") + def _(self): url_prefix = ( dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}" ) @@ -97,6 +97,7 @@ class ProviderResponse(BaseModel): self.icon_large = I18nObject( en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) + return self class ProviderWithModelsResponse(BaseModel): @@ -112,9 +113,8 @@ class ProviderWithModelsResponse(BaseModel): status: CustomConfigurationStatus models: list[ProviderModelWithStatusEntity] - def __init__(self, **data): - super().__init__(**data) - + @model_validator(mode="after") + def _(self): url_prefix = ( dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}" ) @@ -127,6 +127,7 @@ class ProviderWithModelsResponse(BaseModel): self.icon_large = I18nObject( en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) + return self class SimpleProviderEntityResponse(SimpleProviderEntity): @@ -136,9 +137,8 @@ class SimpleProviderEntityResponse(SimpleProviderEntity): tenant_id: str - def __init__(self, **data): - super().__init__(**data) - + @model_validator(mode="after") + def _(self): url_prefix = ( dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}" ) @@ -151,6 +151,7 @@ class SimpleProviderEntityResponse(SimpleProviderEntity): self.icon_large = I18nObject( en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans" ) + return self class DefaultModelResponse(BaseModel): diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 00ec3babf3..6174ce8b3b 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -46,7 +46,7 @@ class HitTestingService: from core.app.app_config.entities import MetadataFilteringCondition - metadata_filtering_conditions = MetadataFilteringCondition(**metadata_filtering_conditions) + metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions) metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition( dataset_ids=[dataset.id], diff --git a/api/services/ops_service.py b/api/services/ops_service.py index c214640653..b4b23b8360 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -123,7 +123,7 @@ class OpsService: config_class: type[BaseTracingConfig] = provider_config["config_class"] other_keys: list[str] = provider_config["other_keys"] - default_config_instance: BaseTracingConfig = config_class(**tracing_config) + default_config_instance = config_class.model_validate(tracing_config) for key in other_keys: if key in tracing_config and tracing_config[key] == "": tracing_config[key] = getattr(default_config_instance, key, None) diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index 99946d8fa9..76bb9a57f9 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -269,7 +269,7 @@ class PluginMigration: for tool in agent_config["tools"]: if isinstance(tool, dict): try: - tool_entity = AgentToolEntity(**tool) + tool_entity = AgentToolEntity.model_validate(tool) if ( tool_entity.provider_type == ToolProviderType.BUILT_IN.value and tool_entity.provider_id not in excluded_providers diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index fdaaa73bcc..3ced0fd9ec 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -358,7 +358,7 @@ class RagPipelineService: for node in nodes: if node.get("data", {}).get("type") == "knowledge-index": knowledge_configuration = node.get("data", {}) - knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration) + knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration) # update dataset dataset = pipeline.retrieve_dataset(session=session) diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index f74de1bcab..9dede31ab4 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -288,7 +288,7 @@ class RagPipelineDslService: dataset_id = None for node in nodes: if node.get("data", {}).get("type") == "knowledge-index": - knowledge_configuration = KnowledgeConfiguration(**node.get("data", {})) + knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {})) if ( dataset and pipeline.is_published @@ -426,7 +426,7 @@ class RagPipelineDslService: dataset_id = None for node in nodes: if node.get("data", {}).get("type") == "knowledge-index": - knowledge_configuration = KnowledgeConfiguration(**node.get("data", {})) + knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {})) if not dataset: dataset = Dataset( tenant_id=account.current_tenant_id, @@ -734,35 +734,35 @@ class RagPipelineDslService: typ = node.get("data", {}).get("type") match typ: case NodeType.TOOL.value: - tool_entity = ToolNodeData(**node["data"]) + tool_entity = ToolNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id), ) case NodeType.DATASOURCE.value: - datasource_entity = DatasourceNodeData(**node["data"]) + datasource_entity = DatasourceNodeData.model_validate(node["data"]) if datasource_entity.provider_type != "local_file": dependencies.append(datasource_entity.plugin_id) case NodeType.LLM.value: - llm_entity = LLMNodeData(**node["data"]) + llm_entity = LLMNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider), ) case NodeType.QUESTION_CLASSIFIER.value: - question_classifier_entity = QuestionClassifierNodeData(**node["data"]) + question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( question_classifier_entity.model.provider ), ) case NodeType.PARAMETER_EXTRACTOR.value: - parameter_extractor_entity = ParameterExtractorNodeData(**node["data"]) + parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( parameter_extractor_entity.model.provider ), ) case NodeType.KNOWLEDGE_INDEX.value: - knowledge_index_entity = KnowledgeConfiguration(**node["data"]) + knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"]) if knowledge_index_entity.indexing_technique == "high_quality": if knowledge_index_entity.embedding_model_provider: dependencies.append( @@ -783,7 +783,7 @@ class RagPipelineDslService: ), ) case NodeType.KNOWLEDGE_RETRIEVAL.value: - knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"]) + knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"]) if knowledge_retrieval_entity.retrieval_mode == "multiple": if knowledge_retrieval_entity.multiple_retrieval_config: if ( @@ -873,7 +873,7 @@ class RagPipelineDslService: """ Returns the leaked dependencies in current workspace """ - dependencies = [PluginDependency(**dep) for dep in dsl_dependencies] + dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies] if not dependencies: return [] diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index 3d5a85b57f..b4425d85a6 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -156,13 +156,13 @@ class RagPipelineTransformService: self, dataset: Dataset, doc_form: str, indexing_technique: str | None, retrieval_model: dict, node: dict ): knowledge_configuration_dict = node.get("data", {}) - knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration_dict) + knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict) if indexing_technique == "high_quality": knowledge_configuration.embedding_model = dataset.embedding_model knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider if retrieval_model: - retrieval_setting = RetrievalSetting(**retrieval_model) + retrieval_setting = RetrievalSetting.model_validate(retrieval_model) if indexing_technique == "economy": retrieval_setting.search_method = "keyword_search" knowledge_configuration.retrieval_model = retrieval_setting diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 6b36ed0eb7..7ae1b97b30 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -242,7 +242,7 @@ class ToolTransformService: is_team_authorization=db_provider.authed, server_url=db_provider.masked_server_url, tools=ToolTransformService.mcp_tool_to_user_tool( - db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)] + db_provider, [MCPTool.model_validate(tool) for tool in json.loads(db_provider.tools)] ), updated_at=int(db_provider.updated_at.timestamp()), label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name), @@ -387,6 +387,7 @@ class ToolTransformService: labels=labels or [], ) else: + assert tool.operation_id return ToolApiEntity( author=tool.author, name=tool.operation_id or "", diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index 7b254ac3b5..72e3b42ca7 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -36,7 +36,7 @@ def process_trace_tasks(file_info): if trace_info.get("workflow_data"): trace_info["workflow_data"] = WorkflowRun.from_dict(data=trace_info["workflow_data"]) if trace_info.get("documents"): - trace_info["documents"] = [Document(**doc) for doc in trace_info["documents"]] + trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]] try: if trace_instance: diff --git a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py index a2c99554f1..4171656131 100644 --- a/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py @@ -79,7 +79,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], # Create Flask application context for this thread with flask_app.app_context(): try: - rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity) + rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity.model_validate(rag_pipeline_invoke_entity) user_id = rag_pipeline_invoke_entity_model.user_id tenant_id = rag_pipeline_invoke_entity_model.tenant_id pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id @@ -112,7 +112,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], workflow_execution_id = str(uuid.uuid4()) # Create application generate entity from dict - entity = RagPipelineGenerateEntity(**application_generate_entity) + entity = RagPipelineGenerateEntity.model_validate(application_generate_entity) # Create workflow repositories session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index 4e00f072bf..90ebe80daf 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -100,7 +100,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], # Create Flask application context for this thread with flask_app.app_context(): try: - rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity) + rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity.model_validate(rag_pipeline_invoke_entity) user_id = rag_pipeline_invoke_entity_model.user_id tenant_id = rag_pipeline_invoke_entity_model.tenant_id pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id @@ -133,7 +133,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], workflow_execution_id = str(uuid.uuid4()) # Create application generate entity from dict - entity = RagPipelineGenerateEntity(**application_generate_entity) + entity = RagPipelineGenerateEntity.model_validate(application_generate_entity) # Create workflow repositories session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) diff --git a/api/tests/integration_tests/tools/api_tool/test_api_tool.py b/api/tests/integration_tests/tools/api_tool/test_api_tool.py index 7c1a200c8f..e637530265 100644 --- a/api/tests/integration_tests/tools/api_tool/test_api_tool.py +++ b/api/tests/integration_tests/tools/api_tool/test_api_tool.py @@ -36,7 +36,7 @@ def test_api_tool(setup_http_mock): entity=ToolEntity( identity=ToolIdentity(provider="", author="", name="", label=I18nObject(en_US="test tool")), ), - api_bundle=ApiToolBundle(**tool_bundle), + api_bundle=ApiToolBundle.model_validate(tool_bundle), runtime=ToolRuntime(tenant_id="", credentials={"auth_type": "none"}), provider_id="test_tool", ) diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py index 48cc8a7e1c..fb2ddfe162 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/milvus/test_milvus.py @@ -11,8 +11,8 @@ def test_default_value(): config = valid_config.copy() del config[key] with pytest.raises(ValidationError) as e: - MilvusConfig(**config) + MilvusConfig.model_validate(config) assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required" - config = MilvusConfig(**valid_config) + config = MilvusConfig.model_validate(valid_config) assert config.database == "default" diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index b942614232..55fe62ca43 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -35,7 +35,7 @@ def list_operator_node(): "extract_by": ExtractConfig(enabled=False, serial="1"), "title": "Test Title", } - node_data = ListOperatorNodeData(**config) + node_data = ListOperatorNodeData.model_validate(config) node_config = { "id": "test_node_id", "data": node_data.model_dump(), diff --git a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py index f990280c5f..47ef289ef3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py @@ -17,7 +17,7 @@ def test_init_question_classifier_node_data(): "vision": {"enabled": True, "configs": {"variable_selector": ["image"], "detail": "low"}}, } - node_data = QuestionClassifierNodeData(**data) + node_data = QuestionClassifierNodeData.model_validate(data) assert node_data.query_variable_selector == ["id", "name"] assert node_data.model.provider == "openai" @@ -49,7 +49,7 @@ def test_init_question_classifier_node_data_without_vision_config(): }, } - node_data = QuestionClassifierNodeData(**data) + node_data = QuestionClassifierNodeData.model_validate(data) assert node_data.query_variable_selector == ["id", "name"] assert node_data.model.provider == "openai" diff --git a/api/tests/unit_tests/core/workflow/test_system_variable.py b/api/tests/unit_tests/core/workflow/test_system_variable.py index 11d788ed79..3ae5edb383 100644 --- a/api/tests/unit_tests/core/workflow/test_system_variable.py +++ b/api/tests/unit_tests/core/workflow/test_system_variable.py @@ -46,7 +46,7 @@ class TestSystemVariableSerialization: def test_basic_deserialization(self): """Test successful deserialization from JSON structure with all fields correctly mapped.""" # Test with complete data - system_var = SystemVariable(**COMPLETE_VALID_DATA) + system_var = SystemVariable.model_validate(COMPLETE_VALID_DATA) # Verify all fields are correctly mapped assert system_var.user_id == COMPLETE_VALID_DATA["user_id"] @@ -59,7 +59,7 @@ class TestSystemVariableSerialization: assert system_var.files == [] # Test with minimal data (only required fields) - minimal_var = SystemVariable(**VALID_BASE_DATA) + minimal_var = SystemVariable.model_validate(VALID_BASE_DATA) assert minimal_var.user_id == VALID_BASE_DATA["user_id"] assert minimal_var.app_id == VALID_BASE_DATA["app_id"] assert minimal_var.workflow_id == VALID_BASE_DATA["workflow_id"] @@ -75,12 +75,12 @@ class TestSystemVariableSerialization: # Test workflow_run_id only (preferred alias) data_run_id = {**VALID_BASE_DATA, "workflow_run_id": workflow_id} - system_var1 = SystemVariable(**data_run_id) + system_var1 = SystemVariable.model_validate(data_run_id) assert system_var1.workflow_execution_id == workflow_id # Test workflow_execution_id only (direct field name) data_execution_id = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id} - system_var2 = SystemVariable(**data_execution_id) + system_var2 = SystemVariable.model_validate(data_execution_id) assert system_var2.workflow_execution_id == workflow_id # Test both present - workflow_run_id should take precedence @@ -89,17 +89,17 @@ class TestSystemVariableSerialization: "workflow_execution_id": "should-be-ignored", "workflow_run_id": workflow_id, } - system_var3 = SystemVariable(**data_both) + system_var3 = SystemVariable.model_validate(data_both) assert system_var3.workflow_execution_id == workflow_id # Test neither present - should be None - system_var4 = SystemVariable(**VALID_BASE_DATA) + system_var4 = SystemVariable.model_validate(VALID_BASE_DATA) assert system_var4.workflow_execution_id is None def test_serialization_round_trip(self): """Test that serialize → deserialize produces the same result with alias handling.""" # Create original SystemVariable - original = SystemVariable(**COMPLETE_VALID_DATA) + original = SystemVariable.model_validate(COMPLETE_VALID_DATA) # Serialize to dict serialized = original.model_dump(mode="json") @@ -110,7 +110,7 @@ class TestSystemVariableSerialization: assert serialized["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"] # Deserialize back - deserialized = SystemVariable(**serialized) + deserialized = SystemVariable.model_validate(serialized) # Verify all fields match after round-trip assert deserialized.user_id == original.user_id @@ -125,7 +125,7 @@ class TestSystemVariableSerialization: def test_json_round_trip(self): """Test JSON serialization/deserialization consistency with proper structure.""" # Create original SystemVariable - original = SystemVariable(**COMPLETE_VALID_DATA) + original = SystemVariable.model_validate(COMPLETE_VALID_DATA) # Serialize to JSON string json_str = original.model_dump_json() @@ -137,7 +137,7 @@ class TestSystemVariableSerialization: assert json_data["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"] # Deserialize from JSON data - deserialized = SystemVariable(**json_data) + deserialized = SystemVariable.model_validate(json_data) # Verify key fields match after JSON round-trip assert deserialized.workflow_execution_id == original.workflow_execution_id @@ -149,13 +149,13 @@ class TestSystemVariableSerialization: """Test deserialization with File objects in the files field - SystemVariable specific logic.""" # Test with empty files list data_empty = {**VALID_BASE_DATA, "files": []} - system_var_empty = SystemVariable(**data_empty) + system_var_empty = SystemVariable.model_validate(data_empty) assert system_var_empty.files == [] # Test with single File object test_file = create_test_file() data_single = {**VALID_BASE_DATA, "files": [test_file]} - system_var_single = SystemVariable(**data_single) + system_var_single = SystemVariable.model_validate(data_single) assert len(system_var_single.files) == 1 assert system_var_single.files[0].filename == "test.txt" assert system_var_single.files[0].tenant_id == "test-tenant-id" @@ -179,14 +179,14 @@ class TestSystemVariableSerialization: ) data_multiple = {**VALID_BASE_DATA, "files": [file1, file2]} - system_var_multiple = SystemVariable(**data_multiple) + system_var_multiple = SystemVariable.model_validate(data_multiple) assert len(system_var_multiple.files) == 2 assert system_var_multiple.files[0].filename == "doc1.txt" assert system_var_multiple.files[1].filename == "image.jpg" # Verify files field serialization/deserialization serialized = system_var_multiple.model_dump(mode="json") - deserialized = SystemVariable(**serialized) + deserialized = SystemVariable.model_validate(serialized) assert len(deserialized.files) == 2 assert deserialized.files[0].filename == "doc1.txt" assert deserialized.files[1].filename == "image.jpg" @@ -197,7 +197,7 @@ class TestSystemVariableSerialization: # Create with workflow_run_id (alias) data_with_alias = {**VALID_BASE_DATA, "workflow_run_id": workflow_id} - system_var = SystemVariable(**data_with_alias) + system_var = SystemVariable.model_validate(data_with_alias) # Serialize and verify alias is used serialized = system_var.model_dump() @@ -205,7 +205,7 @@ class TestSystemVariableSerialization: assert "workflow_execution_id" not in serialized # Deserialize and verify field mapping - deserialized = SystemVariable(**serialized) + deserialized = SystemVariable.model_validate(serialized) assert deserialized.workflow_execution_id == workflow_id # Test JSON serialization path @@ -213,7 +213,7 @@ class TestSystemVariableSerialization: assert json_serialized["workflow_run_id"] == workflow_id assert "workflow_execution_id" not in json_serialized - json_deserialized = SystemVariable(**json_serialized) + json_deserialized = SystemVariable.model_validate(json_serialized) assert json_deserialized.workflow_execution_id == workflow_id def test_model_validator_serialization_logic(self): @@ -222,7 +222,7 @@ class TestSystemVariableSerialization: # Test direct instantiation with workflow_execution_id (should work) data1 = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id} - system_var1 = SystemVariable(**data1) + system_var1 = SystemVariable.model_validate(data1) assert system_var1.workflow_execution_id == workflow_id # Test serialization of the above (should use alias) @@ -236,7 +236,7 @@ class TestSystemVariableSerialization: "workflow_execution_id": "should-be-removed", "workflow_run_id": workflow_id, } - system_var2 = SystemVariable(**data2) + system_var2 = SystemVariable.model_validate(data2) assert system_var2.workflow_execution_id == workflow_id # Verify serialization consistency diff --git a/api/tests/unit_tests/services/test_metadata_bug_complete.py b/api/tests/unit_tests/services/test_metadata_bug_complete.py index 0ff1edc950..31fe9b2868 100644 --- a/api/tests/unit_tests/services/test_metadata_bug_complete.py +++ b/api/tests/unit_tests/services/test_metadata_bug_complete.py @@ -118,7 +118,7 @@ class TestMetadataBugCompleteValidation: # But would crash when trying to create MetadataArgs with pytest.raises((ValueError, TypeError)): - MetadataArgs(**args) + MetadataArgs.model_validate(args) def test_7_end_to_end_validation_layers(self): """Test all validation layers work together correctly.""" @@ -131,7 +131,7 @@ class TestMetadataBugCompleteValidation: valid_data = {"type": "string", "name": "test_metadata"} # Should create valid Pydantic object - metadata_args = MetadataArgs(**valid_data) + metadata_args = MetadataArgs.model_validate(valid_data) assert metadata_args.type == "string" assert metadata_args.name == "test_metadata" diff --git a/api/tests/unit_tests/services/test_metadata_nullable_bug.py b/api/tests/unit_tests/services/test_metadata_nullable_bug.py index d151100cf3..c8cd7025c2 100644 --- a/api/tests/unit_tests/services/test_metadata_nullable_bug.py +++ b/api/tests/unit_tests/services/test_metadata_nullable_bug.py @@ -76,7 +76,7 @@ class TestMetadataNullableBug: # Step 2: Try to create MetadataArgs with None values # This should fail at Pydantic validation level with pytest.raises((ValueError, TypeError)): - metadata_args = MetadataArgs(**args) + metadata_args = MetadataArgs.model_validate(args) # Step 3: If we bypass Pydantic (simulating the bug scenario) # Move this outside the request context to avoid Flask-Login issues