use model_validate (#26182)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2025-10-10 17:30:13 +09:00 committed by GitHub
parent aead192743
commit ab2eacb6c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
70 changed files with 260 additions and 241 deletions

View File

@ -90,7 +90,7 @@ class ModelConfigResource(Resource):
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue
agent_tool_entity = AgentToolEntity(**tool)
agent_tool_entity = AgentToolEntity.model_validate(tool)
# get tool
try:
tool_runtime = ToolManager.get_agent_tool_runtime(
@ -124,7 +124,7 @@ class ModelConfigResource(Resource):
# encrypt agent tool parameters if it's secret-input
agent_mode = new_app_model_config.agent_mode_dict
for tool in agent_mode.get("tools") or []:
agent_tool_entity = AgentToolEntity(**tool)
agent_tool_entity = AgentToolEntity.model_validate(tool)
# get tool
key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}"

View File

@ -15,7 +15,7 @@ from core.datasource.entities.datasource_entities import DatasourceProviderType,
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
from core.indexing_runner import IndexingRunner
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
@ -257,13 +257,15 @@ class DataSourceNotionApi(Resource):
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
},
notion_info=NotionInfo.model_validate(
{
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
}
),
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)

View File

@ -24,7 +24,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from fields.app_fields import related_app_list
@ -513,13 +513,15 @@ class DatasetIndexingEstimateApi(Resource):
for page in notion_info["pages"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
},
notion_info=NotionInfo.model_validate(
{
"credential_id": credential_id,
"notion_workspace_id": workspace_id,
"notion_obj_id": page["page_id"],
"notion_page_type": page["type"],
"tenant_id": current_user.current_tenant_id,
}
),
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)
@ -528,14 +530,16 @@ class DatasetIndexingEstimateApi(Resource):
for url in website_info_list["urls"]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],
"url": url,
"tenant_id": current_user.current_tenant_id,
"mode": "crawl",
"only_main_content": website_info_list["only_main_content"],
},
website_info=WebsiteInfo.model_validate(
{
"provider": website_info_list["provider"],
"job_id": website_info_list["job_id"],
"url": url,
"tenant_id": current_user.current_tenant_id,
"mode": "crawl",
"only_main_content": website_info_list["only_main_content"],
}
),
document_model=args["doc_form"],
)
extract_settings.append(extract_setting)

View File

@ -44,7 +44,7 @@ from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from extensions.ext_database import db
from fields.document_fields import (
dataset_and_document_fields,
@ -305,7 +305,7 @@ class DatasetDocumentListApi(Resource):
"doc_language", type=str, default="English", required=False, nullable=False, location="json"
)
args = parser.parse_args()
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
if not dataset.indexing_technique and not knowledge_config.indexing_technique:
raise ValueError("indexing_technique is required.")
@ -395,7 +395,7 @@ class DatasetInitApi(Resource):
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
if knowledge_config.indexing_technique == "high_quality":
if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
@ -547,13 +547,15 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_user.current_tenant_id,
},
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_user.current_tenant_id,
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
@ -562,14 +564,16 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_user.current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
},
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_user.current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)

View File

@ -309,7 +309,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
)
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
segment = SegmentService.update_segment(SegmentUpdateArgs.model_validate(args), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required
@ -564,7 +564,7 @@ class ChildChunkAddApi(Resource):
args = parser.parse_args()
try:
chunks_data = args["chunks"]
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in chunks_data]
chunks = [ChildChunkUpdateArgs.model_validate(chunk) for chunk in chunks_data]
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))

View File

@ -28,7 +28,7 @@ class DatasetMetadataCreateApi(Resource):
parser.add_argument("type", type=str, required=True, nullable=False, location="json")
parser.add_argument("name", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
metadata_args = MetadataArgs(**args)
metadata_args = MetadataArgs.model_validate(args)
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -137,7 +137,7 @@ class DocumentMetadataEditApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("operation_data", type=list, required=True, nullable=False, location="json")
args = parser.parse_args()
metadata_args = MetadataOperationData(**args)
metadata_args = MetadataOperationData.model_validate(args)
MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@ -88,7 +88,7 @@ class CustomizedPipelineTemplateApi(Resource):
nullable=True,
)
args = parser.parse_args()
pipeline_template_info = PipelineTemplateInfoEntity(**args)
pipeline_template_info = PipelineTemplateInfoEntity.model_validate(args)
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
return 200

View File

@ -128,7 +128,7 @@ def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseMo
raise ValueError("invalid json")
try:
payload = payload_type(**data)
payload = payload_type.model_validate(data)
except Exception as e:
raise ValueError(f"invalid payload: {str(e)}")

View File

@ -280,7 +280,7 @@ class DatasetListApi(DatasetApiResource):
external_knowledge_id=args["external_knowledge_id"],
embedding_model_provider=args["embedding_model_provider"],
embedding_model_name=args["embedding_model"],
retrieval_model=RetrievalModel(**args["retrieval_model"])
retrieval_model=RetrievalModel.model_validate(args["retrieval_model"])
if args["retrieval_model"] is not None
else None,
)

View File

@ -136,7 +136,7 @@ class DocumentAddByTextApi(DatasetApiResource):
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
}
args["data_source"] = data_source
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
# validate args
DocumentService.document_create_args_validate(knowledge_config)
@ -221,7 +221,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
args["data_source"] = data_source
# validate args
args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
try:
@ -328,7 +328,7 @@ class DocumentAddByFileApi(DatasetApiResource):
}
args["data_source"] = data_source
# validate args
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None
@ -426,7 +426,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
# validate args
args["original_document_id"] = str(document_id)
knowledge_config = KnowledgeConfig(**args)
knowledge_config = KnowledgeConfig.model_validate(args)
DocumentService.document_create_args_validate(knowledge_config)
try:

View File

@ -51,7 +51,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
def post(self, tenant_id, dataset_id):
"""Create metadata for a dataset."""
args = metadata_create_parser.parse_args()
metadata_args = MetadataArgs(**args)
metadata_args = MetadataArgs.model_validate(args)
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@ -200,7 +200,7 @@ class DocumentMetadataEditServiceApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user)
args = document_metadata_parser.parse_args()
metadata_args = MetadataOperationData(**args)
metadata_args = MetadataOperationData.model_validate(args)
MetadataService.update_documents_metadata(dataset, metadata_args)

