[CHORE]: remove redundant-cast (#24807)

This commit is contained in:
willzhao 2025-09-01 14:05:32 +08:00 committed by GitHub
parent f11131f8b5
commit ffba341258
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 54 additions and 90 deletions

View File

@ -140,7 +140,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables),
conversation_variables=conversation_variables,
)
# init graph

View File

@ -3,7 +3,7 @@ import base64
from libs import rsa
def obfuscated_token(token: str):
def obfuscated_token(token: str) -> str:
if not token:
return token
if len(token) <= 8:

View File

@ -158,8 +158,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast(
Union[LLMResult, Generator],
self._round_robin_invoke(
@ -188,8 +186,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, LargeLanguageModel):
raise Exception("Model type instance is not LargeLanguageModel")
self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance)
return cast(
int,
self._round_robin_invoke(
@ -214,8 +210,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast(
TextEmbeddingResult,
self._round_robin_invoke(
@ -237,8 +231,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TextEmbeddingModel):
raise Exception("Model type instance is not TextEmbeddingModel")
self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance)
return cast(
list[int],
self._round_robin_invoke(
@ -269,8 +261,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, RerankModel):
raise Exception("Model type instance is not RerankModel")
self.model_type_instance = cast(RerankModel, self.model_type_instance)
return cast(
RerankResult,
self._round_robin_invoke(
@ -295,8 +285,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, ModerationModel):
raise Exception("Model type instance is not ModerationModel")
self.model_type_instance = cast(ModerationModel, self.model_type_instance)
return cast(
bool,
self._round_robin_invoke(
@ -318,8 +306,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, Speech2TextModel):
raise Exception("Model type instance is not Speech2TextModel")
self.model_type_instance = cast(Speech2TextModel, self.model_type_instance)
return cast(
str,
self._round_robin_invoke(
@ -343,8 +329,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return cast(
Iterable[bytes],
self._round_robin_invoke(
@ -404,8 +388,6 @@ class ModelInstance:
"""
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
self.model_type_instance = cast(TTSModel, self.model_type_instance)
return self.model_type_instance.get_tts_model_voices(
model=self.model, credentials=self.credentials, language=language
)

View File

@ -87,7 +87,6 @@ class PromptMessageUtil:
if isinstance(prompt_message.content, list):
for content in prompt_message.content:
if content.type == PromptMessageContentType.TEXT:
content = cast(TextPromptMessageContent, content)
text += content.data
else:
content = cast(ImagePromptMessageContent, content)

View File

@ -2,7 +2,7 @@ import contextlib
import json
from collections import defaultdict
from json import JSONDecodeError
from typing import Any, Optional, cast
from typing import Any, Optional
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
@ -154,8 +154,8 @@ class ProviderManager:
for provider_entity in provider_entities:
# handle include, exclude
if is_filtered(
include_set=cast(set[str], dify_config.POSITION_PROVIDER_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_PROVIDER_EXCLUDES_SET),
include_set=dify_config.POSITION_PROVIDER_INCLUDES_SET,
exclude_set=dify_config.POSITION_PROVIDER_EXCLUDES_SET,
data=provider_entity,
name_func=lambda x: x.provider,
):

View File

@ -3,7 +3,7 @@ import os
import uuid
from collections.abc import Generator, Iterable, Sequence
from itertools import islice
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Any, Optional, Union
import qdrant_client
from flask import current_app
@ -426,7 +426,6 @@ class QdrantVector(BaseVector):
def _reload_if_needed(self):
if isinstance(self._client, QdrantLocal):
self._client = cast(QdrantLocal, self._client)
self._client._load()
@classmethod

View File

@ -2,7 +2,7 @@
import re
from pathlib import Path
from typing import Optional, cast
from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.extractor.helpers import detect_file_encodings
@ -76,7 +76,7 @@ class MarkdownExtractor(BaseExtractor):
markdown_tups.append((current_header, current_text))
markdown_tups = [
(re.sub(r"#", "", cast(str, key)).strip() if key else None, re.sub(r"<.*?>", "", value))
(re.sub(r"#", "", key).strip() if key else None, re.sub(r"<.*?>", "", value))
for key, value in markdown_tups
]

View File

@ -385,4 +385,4 @@ class NotionExtractor(BaseExtractor):
f"No notion data source binding found for tenant {tenant_id} and notion workspace {notion_workspace_id}"
)
return cast(str, data_source_binding.access_token)
return data_source_binding.access_token

View File

@ -2,7 +2,7 @@
import contextlib
from collections.abc import Iterator
from typing import Optional, cast
from typing import Optional
from core.rag.extractor.blob.blob import Blob
from core.rag.extractor.extractor_base import BaseExtractor
@ -27,7 +27,7 @@ class PdfExtractor(BaseExtractor):
plaintext_file_exists = False
if self._file_cache_key:
with contextlib.suppress(FileNotFoundError):
text = cast(bytes, storage.load(self._file_cache_key)).decode("utf-8")
text = storage.load(self._file_cache_key).decode("utf-8")
plaintext_file_exists = True
return [Document(page_content=text)]
documents = list(self.load())

View File

@ -331,16 +331,13 @@ class ToolManager:
if controller_tools is None or len(controller_tools) == 0:
raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
return cast(
WorkflowTool,
controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
),
return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
)
elif provider_type == ToolProviderType.APP:
raise NotImplementedError("app provider not implemented")
@ -648,8 +645,8 @@ class ToolManager:
for provider in builtin_providers:
# handle include, exclude
if is_filtered(
include_set=cast(set[str], dify_config.POSITION_TOOL_INCLUDES_SET),
exclude_set=cast(set[str], dify_config.POSITION_TOOL_EXCLUDES_SET),
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider,
name_func=lambda x: x.identity.name,
):

View File

@ -3,7 +3,7 @@ from collections.abc import Generator
from datetime import date, datetime
from decimal import Decimal
from mimetypes import guess_extension
from typing import Optional, cast
from typing import Optional
from uuid import UUID
import numpy as np
@ -159,8 +159,7 @@ class ToolFileMessageTransformer:
elif message.type == ToolInvokeMessage.MessageType.JSON:
if isinstance(message.message, ToolInvokeMessage.JsonMessage):
json_msg = cast(ToolInvokeMessage.JsonMessage, message.message)
json_msg.json_object = safe_json_value(json_msg.json_object)
message.message.json_object = safe_json_value(message.message.json_object)
yield message
else:
yield message

View File

@ -129,17 +129,14 @@ class ModelInvocationUtils:
db.session.commit()
try:
response: LLMResult = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
),
response: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
)
except InvokeRateLimitError as e:
raise InvokeModelError(f"Invoke rate limit error: {e}")

View File

@ -1,7 +1,7 @@
import json
import logging
from collections.abc import Generator
from typing import Any, Optional, cast
from typing import Any, Optional
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool
@ -204,14 +204,14 @@ class WorkflowTool(Tool):
item = self._update_file_mapping(item)
file = build_from_mapping(
mapping=item,
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
tenant_id=str(self.runtime.tenant_id),
)
files.append(file)
elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
value = self._update_file_mapping(value)
file = build_from_mapping(
mapping=value,
tenant_id=str(cast(ToolRuntime, self.runtime).tenant_id),
tenant_id=str(self.runtime.tenant_id),
)
files.append(file)

View File

@ -1,5 +1,5 @@
from collections.abc import Sequence
from typing import Annotated, TypeAlias, cast
from typing import Annotated, TypeAlias
from uuid import uuid4
from pydantic import Discriminator, Field, Tag
@ -86,7 +86,7 @@ class SecretVariable(StringVariable):
@property
def log(self) -> str:
return cast(str, encrypter.obfuscated_token(self.value))
return encrypter.obfuscated_token(self.value)
class NoneVariable(NoneSegment, Variable):

View File

@ -374,7 +374,7 @@ class GraphEngine:
if len(sub_edge_mappings) == 0:
continue
edge = cast(GraphEdge, sub_edge_mappings[0])
edge = sub_edge_mappings[0]
if edge.run_condition is None:
logger.warning("Edge %s run condition is None", edge.target_node_id)
continue

View File

@ -153,7 +153,7 @@ class AgentNode(BaseNode):
messages=message_stream,
tool_info={
"icon": self.agent_strategy_icon,
"agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
"agent_strategy": self._node_data.agent_strategy_name,
},
parameters_for_log=parameters_for_log,
user_id=self.user_id,
@ -394,8 +394,7 @@ class AgentNode(BaseNode):
current_plugin = next(
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}"
== cast(AgentNodeData, self._node_data).agent_strategy_provider_name
if f"{plugin.plugin_id}/{plugin.name}" == self._node_data.agent_strategy_provider_name
)
icon = current_plugin.declaration.icon
except StopIteration:

View File

@ -302,12 +302,12 @@ def _extract_text_from_yaml(file_content: bytes) -> str:
encoding = "utf-8"
yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore"))
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e:
# If decoding fails, try with utf-8 as last resort
try:
yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore"))
return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False))
return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)
except (UnicodeDecodeError, yaml.YAMLError):
raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e

View File

@ -139,7 +139,7 @@ class ParameterExtractorNode(BaseNode):
"""
Run the node.
"""
node_data = cast(ParameterExtractorNodeData, self._node_data)
node_data = self._node_data
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else ""

View File

@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
from typing import TYPE_CHECKING, Any, Optional
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@ -109,7 +109,7 @@ class QuestionClassifierNode(BaseNode):
return "1"
def _run(self):
node_data = cast(QuestionClassifierNodeData, self._node_data)
node_data = self._node_data
variable_pool = self.graph_runtime_state.variable_pool
# extract variables

View File

@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, Optional
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -57,7 +57,7 @@ class ToolNode(BaseNode):
Run the tool node
"""
node_data = cast(ToolNodeData, self._node_data)
node_data = self._node_data
# fetch tool icon
tool_info = {

View File

@ -2,7 +2,7 @@ import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from typing import Any, Optional
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
@ -261,7 +261,6 @@ class WorkflowEntry:
environment_variables=[],
)
node_cls = cast(type[BaseNode], node_cls)
# init workflow run state
node: BaseNode = node_cls(
id=str(uuid.uuid4()),

View File

@ -3,7 +3,7 @@ import os
import urllib.parse
import uuid
from collections.abc import Callable, Mapping, Sequence
from typing import Any, cast
from typing import Any
import httpx
from sqlalchemy import select
@ -258,7 +258,6 @@ def _get_remote_file_info(url: str):
mime_type = ""
resp = ssrf_proxy.head(url, follow_redirects=True)
resp = cast(httpx.Response, resp)
if resp.status_code == httpx.codes.OK:
if content_disposition := resp.headers.get("Content-Disposition"):
filename = str(content_disposition.split("filename=")[-1].strip('"'))

View File

@ -308,7 +308,7 @@ class MCPToolProvider(Base):
@property
def decrypted_server_url(self) -> str:
return cast(str, encrypter.decrypt_token(self.tenant_id, self.server_url))
return encrypter.decrypt_token(self.tenant_id, self.server_url)
@property
def masked_server_url(self) -> str:

View File

@ -146,7 +146,7 @@ class AccountService:
account.last_active_at = naive_utc_now()
db.session.commit()
return cast(Account, account)
return account
@staticmethod
def get_account_jwt_token(account: Account) -> str:
@ -191,7 +191,7 @@ class AccountService:
db.session.commit()
return cast(Account, account)
return account
@staticmethod
def update_account_password(account, password, new_password):
@ -1127,7 +1127,7 @@ class TenantService:
def get_custom_config(tenant_id: str) -> dict:
tenant = db.get_or_404(Tenant, tenant_id)
return cast(dict, tenant.custom_config_dict)
return tenant.custom_config_dict
@staticmethod
def is_owner(account: Account, tenant: Tenant) -> bool:

View File

@ -1,5 +1,5 @@
import uuid
from typing import cast
from typing import Optional
import pandas as pd
from flask_login import current_user
@ -40,7 +40,7 @@ class AppAnnotationService:
if not message:
raise NotFound("Message Not Exists.")
annotation = message.annotation
annotation: Optional[MessageAnnotation] = message.annotation
# save the message annotation
if annotation:
annotation.content = args["answer"]
@ -70,7 +70,7 @@ class AppAnnotationService:
app_id,
annotation_setting.collection_binding_id,
)
return cast(MessageAnnotation, annotation)
return annotation
@classmethod
def enable_app_annotation(cls, args: dict, app_id: str) -> dict:

View File

@ -1,7 +1,6 @@
import time
import uuid
from os import getenv
from typing import cast
import pytest
@ -13,7 +12,6 @@ from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.entities import CodeNodeData
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
from models.workflow import WorkflowType
@ -238,8 +236,6 @@ def test_execute_code_output_validator_depth():
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
}
node._node_data = cast(CodeNodeData, node._node_data)
# validate
node._transform_result(result, node._node_data.outputs)
@ -334,8 +330,6 @@ def test_execute_code_output_object_list():
]
}
node._node_data = cast(CodeNodeData, node._node_data)
# validate
node._transform_result(result, node._node_data.outputs)