View File

@ -98,7 +98,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
parser.add_argument("is_published", type=bool, required=True, location="json")
args: ParseResult = parser.parse_args()
datasource_node_run_api_entity: DatasourceNodeRunApiEntity = DatasourceNodeRunApiEntity(**args)
datasource_node_run_api_entity = DatasourceNodeRunApiEntity.model_validate(args)
assert isinstance(current_user, Account)
rag_pipeline_service: RagPipelineService = RagPipelineService()
pipeline: Pipeline = rag_pipeline_service.get_pipeline(tenant_id=tenant_id, dataset_id=dataset_id)

View File

@ -252,7 +252,7 @@ class DatasetSegmentApi(DatasetApiResource):
args = segment_update_parser.parse_args()
updated_segment = SegmentService.update_segment(
SegmentUpdateArgs(**args["segment"]), segment, document, dataset
SegmentUpdateArgs.model_validate(args["segment"]), segment, document, dataset
)
return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200

View File

@ -40,7 +40,7 @@ class AgentConfigManager:
"credential_id": tool.get("credential_id", None),
}
agent_tools.append(AgentToolEntity(**agent_tool_properties))
agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties))
if "strategy" in config["agent_mode"] and config["agent_mode"]["strategy"] not in {
"react_router",

View File

@ -116,7 +116,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
rag_pipeline_variables = []
if workflow.rag_pipeline_variables:
for v in workflow.rag_pipeline_variables:
rag_pipeline_variable = RAGPipelineVariable(**v)
rag_pipeline_variable = RAGPipelineVariable.model_validate(v)
if (
rag_pipeline_variable.belong_to_node_id
in (self.application_generate_entity.start_node_id, "shared")

View File

@ -1,4 +1,4 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
class I18nObject(BaseModel):
@ -11,11 +11,12 @@ class I18nObject(BaseModel):
pt_BR: str | None = Field(default=None)
ja_JP: str | None = Field(default=None)
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
self.zh_Hans = self.zh_Hans or self.en_US
self.pt_BR = self.pt_BR or self.en_US
self.ja_JP = self.ja_JP or self.en_US
return self
def to_dict(self) -> dict:
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}

View File

@ -5,7 +5,7 @@ from collections import defaultdict
from collections.abc import Iterator, Sequence
from json import JSONDecodeError
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator
from sqlalchemy import func, select
from sqlalchemy.orm import Session
@ -73,9 +73,8 @@ class ProviderConfiguration(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
if self.provider.provider not in original_provider_configurate_methods:
original_provider_configurate_methods[self.provider.provider] = []
for configurate_method in self.provider.configurate_methods:
@ -90,6 +89,7 @@ class ProviderConfiguration(BaseModel):
and ConfigurateMethod.PREDEFINED_MODEL not in self.provider.configurate_methods
):
self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL)
return self
def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None:
"""

View File

@ -131,7 +131,7 @@ class CodeExecutor:
if (code := response_data.get("code")) != 0:
raise CodeExecutionError(f"Got error code: {code}. Got error msg: {response_data.get('message')}")
response_code = CodeExecutionResponse(**response_data)
response_code = CodeExecutionResponse.model_validate(response_data)
if response_code.data.error:
raise CodeExecutionError(response_code.data.error)

View File

@ -26,7 +26,7 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
response.raise_for_status()
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
return [MarketplacePluginDeclaration.model_validate(plugin) for plugin in response.json()["data"]["plugins"]]
def batch_fetch_plugin_manifests_ignore_deserialization_error(
@ -41,7 +41,7 @@ def batch_fetch_plugin_manifests_ignore_deserialization_error(
result: list[MarketplacePluginDeclaration] = []
for plugin in response.json()["data"]["plugins"]:
try:
result.append(MarketplacePluginDeclaration(**plugin))
result.append(MarketplacePluginDeclaration.model_validate(plugin))
except Exception:
pass

View File

@ -20,7 +20,7 @@ from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.datasource_type import DatasourceType
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -357,14 +357,16 @@ class IndexingRunner:
raise ValueError("no notion import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION.value,
notion_info={
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"document": dataset_document,
"tenant_id": dataset_document.tenant_id,
},
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info["credential_id"],
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"document": dataset_document,
"tenant_id": dataset_document.tenant_id,
}
),
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
@ -378,14 +380,16 @@ class IndexingRunner:
raise ValueError("no website import info found")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE.value,
website_info={
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"tenant_id": dataset_document.tenant_id,
"url": data_source_info["url"],
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
},
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"tenant_id": dataset_document.tenant_id,
"url": data_source_info["url"],
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=dataset_document.doc_form,
)
text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])

View File

@ -294,7 +294,7 @@ class ClientSession(
method="completion/complete",
params=types.CompleteRequestParams(
ref=ref,
argument=types.CompletionArgument(**argument),
argument=types.CompletionArgument.model_validate(argument),
),
)
),

View File

@ -1,4 +1,4 @@
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
class I18nObject(BaseModel):
@ -9,7 +9,8 @@ class I18nObject(BaseModel):
zh_Hans: str | None = None
en_US: str
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
if not self.zh_Hans:
self.zh_Hans = self.en_US
return self

View File

@ -1,7 +1,7 @@
from collections.abc import Sequence
from enum import Enum, StrEnum, auto
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
@ -46,10 +46,11 @@ class FormOption(BaseModel):
value: str
show_on: list[FormShowOnObject] = []
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
if not self.label:
self.label = I18nObject(en_US=self.value)
return self
class CredentialFormSchema(BaseModel):

View File

@ -269,17 +269,17 @@ class ModelProviderFactory:
}
if model_type == ModelType.LLM:
return LargeLanguageModel(**init_params) # type: ignore
return LargeLanguageModel.model_validate(init_params)
elif model_type == ModelType.TEXT_EMBEDDING:
return TextEmbeddingModel(**init_params) # type: ignore
return TextEmbeddingModel.model_validate(init_params)
elif model_type == ModelType.RERANK:
return RerankModel(**init_params) # type: ignore
return RerankModel.model_validate(init_params)
elif model_type == ModelType.SPEECH2TEXT:
return Speech2TextModel(**init_params) # type: ignore
return Speech2TextModel.model_validate(init_params)
elif model_type == ModelType.MODERATION:
return ModerationModel(**init_params) # type: ignore
return ModerationModel.model_validate(init_params)
elif model_type == ModelType.TTS:
return TTSModel(**init_params) # type: ignore
return TTSModel.model_validate(init_params)
def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]:
"""

View File

@ -51,7 +51,7 @@ class ApiModeration(Moderation):
params = ModerationInputParams(app_id=self.app_id, inputs=inputs, query=query)
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.model_dump())
return ModerationInputsResult(**result)
return ModerationInputsResult.model_validate(result)
return ModerationInputsResult(
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response
@ -67,7 +67,7 @@ class ApiModeration(Moderation):
params = ModerationOutputParams(app_id=self.app_id, text=text)
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.model_dump())
return ModerationOutputsResult(**result)
return ModerationOutputsResult.model_validate(result)
return ModerationOutputsResult(
flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response

View File

@ -84,15 +84,15 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
for i in range(len(v)):
if v[i]["role"] == PromptMessageRole.USER.value:
v[i] = UserPromptMessage(**v[i])
v[i] = UserPromptMessage.model_validate(v[i])
elif v[i]["role"] == PromptMessageRole.ASSISTANT.value:
v[i] = AssistantPromptMessage(**v[i])
v[i] = AssistantPromptMessage.model_validate(v[i])
elif v[i]["role"] == PromptMessageRole.SYSTEM.value:
v[i] = SystemPromptMessage(**v[i])
v[i] = SystemPromptMessage.model_validate(v[i])
elif v[i]["role"] == PromptMessageRole.TOOL.value:
v[i] = ToolPromptMessage(**v[i])
v[i] = ToolPromptMessage.model_validate(v[i])
else:
v[i] = PromptMessage(**v[i])
v[i] = PromptMessage.model_validate(v[i])
return v

View File

@ -94,7 +94,7 @@ class BasePluginClient:
self,
method: str,
path: str,
type: type[T],
type_: type[T],
headers: dict | None = None,
data: bytes | dict | None = None,
params: dict | None = None,
@ -104,13 +104,13 @@ class BasePluginClient:
Make a stream request to the plugin daemon inner API and yield the response as a model.
"""
for line in self._stream_request(method, path, params, headers, data, files):
yield type(**json.loads(line)) # type: ignore
yield type_(**json.loads(line)) # type: ignore
def _request_with_model(
self,
method: str,
path: str,
type: type[T],
type_: type[T],
headers: dict | None = None,
data: bytes | None = None,
params: dict | None = None,
@ -120,13 +120,13 @@ class BasePluginClient:
Make a request to the plugin daemon inner API and return the response as a model.
"""
response = self._request(method, path, headers, data, params, files)
return type(**response.json()) # type: ignore
return type_(**response.json()) # type: ignore
def _request_with_plugin_daemon_response(
self,
method: str,
path: str,
type: type[T],
type_: type[T],
headers: dict | None = None,
data: bytes | dict | None = None,
params: dict | None = None,
@ -140,22 +140,22 @@ class BasePluginClient:
response = self._request(method, path, headers, data, params, files)
response.raise_for_status()
except HTTPError as e:
msg = f"Failed to request plugin daemon, status: {e.response.status_code}, url: {path}"
logger.exception(msg)
logger.exception("Failed to request plugin daemon, status: %s, url: %s", e.response.status_code, path)
raise e
except Exception as e:
msg = f"Failed to request plugin daemon, url: {path}"
logger.exception(msg)
logger.exception("Failed to request plugin daemon, url: %s", path)
raise ValueError(msg) from e
try:
json_response = response.json()
if transformer:
json_response = transformer(json_response)
rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore
# https://stackoverflow.com/questions/59634937/variable-foo-class-is-not-valid-as-type-but-why
rep = PluginDaemonBasicResponse[type_].model_validate(json_response) # type: ignore
except Exception:
msg = (
f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type.__name__)}],"
f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type_.__name__)}],"
f" url: {path}"
)
logger.exception(msg)
@ -163,7 +163,7 @@ class BasePluginClient:
if rep.code != 0:
try:
error = PluginDaemonError(**json.loads(rep.message))
error = PluginDaemonError.model_validate(json.loads(rep.message))
except Exception:
raise ValueError(f"{rep.message}, code: {rep.code}")
@ -178,7 +178,7 @@ class BasePluginClient:
self,
method: str,
path: str,
type: type[T],
type_: type[T],
headers: dict | None = None,
data: bytes | dict | None = None,
params: dict | None = None,
@ -189,7 +189,7 @@ class BasePluginClient:
"""
for line in self._stream_request(method, path, params, headers, data, files):
try:
rep = PluginDaemonBasicResponse[type].model_validate_json(line) # type: ignore
rep = PluginDaemonBasicResponse[type_].model_validate_json(line) # type: ignore
except (ValueError, TypeError):
# TODO modify this when line_data has code and message
try:
@ -204,7 +204,7 @@ class BasePluginClient:
if rep.code != 0:
if rep.code == -500:
try:
error = PluginDaemonError(**json.loads(rep.message))
error = PluginDaemonError.model_validate(json.loads(rep.message))
except Exception:
raise PluginDaemonInnerError(code=rep.code, message=rep.message)

View File

@ -46,7 +46,9 @@ class PluginDatasourceManager(BasePluginClient):
params={"page": 1, "page_size": 256},
transformer=transformer,
)
local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
local_file_datasource_provider = PluginDatasourceProviderEntity.model_validate(
self._get_local_file_datasource_provider()
)
for provider in response:
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
@ -104,7 +106,7 @@ class PluginDatasourceManager(BasePluginClient):
Fetch datasource provider for the given tenant and plugin.
"""
if provider_id == "langgenius/file/file":
return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
return PluginDatasourceProviderEntity.model_validate(self._get_local_file_datasource_provider())
tool_provider_id = DatasourceProviderID(provider_id)

View File

@ -162,7 +162,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/llm/invoke",
type=LLMResultChunk,
type_=LLMResultChunk,
data=jsonable_encoder(
{
"user_id": user_id,
@ -208,7 +208,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
type=PluginLLMNumTokensResponse,
type_=PluginLLMNumTokensResponse,
data=jsonable_encoder(
{
"user_id": user_id,
@ -250,7 +250,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
type=TextEmbeddingResult,
type_=TextEmbeddingResult,
data=jsonable_encoder(
{
"user_id": user_id,
@ -291,7 +291,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
type=PluginTextEmbeddingNumTokensResponse,
type_=PluginTextEmbeddingNumTokensResponse,
data=jsonable_encoder(
{
"user_id": user_id,
@ -334,7 +334,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
type=RerankResult,
type_=RerankResult,
data=jsonable_encoder(
{
"user_id": user_id,
@ -378,7 +378,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/tts/invoke",
type=PluginStringResultResponse,
type_=PluginStringResultResponse,
data=jsonable_encoder(
{
"user_id": user_id,
@ -422,7 +422,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
type=PluginVoicesResponse,
type_=PluginVoicesResponse,
data=jsonable_encoder(
{
"user_id": user_id,
@ -466,7 +466,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
type=PluginStringResultResponse,
type_=PluginStringResultResponse,
data=jsonable_encoder(
{
"user_id": user_id,
@ -506,7 +506,7 @@ class PluginModelClient(BasePluginClient):
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
type=PluginBasicBooleanResponse,
type_=PluginBasicBooleanResponse,
data=jsonable_encoder(
{
"user_id": user_id,

View File

@ -134,7 +134,7 @@ class RetrievalService:
if not dataset:
return []
metadata_condition = (
MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None
MetadataCondition.model_validate(metadata_filtering_conditions) if metadata_filtering_conditions else None
)
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
dataset.tenant_id,

View File

@ -17,9 +17,6 @@ class NotionInfo(BaseModel):
tenant_id: str
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **data):
super().__init__(**data)
class WebsiteInfo(BaseModel):
"""
@ -47,6 +44,3 @@ class ExtractSetting(BaseModel):
website_info: WebsiteInfo | None = None
document_model: str | None = None
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **data):
super().__init__(**data)

View File

@ -38,11 +38,11 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
raise ValueError("No process rule found.")
if process_rule.get("mode") == "automatic":
automatic_rule = DatasetProcessRule.AUTOMATIC_RULES
rules = Rule(**automatic_rule)
rules = Rule.model_validate(automatic_rule)
else:
if not process_rule.get("rules"):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules"))
rules = Rule.model_validate(process_rule.get("rules"))
# Split the text documents into nodes.
if not rules.segmentation:
raise ValueError("No segmentation found in rules.")

View File

@ -40,7 +40,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
raise ValueError("No process rule found.")
if not process_rule.get("rules"):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules"))
rules = Rule.model_validate(process_rule.get("rules"))
all_documents: list[Document] = []
if rules.parent_mode == ParentMode.PARAGRAPH:
# Split the text documents into nodes.
@ -110,7 +110,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_documents = document.children
if child_documents:
formatted_child_documents = [
Document(**child_document.model_dump()) for child_document in child_documents
Document.model_validate(child_document.model_dump()) for child_document in child_documents
]
vector.create(formatted_child_documents)
@ -224,7 +224,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
return child_nodes
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
parent_childs = ParentChildStructureChunk(**chunks)
parent_childs = ParentChildStructureChunk.model_validate(chunks)
documents = []
for parent_child in parent_childs.parent_child_chunks:
metadata = {
@ -274,7 +274,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
vector.create(all_child_documents)
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
parent_childs = ParentChildStructureChunk(**chunks)
parent_childs = ParentChildStructureChunk.model_validate(chunks)
preview = []
for parent_child in parent_childs.parent_child_chunks:
preview.append({"content": parent_child.parent_content, "child_chunks": parent_child.child_contents})

View File

@ -47,7 +47,7 @@ class QAIndexProcessor(BaseIndexProcessor):
raise ValueError("No process rule found.")
if not process_rule.get("rules"):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules"))
rules = Rule.model_validate(process_rule.get("rules"))
splitter = self._get_splitter(
processing_rule_mode=process_rule.get("mode"),
max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0,
@ -168,7 +168,7 @@ class QAIndexProcessor(BaseIndexProcessor):
return docs
def index(self, dataset: Dataset, document: DatasetDocument, chunks: Any):
qa_chunks = QAStructureChunk(**chunks)
qa_chunks = QAStructureChunk.model_validate(chunks)
documents = []
for qa_chunk in qa_chunks.qa_chunks:
metadata = {
@ -191,7 +191,7 @@ class QAIndexProcessor(BaseIndexProcessor):
raise ValueError("Indexing technique must be high quality.")
def format_preview(self, chunks: Any) -> Mapping[str, Any]:
qa_chunks = QAStructureChunk(**chunks)
qa_chunks = QAStructureChunk.model_validate(chunks)
preview = []
for qa_chunk in qa_chunks.qa_chunks:
preview.append({"question": qa_chunk.question, "answer": qa_chunk.answer})

View File

@ -90,7 +90,7 @@ class BuiltinToolProviderController(ToolProviderController):
tools.append(
assistant_tool_class(
provider=provider,
entity=ToolEntity(**tool),
entity=ToolEntity.model_validate(tool),
runtime=ToolRuntime(tenant_id=""),
)
)

View File

@ -1,4 +1,4 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
class I18nObject(BaseModel):
@ -11,11 +11,12 @@ class I18nObject(BaseModel):
pt_BR: str | None = Field(default=None)
ja_JP: str | None = Field(default=None)
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _populate_missing_locales(self):
self.zh_Hans = self.zh_Hans or self.en_US
self.pt_BR = self.pt_BR or self.en_US
self.ja_JP = self.ja_JP or self.en_US
return self
def to_dict(self):
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}

View File

@ -54,7 +54,7 @@ class MCPToolProviderController(ToolProviderController):
"""
tools = []
tools_data = json.loads(db_provider.tools)
remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data]
remote_mcp_tools = [RemoteMCPTool.model_validate(tool) for tool in tools_data]
user = db_provider.load_user()
tools = [
ToolEntity(

View File

@ -1008,7 +1008,7 @@ class ToolManager:
config = tool_configurations.get(parameter.name, {})
if not (config and isinstance(config, dict) and config.get("value") is not None):
continue
tool_input = ToolNodeData.ToolInput(**tool_configurations.get(parameter.name, {}))
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
if variable is None:

View File

@ -105,10 +105,10 @@ class RedisChannel:
command_type = CommandType(command_type_value)
if command_type == CommandType.ABORT:
return AbortCommand(**data)
return AbortCommand.model_validate(data)
else:
# For other command types, use base class
return GraphEngineCommand(**data)
return GraphEngineCommand.model_validate(data)
except (ValueError, TypeError):
return None

View File

@ -16,7 +16,7 @@ class EndNode(Node):
_node_data: EndNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = EndNodeData(**data)
self._node_data = EndNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@ -18,7 +18,7 @@ class IterationStartNode(Node):
_node_data: IterationStartNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = IterationStartNodeData(**data)
self._node_data = IterationStartNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@ -41,7 +41,7 @@ class ListOperatorNode(Node):
_node_data: ListOperatorNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = ListOperatorNodeData(**data)
self._node_data = ListOperatorNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@ -18,7 +18,7 @@ class LoopEndNode(Node):
_node_data: LoopEndNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopEndNodeData(**data)
self._node_data = LoopEndNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@ -18,7 +18,7 @@ class LoopStartNode(Node):
_node_data: LoopStartNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = LoopStartNodeData(**data)
self._node_data = LoopStartNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@ -16,7 +16,7 @@ class StartNode(Node):
_node_data: StartNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = StartNodeData(**data)
self._node_data = StartNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@ -15,7 +15,7 @@ class VariableAggregatorNode(Node):
_node_data: VariableAssignerNodeData
def init_node_data(self, data: Mapping[str, Any]):
self._node_data = VariableAssignerNodeData(**data)
self._node_data = VariableAssignerNodeData.model_validate(data)
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._node_data.error_strategy

View File

@ -14,7 +14,7 @@ def handle(sender, **kwargs):
for node_data in synced_draft_workflow.graph_dict.get("nodes", []):
if node_data.get("data", {}).get("type") == NodeType.TOOL.value:
try:
tool_entity = ToolEntity(**node_data["data"])
tool_entity = ToolEntity.model_validate(node_data["data"])
tool_runtime = ToolManager.get_tool_runtime(
provider_type=tool_entity.provider_type,
provider_id=tool_entity.provider_id,

View File

@ -61,7 +61,7 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[str]:
for node in knowledge_retrieval_nodes:
try:
node_data = KnowledgeRetrievalNodeData(**node.get("data", {}))
node_data = KnowledgeRetrievalNodeData.model_validate(node.get("data", {}))
dataset_ids.update(dataset_id for dataset_id in node_data.dataset_ids)
except Exception:
continue

View File

@ -754,7 +754,7 @@ class DocumentSegment(Base):
if process_rule and process_rule.mode == "hierarchical":
rules_dict = process_rule.rules_dict
if rules_dict:
rules = Rule(**rules_dict)
rules = Rule.model_validate(rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
@ -772,7 +772,7 @@ class DocumentSegment(Base):
if process_rule and process_rule.mode == "hierarchical":
rules_dict = process_rule.rules_dict
if rules_dict:
rules = Rule(**rules_dict)
rules = Rule.model_validate(rules_dict)
if rules.parent_mode:
child_chunks = (
db.session.query(ChildChunk)

View File

@ -152,7 +152,7 @@ class ApiToolProvider(Base):
def tools(self) -> list["ApiToolBundle"]:
from core.tools.entities.tool_bundle import ApiToolBundle
return [ApiToolBundle(**tool) for tool in json.loads(self.tools_str)]
return [ApiToolBundle.model_validate(tool) for tool in json.loads(self.tools_str)]
@property
def credentials(self) -> dict[str, Any]:
@ -242,7 +242,10 @@ class WorkflowToolProvider(Base):
def parameter_configurations(self) -> list["WorkflowToolParameterConfiguration"]:
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
return [WorkflowToolParameterConfiguration(**config) for config in json.loads(self.parameter_configuration)]
return [
WorkflowToolParameterConfiguration.model_validate(config)
for config in json.loads(self.parameter_configuration)
]
@property
def app(self) -> App | None:
@ -312,7 +315,7 @@ class MCPToolProvider(Base):
def mcp_tools(self) -> list["MCPTool"]:
from core.mcp.types import Tool as MCPTool
return [MCPTool(**tool) for tool in json.loads(self.tools)]
return [MCPTool.model_validate(tool) for tool in json.loads(self.tools)]
@property
def provider_icon(self) -> Mapping[str, str] | str:
@ -552,4 +555,4 @@ class DeprecatedPublishedAppTool(Base):
def description_i18n(self) -> "I18nObject":
from core.tools.entities.common_entities import I18nObject
return I18nObject(**json.loads(self.description))
return I18nObject.model_validate(json.loads(self.description))

View File

@ -659,31 +659,31 @@ class AppDslService:
typ = node.get("data", {}).get("type")
match typ:
case NodeType.TOOL.value:
tool_entity = ToolNodeData(**node["data"])
tool_entity = ToolNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id),
)
case NodeType.LLM.value:
llm_entity = LLMNodeData(**node["data"])
llm_entity = LLMNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider),
)
case NodeType.QUESTION_CLASSIFIER.value:
question_classifier_entity = QuestionClassifierNodeData(**node["data"])
question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
question_classifier_entity.model.provider
),
)
case NodeType.PARAMETER_EXTRACTOR.value:
parameter_extractor_entity = ParameterExtractorNodeData(**node["data"])
parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
parameter_extractor_entity.model.provider
),
)
case NodeType.KNOWLEDGE_RETRIEVAL.value:
knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"])
knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"])
if knowledge_retrieval_entity.retrieval_mode == "multiple":
if knowledge_retrieval_entity.multiple_retrieval_config:
if (
@ -773,7 +773,7 @@ class AppDslService:
"""
Returns the leaked dependencies in current workspace
"""
dependencies = [PluginDependency(**dep) for dep in dsl_dependencies]
dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies]
if not dependencies:
return []

View File

@ -70,7 +70,7 @@ class EnterpriseService:
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params)
if not data:
raise ValueError("No data found.")
return WebAppSettings(**data)
return WebAppSettings.model_validate(data)
@classmethod
def batch_get_app_access_mode_by_id(cls, app_ids: list[str]) -> dict[str, WebAppSettings]:
@ -100,7 +100,7 @@ class EnterpriseService:
data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params)
if not data:
raise ValueError("No data found.")
return WebAppSettings(**data)
return WebAppSettings.model_validate(data)
@classmethod
def update_app_access_mode(cls, app_id: str, access_mode: str):

View File

@ -1,6 +1,7 @@
from collections.abc import Sequence
from enum import Enum
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, model_validator
from configs import dify_config
from core.entities.model_entities import (
@ -71,7 +72,7 @@ class ProviderResponse(BaseModel):
icon_large: I18nObject | None = None
background: str | None = None
help: ProviderHelpEntity | None = None
supported_model_types: list[ModelType]
supported_model_types: Sequence[ModelType]
configurate_methods: list[ConfigurateMethod]
provider_credential_schema: ProviderCredentialSchema | None = None
model_credential_schema: ModelCredentialSchema | None = None
@ -82,9 +83,8 @@ class ProviderResponse(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
url_prefix = (
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
)
@ -97,6 +97,7 @@ class ProviderResponse(BaseModel):
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
return self
class ProviderWithModelsResponse(BaseModel):
@ -112,9 +113,8 @@ class ProviderWithModelsResponse(BaseModel):
status: CustomConfigurationStatus
models: list[ProviderModelWithStatusEntity]
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
url_prefix = (
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
)
@ -127,6 +127,7 @@ class ProviderWithModelsResponse(BaseModel):
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
return self
class SimpleProviderEntityResponse(SimpleProviderEntity):
@ -136,9 +137,8 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
tenant_id: str
def __init__(self, **data):
super().__init__(**data)
@model_validator(mode="after")
def _(self):
url_prefix = (
dify_config.CONSOLE_API_URL + f"/console/api/workspaces/{self.tenant_id}/model-providers/{self.provider}"
)
@ -151,6 +151,7 @@ class SimpleProviderEntityResponse(SimpleProviderEntity):
self.icon_large = I18nObject(
en_US=f"{url_prefix}/icon_large/en_US", zh_Hans=f"{url_prefix}/icon_large/zh_Hans"
)
return self
class DefaultModelResponse(BaseModel):

View File

@ -46,7 +46,7 @@ class HitTestingService:
from core.app.app_config.entities import MetadataFilteringCondition
metadata_filtering_conditions = MetadataFilteringCondition(**metadata_filtering_conditions)
metadata_filtering_conditions = MetadataFilteringCondition.model_validate(metadata_filtering_conditions)
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
dataset_ids=[dataset.id],

View File

@ -123,7 +123,7 @@ class OpsService:
config_class: type[BaseTracingConfig] = provider_config["config_class"]
other_keys: list[str] = provider_config["other_keys"]
default_config_instance: BaseTracingConfig = config_class(**tracing_config)
default_config_instance = config_class.model_validate(tracing_config)
for key in other_keys:
if key in tracing_config and tracing_config[key] == "":
tracing_config[key] = getattr(default_config_instance, key, None)

View File

@ -269,7 +269,7 @@ class PluginMigration:
for tool in agent_config["tools"]:
if isinstance(tool, dict):
try:
tool_entity = AgentToolEntity(**tool)
tool_entity = AgentToolEntity.model_validate(tool)
if (
tool_entity.provider_type == ToolProviderType.BUILT_IN.value
and tool_entity.provider_id not in excluded_providers

View File

@ -358,7 +358,7 @@ class RagPipelineService:
for node in nodes:
if node.get("data", {}).get("type") == "knowledge-index":
knowledge_configuration = node.get("data", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration)
# update dataset
dataset = pipeline.retrieve_dataset(session=session)

View File

@ -288,7 +288,7 @@ class RagPipelineDslService:
dataset_id = None
for node in nodes:
if node.get("data", {}).get("type") == "knowledge-index":
knowledge_configuration = KnowledgeConfiguration(**node.get("data", {}))
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
if (
dataset
and pipeline.is_published
@ -426,7 +426,7 @@ class RagPipelineDslService:
dataset_id = None
for node in nodes:
if node.get("data", {}).get("type") == "knowledge-index":
knowledge_configuration = KnowledgeConfiguration(**node.get("data", {}))
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
if not dataset:
dataset = Dataset(
tenant_id=account.current_tenant_id,
@ -734,35 +734,35 @@ class RagPipelineDslService:
typ = node.get("data", {}).get("type")
match typ:
case NodeType.TOOL.value:
tool_entity = ToolNodeData(**node["data"])
tool_entity = ToolNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id),
)
case NodeType.DATASOURCE.value:
datasource_entity = DatasourceNodeData(**node["data"])
datasource_entity = DatasourceNodeData.model_validate(node["data"])
if datasource_entity.provider_type != "local_file":
dependencies.append(datasource_entity.plugin_id)
case NodeType.LLM.value:
llm_entity = LLMNodeData(**node["data"])
llm_entity = LLMNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider),
)
case NodeType.QUESTION_CLASSIFIER.value:
question_classifier_entity = QuestionClassifierNodeData(**node["data"])
question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
question_classifier_entity.model.provider
),
)
case NodeType.PARAMETER_EXTRACTOR.value:
parameter_extractor_entity = ParameterExtractorNodeData(**node["data"])
parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"])
dependencies.append(
DependenciesAnalysisService.analyze_model_provider_dependency(
parameter_extractor_entity.model.provider
),
)
case NodeType.KNOWLEDGE_INDEX.value:
knowledge_index_entity = KnowledgeConfiguration(**node["data"])
knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"])
if knowledge_index_entity.indexing_technique == "high_quality":
if knowledge_index_entity.embedding_model_provider:
dependencies.append(
@ -783,7 +783,7 @@ class RagPipelineDslService:
),
)
case NodeType.KNOWLEDGE_RETRIEVAL.value:
knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"])
knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"])
if knowledge_retrieval_entity.retrieval_mode == "multiple":
if knowledge_retrieval_entity.multiple_retrieval_config:
if (
@ -873,7 +873,7 @@ class RagPipelineDslService:
"""
Returns the leaked dependencies in current workspace
"""
dependencies = [PluginDependency(**dep) for dep in dsl_dependencies]
dependencies = [PluginDependency.model_validate(dep) for dep in dsl_dependencies]
if not dependencies:
return []

View File

@ -156,13 +156,13 @@ class RagPipelineTransformService:
self, dataset: Dataset, doc_form: str, indexing_technique: str | None, retrieval_model: dict, node: dict
):
knowledge_configuration_dict = node.get("data", {})
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration_dict)
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict)
if indexing_technique == "high_quality":
knowledge_configuration.embedding_model = dataset.embedding_model
knowledge_configuration.embedding_model_provider = dataset.embedding_model_provider
if retrieval_model:
retrieval_setting = RetrievalSetting(**retrieval_model)
retrieval_setting = RetrievalSetting.model_validate(retrieval_model)
if indexing_technique == "economy":
retrieval_setting.search_method = "keyword_search"
knowledge_configuration.retrieval_model = retrieval_setting

View File

@ -242,7 +242,7 @@ class ToolTransformService:
is_team_authorization=db_provider.authed,
server_url=db_provider.masked_server_url,
tools=ToolTransformService.mcp_tool_to_user_tool(
db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
db_provider, [MCPTool.model_validate(tool) for tool in json.loads(db_provider.tools)]
),
updated_at=int(db_provider.updated_at.timestamp()),
label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
@ -387,6 +387,7 @@ class ToolTransformService:
labels=labels or [],
)
else:
assert tool.operation_id
return ToolApiEntity(
author=tool.author,
name=tool.operation_id or "",

View File

@ -36,7 +36,7 @@ def process_trace_tasks(file_info):
if trace_info.get("workflow_data"):
trace_info["workflow_data"] = WorkflowRun.from_dict(data=trace_info["workflow_data"])
if trace_info.get("documents"):
trace_info["documents"] = [Document(**doc) for doc in trace_info["documents"]]
trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]]
try:
if trace_instance:

View File

@ -79,7 +79,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
# Create Flask application context for this thread
with flask_app.app_context():
try:
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity)
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity.model_validate(rag_pipeline_invoke_entity)
user_id = rag_pipeline_invoke_entity_model.user_id
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
@ -112,7 +112,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
workflow_execution_id = str(uuid.uuid4())
# Create application generate entity from dict
entity = RagPipelineGenerateEntity(**application_generate_entity)
entity = RagPipelineGenerateEntity.model_validate(application_generate_entity)
# Create workflow repositories
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)

View File

@ -100,7 +100,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
# Create Flask application context for this thread
with flask_app.app_context():
try:
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity)
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity.model_validate(rag_pipeline_invoke_entity)
user_id = rag_pipeline_invoke_entity_model.user_id
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
@ -133,7 +133,7 @@ def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any],
workflow_execution_id = str(uuid.uuid4())
# Create application generate entity from dict
entity = RagPipelineGenerateEntity(**application_generate_entity)
entity = RagPipelineGenerateEntity.model_validate(application_generate_entity)
# Create workflow repositories
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)

View File

@ -36,7 +36,7 @@ def test_api_tool(setup_http_mock):
entity=ToolEntity(
identity=ToolIdentity(provider="", author="", name="", label=I18nObject(en_US="test tool")),
),
api_bundle=ApiToolBundle(**tool_bundle),
api_bundle=ApiToolBundle.model_validate(tool_bundle),
runtime=ToolRuntime(tenant_id="", credentials={"auth_type": "none"}),
provider_id="test_tool",
)

View File

@ -11,8 +11,8 @@ def test_default_value():
config = valid_config.copy()
del config[key]
with pytest.raises(ValidationError) as e:
MilvusConfig(**config)
MilvusConfig.model_validate(config)
assert e.value.errors()[0]["msg"] == f"Value error, config MILVUS_{key.upper()} is required"
config = MilvusConfig(**valid_config)
config = MilvusConfig.model_validate(valid_config)
assert config.database == "default"

View File

@ -35,7 +35,7 @@ def list_operator_node():
"extract_by": ExtractConfig(enabled=False, serial="1"),
"title": "Test Title",
}
node_data = ListOperatorNodeData(**config)
node_data = ListOperatorNodeData.model_validate(config)
node_config = {
"id": "test_node_id",
"data": node_data.model_dump(),

View File

@ -17,7 +17,7 @@ def test_init_question_classifier_node_data():
"vision": {"enabled": True, "configs": {"variable_selector": ["image"], "detail": "low"}},
}
node_data = QuestionClassifierNodeData(**data)
node_data = QuestionClassifierNodeData.model_validate(data)
assert node_data.query_variable_selector == ["id", "name"]
assert node_data.model.provider == "openai"
@ -49,7 +49,7 @@ def test_init_question_classifier_node_data_without_vision_config():
},
}
node_data = QuestionClassifierNodeData(**data)
node_data = QuestionClassifierNodeData.model_validate(data)
assert node_data.query_variable_selector == ["id", "name"]
assert node_data.model.provider == "openai"

View File

@ -46,7 +46,7 @@ class TestSystemVariableSerialization:
def test_basic_deserialization(self):
"""Test successful deserialization from JSON structure with all fields correctly mapped."""
# Test with complete data
system_var = SystemVariable(**COMPLETE_VALID_DATA)
system_var = SystemVariable.model_validate(COMPLETE_VALID_DATA)
# Verify all fields are correctly mapped
assert system_var.user_id == COMPLETE_VALID_DATA["user_id"]
@ -59,7 +59,7 @@ class TestSystemVariableSerialization:
assert system_var.files == []
# Test with minimal data (only required fields)
minimal_var = SystemVariable(**VALID_BASE_DATA)
minimal_var = SystemVariable.model_validate(VALID_BASE_DATA)
assert minimal_var.user_id == VALID_BASE_DATA["user_id"]
assert minimal_var.app_id == VALID_BASE_DATA["app_id"]
assert minimal_var.workflow_id == VALID_BASE_DATA["workflow_id"]
@ -75,12 +75,12 @@ class TestSystemVariableSerialization:
# Test workflow_run_id only (preferred alias)
data_run_id = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
system_var1 = SystemVariable(**data_run_id)
system_var1 = SystemVariable.model_validate(data_run_id)
assert system_var1.workflow_execution_id == workflow_id
# Test workflow_execution_id only (direct field name)
data_execution_id = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
system_var2 = SystemVariable(**data_execution_id)
system_var2 = SystemVariable.model_validate(data_execution_id)
assert system_var2.workflow_execution_id == workflow_id
# Test both present - workflow_run_id should take precedence
@ -89,17 +89,17 @@ class TestSystemVariableSerialization:
"workflow_execution_id": "should-be-ignored",
"workflow_run_id": workflow_id,
}
system_var3 = SystemVariable(**data_both)
system_var3 = SystemVariable.model_validate(data_both)
assert system_var3.workflow_execution_id == workflow_id
# Test neither present - should be None
system_var4 = SystemVariable(**VALID_BASE_DATA)
system_var4 = SystemVariable.model_validate(VALID_BASE_DATA)
assert system_var4.workflow_execution_id is None
def test_serialization_round_trip(self):
"""Test that serialize → deserialize produces the same result with alias handling."""
# Create original SystemVariable
original = SystemVariable(**COMPLETE_VALID_DATA)
original = SystemVariable.model_validate(COMPLETE_VALID_DATA)
# Serialize to dict
serialized = original.model_dump(mode="json")
@ -110,7 +110,7 @@ class TestSystemVariableSerialization:
assert serialized["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
# Deserialize back
deserialized = SystemVariable(**serialized)
deserialized = SystemVariable.model_validate(serialized)
# Verify all fields match after round-trip
assert deserialized.user_id == original.user_id
@ -125,7 +125,7 @@ class TestSystemVariableSerialization:
def test_json_round_trip(self):
"""Test JSON serialization/deserialization consistency with proper structure."""
# Create original SystemVariable
original = SystemVariable(**COMPLETE_VALID_DATA)
original = SystemVariable.model_validate(COMPLETE_VALID_DATA)
# Serialize to JSON string
json_str = original.model_dump_json()
@ -137,7 +137,7 @@ class TestSystemVariableSerialization:
assert json_data["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
# Deserialize from JSON data
deserialized = SystemVariable(**json_data)
deserialized = SystemVariable.model_validate(json_data)
# Verify key fields match after JSON round-trip
assert deserialized.workflow_execution_id == original.workflow_execution_id
@ -149,13 +149,13 @@ class TestSystemVariableSerialization:
"""Test deserialization with File objects in the files field - SystemVariable specific logic."""
# Test with empty files list
data_empty = {**VALID_BASE_DATA, "files": []}
system_var_empty = SystemVariable(**data_empty)
system_var_empty = SystemVariable.model_validate(data_empty)
assert system_var_empty.files == []
# Test with single File object
test_file = create_test_file()
data_single = {**VALID_BASE_DATA, "files": [test_file]}
system_var_single = SystemVariable(**data_single)
system_var_single = SystemVariable.model_validate(data_single)
assert len(system_var_single.files) == 1
assert system_var_single.files[0].filename == "test.txt"
assert system_var_single.files[0].tenant_id == "test-tenant-id"
@ -179,14 +179,14 @@ class TestSystemVariableSerialization:
)
data_multiple = {**VALID_BASE_DATA, "files": [file1, file2]}
system_var_multiple = SystemVariable(**data_multiple)
system_var_multiple = SystemVariable.model_validate(data_multiple)
assert len(system_var_multiple.files) == 2
assert system_var_multiple.files[0].filename == "doc1.txt"
assert system_var_multiple.files[1].filename == "image.jpg"
# Verify files field serialization/deserialization
serialized = system_var_multiple.model_dump(mode="json")
deserialized = SystemVariable(**serialized)
deserialized = SystemVariable.model_validate(serialized)
assert len(deserialized.files) == 2
assert deserialized.files[0].filename == "doc1.txt"
assert deserialized.files[1].filename == "image.jpg"
@ -197,7 +197,7 @@ class TestSystemVariableSerialization:
# Create with workflow_run_id (alias)
data_with_alias = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
system_var = SystemVariable(**data_with_alias)
system_var = SystemVariable.model_validate(data_with_alias)
# Serialize and verify alias is used
serialized = system_var.model_dump()
@ -205,7 +205,7 @@ class TestSystemVariableSerialization:
assert "workflow_execution_id" not in serialized
# Deserialize and verify field mapping
deserialized = SystemVariable(**serialized)
deserialized = SystemVariable.model_validate(serialized)
assert deserialized.workflow_execution_id == workflow_id
# Test JSON serialization path
@ -213,7 +213,7 @@ class TestSystemVariableSerialization:
assert json_serialized["workflow_run_id"] == workflow_id
assert "workflow_execution_id" not in json_serialized
json_deserialized = SystemVariable(**json_serialized)
json_deserialized = SystemVariable.model_validate(json_serialized)
assert json_deserialized.workflow_execution_id == workflow_id
def test_model_validator_serialization_logic(self):
@ -222,7 +222,7 @@ class TestSystemVariableSerialization:
# Test direct instantiation with workflow_execution_id (should work)
data1 = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
system_var1 = SystemVariable(**data1)
system_var1 = SystemVariable.model_validate(data1)
assert system_var1.workflow_execution_id == workflow_id
# Test serialization of the above (should use alias)
@ -236,7 +236,7 @@ class TestSystemVariableSerialization:
"workflow_execution_id": "should-be-removed",
"workflow_run_id": workflow_id,
}
system_var2 = SystemVariable(**data2)
system_var2 = SystemVariable.model_validate(data2)
assert system_var2.workflow_execution_id == workflow_id
# Verify serialization consistency

View File

@ -118,7 +118,7 @@ class TestMetadataBugCompleteValidation:
# But would crash when trying to create MetadataArgs
with pytest.raises((ValueError, TypeError)):
MetadataArgs(**args)
MetadataArgs.model_validate(args)
def test_7_end_to_end_validation_layers(self):
"""Test all validation layers work together correctly."""
@ -131,7 +131,7 @@ class TestMetadataBugCompleteValidation:
valid_data = {"type": "string", "name": "test_metadata"}
# Should create valid Pydantic object
metadata_args = MetadataArgs(**valid_data)
metadata_args = MetadataArgs.model_validate(valid_data)
assert metadata_args.type == "string"
assert metadata_args.name == "test_metadata"

View File

@ -76,7 +76,7 @@ class TestMetadataNullableBug:
# Step 2: Try to create MetadataArgs with None values
# This should fail at Pydantic validation level
with pytest.raises((ValueError, TypeError)):
metadata_args = MetadataArgs(**args)
metadata_args = MetadataArgs.model_validate(args)
# Step 3: If we bypass Pydantic (simulating the bug scenario)
# Move this outside the request context to avoid Flask-Login issues