rm type ignore (#25715)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
Asuka Minato 2025-10-21 12:26:58 +09:00 committed by GitHub
parent c11cdf7468
commit 32c715c4d0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
78 changed files with 229 additions and 204 deletions

View File

@ -145,7 +145,7 @@ class DatabaseConfig(BaseSettings):
default="postgresql", default="postgresql",
) )
@computed_field # type: ignore[misc] @computed_field # type: ignore[prop-decorator]
@property @property
def SQLALCHEMY_DATABASE_URI(self) -> str: def SQLALCHEMY_DATABASE_URI(self) -> str:
db_extras = ( db_extras = (
@ -198,7 +198,7 @@ class DatabaseConfig(BaseSettings):
default=os.cpu_count() or 1, default=os.cpu_count() or 1,
) )
@computed_field # type: ignore[misc] @computed_field # type: ignore[prop-decorator]
@property @property
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
# Parse DB_EXTRAS for 'options' # Parse DB_EXTRAS for 'options'

View File

@ -24,7 +24,7 @@ except ImportError:
) )
else: else:
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2) warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
magic = None # type: ignore magic = None # type: ignore[assignment]
from pydantic import BaseModel from pydantic import BaseModel

View File

@ -211,8 +211,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
user=user, user=user,
stream=streaming, stream=streaming,
) )
# FIXME: Type hinting issue here, ignore it for now, will fix it later return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore
def _generate_worker( def _generate_worker(
self, self,

View File

@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err) data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data) response_chunk.update(data)
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute] response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else: else:
response_chunk.update(sub_stream_response.model_dump(mode="json")) response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk yield response_chunk

View File

@ -98,7 +98,7 @@ class RateLimit:
else: else:
return RateLimitGenerator( return RateLimitGenerator(
rate_limit=self, rate_limit=self,
generator=generator, # ty: ignore [invalid-argument-type] generator=generator,
request_id=request_id, request_id=request_id,
) )

View File

@ -49,7 +49,7 @@ class BasedGenerateTaskPipeline:
if isinstance(e, InvokeAuthorizationError): if isinstance(e, InvokeAuthorizationError):
err = InvokeAuthorizationError("Incorrect API key provided") err = InvokeAuthorizationError("Incorrect API key provided")
elif isinstance(e, InvokeError | ValueError): elif isinstance(e, InvokeError | ValueError):
err = e # ty: ignore [invalid-assignment] err = e
else: else:
description = getattr(e, "description", None) description = getattr(e, "description", None)
err = Exception(description if description is not None else str(e)) err = Exception(description if description is not None else str(e))

View File

@ -1868,7 +1868,7 @@ class ProviderConfigurations(BaseModel):
if "/" not in key: if "/" not in key:
key = str(ModelProviderID(key)) key = str(ModelProviderID(key))
return self.configurations.get(key, default) # type: ignore return self.configurations.get(key, default)
class ProviderModelBundle(BaseModel): class ProviderModelBundle(BaseModel):

View File

@ -20,7 +20,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
else: else:
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly # Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
# FIXME: mypy does not support the type of spec.loader # FIXME: mypy does not support the type of spec.loader
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore[assignment]
if not spec or not spec.loader: if not spec or not spec.loader:
raise Exception(f"Failed to load module {module_name} from {py_file_path!r}") raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
if use_lazy_loader: if use_lazy_loader:

View File

@ -2,7 +2,7 @@ import logging
import os import os
from datetime import datetime, timedelta from datetime import datetime, timedelta
from langfuse import Langfuse # type: ignore from langfuse import Langfuse
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance from core.ops.base_trace_instance import BaseTraceInstance

View File

@ -180,7 +180,7 @@ class BasePluginClient:
Make a request to the plugin daemon inner API and return the response as a model. Make a request to the plugin daemon inner API and return the response as a model.
""" """
response = self._request(method, path, headers, data, params, files) response = self._request(method, path, headers, data, params, files)
return type_(**response.json()) # type: ignore return type_(**response.json()) # type: ignore[return-value]
def _request_with_plugin_daemon_response( def _request_with_plugin_daemon_response(
self, self,

View File

@ -74,7 +74,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
tenant_id = extract_tenant_id(user) tenant_id = extract_tenant_id(user)
if not tenant_id: if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id") raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None self._tenant_id = tenant_id
# Store app context # Store app context
self._app_id = app_id self._app_id = app_id

View File

@ -81,7 +81,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
tenant_id = extract_tenant_id(user) tenant_id = extract_tenant_id(user)
if not tenant_id: if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id") raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None self._tenant_id = tenant_id
# Store app context # Store app context
self._app_id = app_id self._app_id = app_id

View File

@ -60,7 +60,7 @@ class DifyCoreRepositoryFactory:
try: try:
repository_class = import_string(class_path) repository_class = import_string(class_path)
return repository_class( # type: ignore[no-any-return] return repository_class(
session_factory=session_factory, session_factory=session_factory,
user=user, user=user,
app_id=app_id, app_id=app_id,
@ -96,7 +96,7 @@ class DifyCoreRepositoryFactory:
try: try:
repository_class = import_string(class_path) repository_class = import_string(class_path)
return repository_class( # type: ignore[no-any-return] return repository_class(
session_factory=session_factory, session_factory=session_factory,
user=user, user=user,
app_id=app_id, app_id=app_id,

View File

@ -157,7 +157,7 @@ class BuiltinToolProviderController(ToolProviderController):
""" """
returns the tool that the provider can provide returns the tool that the provider can provide
""" """
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
@property @property
def need_credentials(self) -> bool: def need_credentials(self) -> bool:

View File

@ -43,7 +43,7 @@ class TTSTool(BuiltinTool):
content_text=tool_parameters.get("text"), # type: ignore content_text=tool_parameters.get("text"), # type: ignore
user=user_id, user=user_id,
tenant_id=self.runtime.tenant_id, tenant_id=self.runtime.tenant_id,
voice=voice, # type: ignore voice=voice,
) )
buffer = io.BytesIO() buffer = io.BytesIO()
for chunk in tts: for chunk in tts:

View File

@ -34,6 +34,7 @@ class LocaltimeToTimestampTool(BuiltinTool):
yield self.create_text_message(f"{timestamp}") yield self.create_text_message(f"{timestamp}")
# TODO: this method's type is messy
@staticmethod @staticmethod
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None: def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
try: try:

View File

@ -48,6 +48,6 @@ class TimezoneConversionTool(BuiltinTool):
datetime_with_tz = input_timezone.localize(local_time) datetime_with_tz = input_timezone.localize(local_time)
# timezone convert # timezone convert
converted_datetime = datetime_with_tz.astimezone(output_timezone) converted_datetime = datetime_with_tz.astimezone(output_timezone)
return converted_datetime.strftime(format=time_format) # type: ignore return converted_datetime.strftime(time_format)
except Exception as e: except Exception as e:
raise ToolInvokeError(str(e)) raise ToolInvokeError(str(e))

View File

@ -105,7 +105,7 @@ class MCPToolProviderController(ToolProviderController):
""" """
pass pass
def get_tool(self, tool_name: str) -> MCPTool: # type: ignore def get_tool(self, tool_name: str) -> MCPTool:
""" """
return tool with given name return tool with given name
""" """
@ -128,7 +128,7 @@ class MCPToolProviderController(ToolProviderController):
sse_read_timeout=self.sse_read_timeout, sse_read_timeout=self.sse_read_timeout,
) )
def get_tools(self) -> list[MCPTool]: # type: ignore def get_tools(self) -> list[MCPTool]:
""" """
get all tools get all tools
""" """

View File

@ -26,7 +26,7 @@ class ToolLabelManager:
labels = cls.filter_tool_labels(labels) labels = cls.filter_tool_labels(labels)
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id # ty: ignore [unresolved-attribute] provider_id = controller.provider_id
else: else:
raise ValueError("Unsupported tool type") raise ValueError("Unsupported tool type")
@ -51,7 +51,7 @@ class ToolLabelManager:
Get tool labels Get tool labels
""" """
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController): if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
provider_id = controller.provider_id # ty: ignore [unresolved-attribute] provider_id = controller.provider_id
elif isinstance(controller, BuiltinToolProviderController): elif isinstance(controller, BuiltinToolProviderController):
return controller.tool_labels return controller.tool_labels
else: else:
@ -85,7 +85,7 @@ class ToolLabelManager:
provider_ids = [] provider_ids = []
for controller in tool_providers: for controller in tool_providers:
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController) assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute] provider_ids.append(controller.provider_id)
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all() labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()

View File

@ -193,18 +193,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
DatasetDocument.archived == False, DatasetDocument.archived == False,
) )
document = db.session.scalar(dataset_document_stmt) # type: ignore document = db.session.scalar(dataset_document_stmt)
if dataset and document: if dataset and document:
source = RetrievalSourceMetadata( source = RetrievalSourceMetadata(
dataset_id=dataset.id, dataset_id=dataset.id,
dataset_name=dataset.name, dataset_name=dataset.name,
document_id=document.id, # type: ignore document_id=document.id,
document_name=document.name, # type: ignore document_name=document.name,
data_source_type=document.data_source_type, # type: ignore data_source_type=document.data_source_type,
segment_id=segment.id, segment_id=segment.id,
retriever_from=self.retriever_from, retriever_from=self.retriever_from,
score=record.score or 0.0, score=record.score or 0.0,
doc_metadata=document.doc_metadata, # type: ignore doc_metadata=document.doc_metadata,
) )
if self.retriever_from == "dev": if self.retriever_from == "dev":

View File

@ -6,8 +6,8 @@ from typing import Any, cast
from urllib.parse import unquote from urllib.parse import unquote
import chardet import chardet
import cloudscraper # type: ignore import cloudscraper
from readabilipy import simple_json_from_html_string # type: ignore from readabilipy import simple_json_from_html_string
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from core.rag.extractor import extract_processor from core.rag.extractor import extract_processor
@ -63,8 +63,8 @@ def get_url(url: str, user_agent: str | None = None) -> str:
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
elif response.status_code == 403: elif response.status_code == 403:
scraper = cloudscraper.create_scraper() scraper = cloudscraper.create_scraper()
scraper.perform_request = ssrf_proxy.make_request # type: ignore scraper.perform_request = ssrf_proxy.make_request
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore response = scraper.get(url, headers=headers, timeout=(120, 300))
if response.status_code != 200: if response.status_code != 200:
return f"URL returned status code {response.status_code}." return f"URL returned status code {response.status_code}."

View File

@ -3,7 +3,7 @@ from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import yaml # type: ignore import yaml
from yaml import YAMLError from yaml import YAMLError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -99,7 +99,7 @@ class WorkflowToolProviderController(ToolProviderController):
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None: def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore return next(filter(lambda x: x.variable == variable_name, variables), None)
user = db_provider.user user = db_provider.user

View File

@ -4,7 +4,7 @@ from .types import SegmentType
class SegmentGroup(Segment): class SegmentGroup(Segment):
value_type: SegmentType = SegmentType.GROUP value_type: SegmentType = SegmentType.GROUP
value: list[Segment] = None # type: ignore value: list[Segment]
@property @property
def text(self): def text(self):

View File

@ -19,7 +19,7 @@ class Segment(BaseModel):
model_config = ConfigDict(frozen=True) model_config = ConfigDict(frozen=True)
value_type: SegmentType value_type: SegmentType
value: Any = None value: Any
@field_validator("value_type") @field_validator("value_type")
@classmethod @classmethod
@ -74,12 +74,12 @@ class NoneSegment(Segment):
class StringSegment(Segment): class StringSegment(Segment):
value_type: SegmentType = SegmentType.STRING value_type: SegmentType = SegmentType.STRING
value: str = None # type: ignore value: str
class FloatSegment(Segment): class FloatSegment(Segment):
value_type: SegmentType = SegmentType.FLOAT value_type: SegmentType = SegmentType.FLOAT
value: float = None # type: ignore value: float
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
# The following tests cannot pass. # The following tests cannot pass.
# #
@ -98,12 +98,12 @@ class FloatSegment(Segment):
class IntegerSegment(Segment): class IntegerSegment(Segment):
value_type: SegmentType = SegmentType.INTEGER value_type: SegmentType = SegmentType.INTEGER
value: int = None # type: ignore value: int
class ObjectSegment(Segment): class ObjectSegment(Segment):
value_type: SegmentType = SegmentType.OBJECT value_type: SegmentType = SegmentType.OBJECT
value: Mapping[str, Any] = None # type: ignore value: Mapping[str, Any]
@property @property
def text(self) -> str: def text(self) -> str:
@ -136,7 +136,7 @@ class ArraySegment(Segment):
class FileSegment(Segment): class FileSegment(Segment):
value_type: SegmentType = SegmentType.FILE value_type: SegmentType = SegmentType.FILE
value: File = None # type: ignore value: File
@property @property
def markdown(self) -> str: def markdown(self) -> str:
@ -153,17 +153,17 @@ class FileSegment(Segment):
class BooleanSegment(Segment): class BooleanSegment(Segment):
value_type: SegmentType = SegmentType.BOOLEAN value_type: SegmentType = SegmentType.BOOLEAN
value: bool = None # type: ignore value: bool
class ArrayAnySegment(ArraySegment): class ArrayAnySegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_ANY value_type: SegmentType = SegmentType.ARRAY_ANY
value: Sequence[Any] = None # type: ignore value: Sequence[Any]
class ArrayStringSegment(ArraySegment): class ArrayStringSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_STRING value_type: SegmentType = SegmentType.ARRAY_STRING
value: Sequence[str] = None # type: ignore value: Sequence[str]
@property @property
def text(self) -> str: def text(self) -> str:
@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment):
class ArrayNumberSegment(ArraySegment): class ArrayNumberSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_NUMBER value_type: SegmentType = SegmentType.ARRAY_NUMBER
value: Sequence[float | int] = None # type: ignore value: Sequence[float | int]
class ArrayObjectSegment(ArraySegment): class ArrayObjectSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_OBJECT value_type: SegmentType = SegmentType.ARRAY_OBJECT
value: Sequence[Mapping[str, Any]] = None # type: ignore value: Sequence[Mapping[str, Any]]
class ArrayFileSegment(ArraySegment): class ArrayFileSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_FILE value_type: SegmentType = SegmentType.ARRAY_FILE
value: Sequence[File] = None # type: ignore value: Sequence[File]
@property @property
def markdown(self) -> str: def markdown(self) -> str:
@ -205,7 +205,7 @@ class ArrayFileSegment(ArraySegment):
class ArrayBooleanSegment(ArraySegment): class ArrayBooleanSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
value: Sequence[bool] = None # type: ignore value: Sequence[bool]
def get_segment_discriminator(v: Any) -> SegmentType | None: def get_segment_discriminator(v: Any) -> SegmentType | None:

View File

@ -1,5 +1,6 @@
import json import json
from abc import ABC from abc import ABC
from builtins import type as type_
from collections.abc import Sequence from collections.abc import Sequence
from enum import StrEnum from enum import StrEnum
from typing import Any, Union from typing import Any, Union
@ -58,10 +59,9 @@ class DefaultValue(BaseModel):
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
@staticmethod @staticmethod
def _validate_array(value: Any, element_type: DefaultValueType) -> bool: def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
"""Unified array type validation""" """Unified array type validation"""
# FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore
@staticmethod @staticmethod
def _convert_number(value: str) -> float: def _convert_number(value: str) -> float:

View File

@ -10,10 +10,10 @@ from typing import Any
import chardet import chardet
import docx import docx
import pandas as pd import pandas as pd
import pypandoc # type: ignore import pypandoc
import pypdfium2 # type: ignore import pypdfium2
import webvtt # type: ignore import webvtt
import yaml # type: ignore import yaml
from docx.document import Document from docx.document import Document
from docx.oxml.table import CT_Tbl from docx.oxml.table import CT_Tbl
from docx.oxml.text.paragraph import CT_P from docx.oxml.text.paragraph import CT_P

View File

@ -141,7 +141,7 @@ class KnowledgeRetrievalNode(Node):
def version(cls): def version(cls):
return "1" return "1"
def _run(self) -> NodeRunResult: # type: ignore def _run(self) -> NodeRunResult:
# extract variables # extract variables
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector) variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
if not isinstance(variable, StringSegment): if not isinstance(variable, StringSegment):
@ -443,7 +443,7 @@ class KnowledgeRetrievalNode(Node):
metadata_condition = MetadataCondition( metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator logical_operator=node_data.metadata_filtering_conditions.logical_operator
if node_data.metadata_filtering_conditions if node_data.metadata_filtering_conditions
else "or", # type: ignore else "or",
conditions=conditions, conditions=conditions,
) )
elif node_data.metadata_filtering_mode == "manual": elif node_data.metadata_filtering_mode == "manual":
@ -457,10 +457,10 @@ class KnowledgeRetrievalNode(Node):
expected_value = self.graph_runtime_state.variable_pool.convert_template( expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value expected_value
).value[0] ).value[0]
if expected_value.value_type in {"number", "integer", "float"}: # type: ignore if expected_value.value_type in {"number", "integer", "float"}:
expected_value = expected_value.value # type: ignore expected_value = expected_value.value
elif expected_value.value_type == "string": # type: ignore elif expected_value.value_type == "string":
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
else: else:
raise ValueError("Invalid expected metadata value type") raise ValueError("Invalid expected metadata value type")
conditions.append( conditions.append(
@ -487,7 +487,7 @@ class KnowledgeRetrievalNode(Node):
if ( if (
node_data.metadata_filtering_conditions node_data.metadata_filtering_conditions
and node_data.metadata_filtering_conditions.logical_operator == "and" and node_data.metadata_filtering_conditions.logical_operator == "and"
): # type: ignore ):
document_query = document_query.where(and_(*filters)) document_query = document_query.where(and_(*filters))
else: else:
document_query = document_query.where(or_(*filters)) document_query = document_query.where(or_(*filters))

View File

@ -260,7 +260,7 @@ class VariablePool(BaseModel):
# This ensures that we can keep the id of the system variables intact. # This ensures that we can keep the id of the system variables intact.
if self._has(selector): if self._has(selector):
continue continue
self.add(selector, value) # type: ignore self.add(selector, value)
@classmethod @classmethod
def empty(cls) -> "VariablePool": def empty(cls) -> "VariablePool":

View File

@ -7,7 +7,7 @@ def is_enabled() -> bool:
def init_app(app: DifyApp): def init_app(app: DifyApp):
from flask_compress import Compress # type: ignore from flask_compress import Compress
compress = Compress() compress = Compress()
compress.init_app(app) compress.init_app(app)

View File

@ -1,6 +1,6 @@
import json import json
import flask_login # type: ignore import flask_login
from flask import Response, request from flask import Response, request
from flask_login import user_loaded_from_request, user_logged_in from flask_login import user_loaded_from_request, user_logged_in
from werkzeug.exceptions import NotFound, Unauthorized from werkzeug.exceptions import NotFound, Unauthorized

View File

@ -2,7 +2,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp): def init_app(app: DifyApp):
import flask_migrate # type: ignore import flask_migrate
from extensions.ext_database import db from extensions.ext_database import db

View File

@ -103,7 +103,7 @@ def init_app(app: DifyApp):
def shutdown_tracer(): def shutdown_tracer():
provider = trace.get_tracer_provider() provider = trace.get_tracer_provider()
if hasattr(provider, "force_flush"): if hasattr(provider, "force_flush"):
provider.force_flush() # ty: ignore [call-non-callable] provider.force_flush()
class ExceptionLoggingHandler(logging.Handler): class ExceptionLoggingHandler(logging.Handler):
"""Custom logging handler that creates spans for logging.exception() calls""" """Custom logging handler that creates spans for logging.exception() calls"""

View File

@ -6,4 +6,4 @@ def init_app(app: DifyApp):
if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED: if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED:
from werkzeug.middleware.proxy_fix import ProxyFix from werkzeug.middleware.proxy_fix import ProxyFix
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore[method-assign]

View File

@ -5,7 +5,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp): def init_app(app: DifyApp):
if dify_config.SENTRY_DSN: if dify_config.SENTRY_DSN:
import sentry_sdk import sentry_sdk
from langfuse import parse_error # type: ignore from langfuse import parse_error
from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.celery import CeleryIntegration
from sentry_sdk.integrations.flask import FlaskIntegration from sentry_sdk.integrations.flask import FlaskIntegration
from werkzeug.exceptions import HTTPException from werkzeug.exceptions import HTTPException

View File

@ -1,7 +1,7 @@
import posixpath import posixpath
from collections.abc import Generator from collections.abc import Generator
import oss2 as aliyun_s3 # type: ignore import oss2 as aliyun_s3
from configs import dify_config from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage

View File

@ -2,9 +2,9 @@ import base64
import hashlib import hashlib
from collections.abc import Generator from collections.abc import Generator
from baidubce.auth.bce_credentials import BceCredentials # type: ignore from baidubce.auth.bce_credentials import BceCredentials
from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore from baidubce.bce_client_configuration import BceClientConfiguration
from baidubce.services.bos.bos_client import BosClient # type: ignore from baidubce.services.bos.bos_client import BosClient
from configs import dify_config from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage

View File

@ -11,7 +11,7 @@ from collections.abc import Generator
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
import clickzetta # type: ignore[import] import clickzetta
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage

View File

@ -34,7 +34,7 @@ class VolumePermissionManager:
# Support two initialization methods: connection object or configuration dictionary # Support two initialization methods: connection object or configuration dictionary
if isinstance(connection_or_config, dict): if isinstance(connection_or_config, dict):
# Create connection from configuration dictionary # Create connection from configuration dictionary
import clickzetta # type: ignore[import-untyped] import clickzetta
config = connection_or_config config = connection_or_config
self._connection = clickzetta.connect( self._connection = clickzetta.connect(

View File

@ -3,7 +3,7 @@ import io
import json import json
from collections.abc import Generator from collections.abc import Generator
from google.cloud import storage as google_cloud_storage # type: ignore from google.cloud import storage as google_cloud_storage
from configs import dify_config from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage

View File

@ -1,6 +1,6 @@
from collections.abc import Generator from collections.abc import Generator
from obs import ObsClient # type: ignore from obs import ObsClient
from configs import dify_config from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage

View File

@ -1,7 +1,7 @@
from collections.abc import Generator from collections.abc import Generator
import boto3 # type: ignore import boto3
from botocore.exceptions import ClientError # type: ignore from botocore.exceptions import ClientError
from configs import dify_config from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage

View File

@ -1,6 +1,6 @@
from collections.abc import Generator from collections.abc import Generator
from qcloud_cos import CosConfig, CosS3Client # type: ignore from qcloud_cos import CosConfig, CosS3Client
from configs import dify_config from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage

View File

@ -1,6 +1,6 @@
from collections.abc import Generator from collections.abc import Generator
import tos # type: ignore import tos
from configs import dify_config from configs import dify_config
from extensions.storage.base_storage import BaseStorage from extensions.storage.base_storage import BaseStorage

View File

@ -146,6 +146,6 @@ class ExternalApi(Api):
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False
# manual separate call on construction and init_app to ensure configs in kwargs effective # manual separate call on construction and init_app to ensure configs in kwargs effective
super().__init__(app=None, *args, **kwargs) # type: ignore super().__init__(app=None, *args, **kwargs)
self.init_app(app, **kwargs) self.init_app(app, **kwargs)
register_external_error_handlers(self) register_external_error_handlers(self)

View File

@ -23,7 +23,7 @@ from hashlib import sha1
import Crypto.Hash.SHA1 import Crypto.Hash.SHA1
import Crypto.Util.number import Crypto.Util.number
import gmpy2 # type: ignore import gmpy2
from Crypto import Random from Crypto import Random
from Crypto.Signature.pss import MGF1 from Crypto.Signature.pss import MGF1
from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
@ -136,7 +136,7 @@ class PKCS1OAepCipher:
# Step 3a (OS2IP) # Step 3a (OS2IP)
em_int = bytes_to_long(em) em_int = bytes_to_long(em)
# Step 3b (RSAEP) # Step 3b (RSAEP)
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) # ty: ignore [unresolved-attribute] m_int = gmpy2.powmod(em_int, self._key.e, self._key.n)
# Step 3c (I2OSP) # Step 3c (I2OSP)
c = long_to_bytes(m_int, k) c = long_to_bytes(m_int, k)
return c return c
@ -169,7 +169,7 @@ class PKCS1OAepCipher:
ct_int = bytes_to_long(ciphertext) ct_int = bytes_to_long(ciphertext)
# Step 2b (RSADP) # Step 2b (RSADP)
# m_int = self._key._decrypt(ct_int) # m_int = self._key._decrypt(ct_int)
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # ty: ignore [unresolved-attribute] m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n)
# Complete step 2c (I2OSP) # Complete step 2c (I2OSP)
em = long_to_bytes(m_int, k) em = long_to_bytes(m_int, k)
# Step 3a # Step 3a
@ -191,12 +191,12 @@ class PKCS1OAepCipher:
# Step 3g # Step 3g
one_pos = hLen + db[hLen:].find(b"\x01") one_pos = hLen + db[hLen:].find(b"\x01")
lHash1 = db[:hLen] lHash1 = db[:hLen]
invalid = bord(y) | int(one_pos < hLen) # type: ignore invalid = bord(y) | int(one_pos < hLen) # type: ignore[arg-type]
hash_compare = strxor(lHash1, lHash) hash_compare = strxor(lHash1, lHash)
for x in hash_compare: for x in hash_compare:
invalid |= bord(x) # type: ignore invalid |= bord(x) # type: ignore[arg-type]
for x in db[hLen:one_pos]: for x in db[hLen:one_pos]:
invalid |= bord(x) # type: ignore invalid |= bord(x) # type: ignore[arg-type]
if invalid != 0: if invalid != 0:
raise ValueError("Incorrect decryption.") raise ValueError("Incorrect decryption.")
# Step 4 # Step 4

View File

@ -3,7 +3,7 @@ from functools import wraps
from typing import Any from typing import Any
from flask import current_app, g, has_request_context, request from flask import current_app, g, has_request_context, request
from flask_login.config import EXEMPT_METHODS # type: ignore from flask_login.config import EXEMPT_METHODS
from werkzeug.local import LocalProxy from werkzeug.local import LocalProxy
from configs import dify_config from configs import dify_config
@ -87,7 +87,7 @@ def _get_user() -> EndUser | Account | None:
if "_login_user" not in g: if "_login_user" not in g:
current_app.login_manager._load_user() # type: ignore current_app.login_manager._load_user() # type: ignore
return g._login_user # type: ignore return g._login_user
return None return None

View File

@ -1,8 +1,8 @@
import logging import logging
import sendgrid # type: ignore import sendgrid
from python_http_client.exceptions import ForbiddenError, UnauthorizedError from python_http_client.exceptions import ForbiddenError, UnauthorizedError
from sendgrid.helpers.mail import Content, Email, Mail, To # type: ignore from sendgrid.helpers.mail import Content, Email, Mail, To
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -5,7 +5,7 @@ from datetime import datetime
from typing import Any, Optional from typing import Any, Optional
import sqlalchemy as sa import sqlalchemy as sa
from flask_login import UserMixin # type: ignore[import-untyped] from flask_login import UserMixin
from sqlalchemy import DateTime, String, func, select from sqlalchemy import DateTime, String, func, select
from sqlalchemy.orm import Mapped, Session, mapped_column from sqlalchemy.orm import Mapped, Session, mapped_column
from typing_extensions import deprecated from typing_extensions import deprecated

View File

@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import sqlalchemy as sa import sqlalchemy as sa
from flask import request from flask import request
from flask_login import UserMixin # type: ignore[import-untyped] from flask_login import UserMixin
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
from sqlalchemy.orm import Mapped, Session, mapped_column from sqlalchemy.orm import Mapped, Session, mapped_column

View File

@ -16,7 +16,25 @@
"opentelemetry.instrumentation.requests", "opentelemetry.instrumentation.requests",
"opentelemetry.instrumentation.sqlalchemy", "opentelemetry.instrumentation.sqlalchemy",
"opentelemetry.instrumentation.redis", "opentelemetry.instrumentation.redis",
"opentelemetry.instrumentation.httpx" "langfuse",
"cloudscraper",
"readabilipy",
"pypandoc",
"pypdfium2",
"webvtt",
"flask_compress",
"oss2",
"baidubce.auth.bce_credentials",
"baidubce.bce_client_configuration",
"baidubce.services.bos.bos_client",
"clickzetta",
"google.cloud",
"obs",
"qcloud_cos",
"tos",
"gmpy2",
"sendgrid",
"sendgrid.helpers.mail"
], ],
"reportUnknownMemberType": "hint", "reportUnknownMemberType": "hint",
"reportUnknownParameterType": "hint", "reportUnknownParameterType": "hint",
@ -28,7 +46,7 @@
"reportUnnecessaryComparison": "hint", "reportUnnecessaryComparison": "hint",
"reportUnnecessaryIsInstance": "hint", "reportUnnecessaryIsInstance": "hint",
"reportUntypedFunctionDecorator": "hint", "reportUntypedFunctionDecorator": "hint",
"reportUnnecessaryTypeIgnoreComment": "hint",
"reportAttributeAccessIssue": "hint", "reportAttributeAccessIssue": "hint",
"pythonVersion": "3.11", "pythonVersion": "3.11",
"pythonPlatform": "All" "pythonPlatform": "All"

View File

@ -48,7 +48,7 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
try: try:
repository_class = import_string(class_path) repository_class = import_string(class_path)
return repository_class(session_maker=session_maker) # type: ignore[no-any-return] return repository_class(session_maker=session_maker)
except (ImportError, Exception) as e: except (ImportError, Exception) as e:
raise RepositoryImportError( raise RepositoryImportError(
f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}" f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}"
@ -77,6 +77,6 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
try: try:
repository_class = import_string(class_path) repository_class = import_string(class_path)
return repository_class(session_maker=session_maker) # type: ignore[no-any-return] return repository_class(session_maker=session_maker)
except (ImportError, Exception) as e: except (ImportError, Exception) as e:
raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e

View File

@ -7,7 +7,7 @@ from enum import StrEnum
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
import yaml # type: ignore import yaml
from Crypto.Cipher import AES from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad from Crypto.Util.Padding import pad, unpad
from packaging import version from packaging import version
@ -563,7 +563,7 @@ class AppDslService:
else: else:
cls._append_model_config_export_data(export_data, app_model) cls._append_model_config_export_data(export_data, app_model)
return yaml.dump(export_data, allow_unicode=True) # type: ignore return yaml.dump(export_data, allow_unicode=True)
@classmethod @classmethod
def _append_workflow_export_data( def _append_workflow_export_data(

View File

@ -241,9 +241,9 @@ class DatasetService:
dataset.created_by = account.id dataset.created_by = account.id
dataset.updated_by = account.id dataset.updated_by = account.id
dataset.tenant_id = tenant_id dataset.tenant_id = tenant_id
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None # type: ignore dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
dataset.embedding_model = embedding_model.model if embedding_model else None # type: ignore dataset.embedding_model = embedding_model.model if embedding_model else None
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None # type: ignore dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
dataset.provider = provider dataset.provider = provider
db.session.add(dataset) db.session.add(dataset)
@ -1416,6 +1416,8 @@ class DocumentService:
# check document limit # check document limit
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None assert current_user.current_tenant_id is not None
assert knowledge_config.data_source
assert knowledge_config.data_source.info_list.file_info_list
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
@ -1424,15 +1426,16 @@ class DocumentService:
count = 0 count = 0
if knowledge_config.data_source: if knowledge_config.data_source:
if knowledge_config.data_source.info_list.data_source_type == "upload_file": if knowledge_config.data_source.info_list.data_source_type == "upload_file":
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
count = len(upload_file_list) count = len(upload_file_list)
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list notion_info_list = knowledge_config.data_source.info_list.notion_info_list or []
for notion_info in notion_info_list: # type: ignore for notion_info in notion_info_list:
count = count + len(notion_info.pages) count = count + len(notion_info.pages)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list website_info = knowledge_config.data_source.info_list.website_info_list
count = len(website_info.urls) # type: ignore assert website_info
count = len(website_info.urls)
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT) batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if features.billing.subscription.plan == "sandbox" and count > 1: if features.billing.subscription.plan == "sandbox" and count > 1:
@ -1444,7 +1447,7 @@ class DocumentService:
# if dataset is empty, update dataset data_source_type # if dataset is empty, update dataset data_source_type
if not dataset.data_source_type: if not dataset.data_source_type:
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type
if not dataset.indexing_technique: if not dataset.indexing_technique:
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST: if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
@ -1481,7 +1484,7 @@ class DocumentService:
knowledge_config.retrieval_model.model_dump() knowledge_config.retrieval_model.model_dump()
if knowledge_config.retrieval_model if knowledge_config.retrieval_model
else default_retrieval_model else default_retrieval_model
) # type: ignore )
documents = [] documents = []
if knowledge_config.original_document_id: if knowledge_config.original_document_id:
@ -1523,11 +1526,12 @@ class DocumentService:
db.session.flush() db.session.flush()
lock_name = f"add_document_lock_dataset_id_{dataset.id}" lock_name = f"add_document_lock_dataset_id_{dataset.id}"
with redis_client.lock(lock_name, timeout=600): with redis_client.lock(lock_name, timeout=600):
assert dataset_process_rule
position = DocumentService.get_documents_position(dataset.id) position = DocumentService.get_documents_position(dataset.id)
document_ids = [] document_ids = []
duplicate_document_ids = [] duplicate_document_ids = []
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore if knowledge_config.data_source.info_list.data_source_type == "upload_file":
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
for file_id in upload_file_list: for file_id in upload_file_list:
file = ( file = (
db.session.query(UploadFile) db.session.query(UploadFile)
@ -1540,7 +1544,7 @@ class DocumentService:
raise FileNotExistsError() raise FileNotExistsError()
file_name = file.name file_name = file.name
data_source_info = { data_source_info: dict[str, str | bool] = {
"upload_file_id": file_id, "upload_file_id": file_id,
} }
# check duplicate # check duplicate
@ -1557,7 +1561,7 @@ class DocumentService:
.first() .first()
) )
if document: if document:
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore document.dataset_process_rule_id = dataset_process_rule.id
document.updated_at = naive_utc_now() document.updated_at = naive_utc_now()
document.created_from = created_from document.created_from = created_from
document.doc_form = knowledge_config.doc_form document.doc_form = knowledge_config.doc_form
@ -1571,8 +1575,8 @@ class DocumentService:
continue continue
document = DocumentService.build_document( document = DocumentService.build_document(
dataset, dataset,
dataset_process_rule.id, # type: ignore dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type, # type: ignore knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form, knowledge_config.doc_form,
knowledge_config.doc_language, knowledge_config.doc_language,
data_source_info, data_source_info,
@ -1587,7 +1591,7 @@ class DocumentService:
document_ids.append(document.id) document_ids.append(document.id)
documents.append(document) documents.append(document)
position += 1 position += 1
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
if not notion_info_list: if not notion_info_list:
raise ValueError("No notion info list found.") raise ValueError("No notion info list found.")
@ -1616,15 +1620,15 @@ class DocumentService:
"credential_id": notion_info.credential_id, "credential_id": notion_info.credential_id,
"notion_workspace_id": workspace_id, "notion_workspace_id": workspace_id,
"notion_page_id": page.page_id, "notion_page_id": page.page_id,
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, "notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore
"type": page.type, "type": page.type,
} }
# Truncate page name to 255 characters to prevent DB field length errors # Truncate page name to 255 characters to prevent DB field length errors
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename" truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
document = DocumentService.build_document( document = DocumentService.build_document(
dataset, dataset,
dataset_process_rule.id, # type: ignore dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type, # type: ignore knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form, knowledge_config.doc_form,
knowledge_config.doc_language, knowledge_config.doc_language,
data_source_info, data_source_info,
@ -1644,8 +1648,8 @@ class DocumentService:
# delete not selected documents # delete not selected documents
if len(exist_document) > 0: if len(exist_document) > 0:
clean_notion_document_task.delay(list(exist_document.values()), dataset.id) clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore website_info = knowledge_config.data_source.info_list.website_info_list
if not website_info: if not website_info:
raise ValueError("No website info list found.") raise ValueError("No website info list found.")
urls = website_info.urls urls = website_info.urls
@ -1663,8 +1667,8 @@ class DocumentService:
document_name = url document_name = url
document = DocumentService.build_document( document = DocumentService.build_document(
dataset, dataset,
dataset_process_rule.id, # type: ignore dataset_process_rule.id,
knowledge_config.data_source.info_list.data_source_type, # type: ignore knowledge_config.data_source.info_list.data_source_type,
knowledge_config.doc_form, knowledge_config.doc_form,
knowledge_config.doc_language, knowledge_config.doc_language,
data_source_info, data_source_info,
@ -2071,7 +2075,7 @@ class DocumentService:
# update document data source # update document data source
if document_data.data_source: if document_data.data_source:
file_name = "" file_name = ""
data_source_info = {} data_source_info: dict[str, str | bool] = {}
if document_data.data_source.info_list.data_source_type == "upload_file": if document_data.data_source.info_list.data_source_type == "upload_file":
if not document_data.data_source.info_list.file_info_list: if not document_data.data_source.info_list.file_info_list:
raise ValueError("No file info list found.") raise ValueError("No file info list found.")
@ -2128,7 +2132,7 @@ class DocumentService:
"url": url, "url": url,
"provider": website_info.provider, "provider": website_info.provider,
"job_id": website_info.job_id, "job_id": website_info.job_id,
"only_main_content": website_info.only_main_content, # type: ignore "only_main_content": website_info.only_main_content,
"mode": "crawl", "mode": "crawl",
} }
document.data_source_type = document_data.data_source.info_list.data_source_type document.data_source_type = document_data.data_source.info_list.data_source_type
@ -2154,7 +2158,7 @@ class DocumentService:
db.session.query(DocumentSegment).filter_by(document_id=document.id).update( db.session.query(DocumentSegment).filter_by(document_id=document.id).update(
{DocumentSegment.status: "re_segment"} {DocumentSegment.status: "re_segment"}
) # type: ignore )
db.session.commit() db.session.commit()
# trigger async task # trigger async task
document_indexing_update_task.delay(document.dataset_id, document.id) document_indexing_update_task.delay(document.dataset_id, document.id)
@ -2164,25 +2168,26 @@ class DocumentService:
def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account): def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account):
assert isinstance(current_user, Account) assert isinstance(current_user, Account)
assert current_user.current_tenant_id is not None assert current_user.current_tenant_id is not None
assert knowledge_config.data_source
features = FeatureService.get_features(current_user.current_tenant_id) features = FeatureService.get_features(current_user.current_tenant_id)
if features.billing.enabled: if features.billing.enabled:
count = 0 count = 0
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore if knowledge_config.data_source.info_list.data_source_type == "upload_file":
upload_file_list = ( upload_file_list = (
knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore knowledge_config.data_source.info_list.file_info_list.file_ids
if knowledge_config.data_source.info_list.file_info_list # type: ignore if knowledge_config.data_source.info_list.file_info_list
else [] else []
) )
count = len(upload_file_list) count = len(upload_file_list)
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore notion_info_list = knowledge_config.data_source.info_list.notion_info_list
if notion_info_list: if notion_info_list:
for notion_info in notion_info_list: for notion_info in notion_info_list:
count = count + len(notion_info.pages) count = count + len(notion_info.pages)
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore website_info = knowledge_config.data_source.info_list.website_info_list
if website_info: if website_info:
count = len(website_info.urls) count = len(website_info.urls)
if features.billing.subscription.plan == "sandbox" and count > 1: if features.billing.subscription.plan == "sandbox" and count > 1:
@ -2196,9 +2201,11 @@ class DocumentService:
dataset_collection_binding_id = None dataset_collection_binding_id = None
retrieval_model = None retrieval_model = None
if knowledge_config.indexing_technique == "high_quality": if knowledge_config.indexing_technique == "high_quality":
assert knowledge_config.embedding_model_provider
assert knowledge_config.embedding_model
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
knowledge_config.embedding_model_provider, # type: ignore knowledge_config.embedding_model_provider,
knowledge_config.embedding_model, # type: ignore knowledge_config.embedding_model,
) )
dataset_collection_binding_id = dataset_collection_binding.id dataset_collection_binding_id = dataset_collection_binding.id
if knowledge_config.retrieval_model: if knowledge_config.retrieval_model:
@ -2215,7 +2222,7 @@ class DocumentService:
dataset = Dataset( dataset = Dataset(
tenant_id=tenant_id, tenant_id=tenant_id,
name="", name="",
data_source_type=knowledge_config.data_source.info_list.data_source_type, # type: ignore data_source_type=knowledge_config.data_source.info_list.data_source_type,
indexing_technique=knowledge_config.indexing_technique, indexing_technique=knowledge_config.indexing_technique,
created_by=account.id, created_by=account.id,
embedding_model=knowledge_config.embedding_model, embedding_model=knowledge_config.embedding_model,
@ -2224,7 +2231,7 @@ class DocumentService:
retrieval_model=retrieval_model.model_dump() if retrieval_model else None, retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
) )
db.session.add(dataset) # type: ignore db.session.add(dataset)
db.session.flush() db.session.flush()
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account) documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account)

View File

@ -88,7 +88,7 @@ class HitTestingService:
db.session.add(dataset_query) db.session.add(dataset_query)
db.session.commit() db.session.commit()
return cls.compact_retrieve_response(query, all_documents) # type: ignore return cls.compact_retrieve_response(query, all_documents)
@classmethod @classmethod
def external_retrieve( def external_retrieve(

View File

@ -1,4 +1,4 @@
import boto3 # type: ignore import boto3
from configs import dify_config from configs import dify_config

View File

@ -89,7 +89,7 @@ class MetadataService:
document.doc_metadata = doc_metadata document.doc_metadata = doc_metadata
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
return metadata # type: ignore return metadata
except Exception: except Exception:
logger.exception("Update metadata name failed") logger.exception("Update metadata name failed")
finally: finally:

View File

@ -137,7 +137,7 @@ class ModelProviderService:
:return: :return:
""" """
provider_configuration = self._get_provider_configuration(tenant_id, provider) provider_configuration = self._get_provider_configuration(tenant_id, provider)
return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore return provider_configuration.get_provider_credential(credential_id=credential_id)
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict): def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict):
""" """
@ -225,7 +225,7 @@ class ModelProviderService:
:return: :return:
""" """
provider_configuration = self._get_provider_configuration(tenant_id, provider) provider_configuration = self._get_provider_configuration(tenant_id, provider)
return provider_configuration.get_custom_model_credential( # type: ignore return provider_configuration.get_custom_model_credential(
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
) )

View File

@ -146,7 +146,7 @@ class PluginMigration:
futures.append( futures.append(
thread_pool.submit( thread_pool.submit(
process_tenant, process_tenant,
current_app._get_current_object(), # type: ignore[attr-defined] current_app._get_current_object(), # type: ignore
tenant_id, tenant_id,
) )
) )

View File

@ -544,8 +544,8 @@ class BuiltinToolManageService:
try: try:
# handle include, exclude # handle include, exclude
if is_filtered( if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider_controller, data=provider_controller,
name_func=lambda x: x.entity.identity.name, name_func=lambda x: x.entity.identity.name,
): ):

View File

@ -308,7 +308,7 @@ class MCPToolManageService:
provider_controller = MCPToolProviderController.from_db(mcp_provider) provider_controller = MCPToolProviderController.from_db(mcp_provider)
tool_configuration = ProviderConfigEncrypter( tool_configuration = ProviderConfigEncrypter(
tenant_id=mcp_provider.tenant_id, tenant_id=mcp_provider.tenant_id,
config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type] config=list(provider_controller.get_credentials_schema()),
provider_config_cache=NoOpProviderCredentialCache(), provider_config_cache=NoOpProviderCredentialCache(),
) )
credentials = tool_configuration.encrypt(credentials) credentials = tool_configuration.encrypt(credentials)

View File

@ -102,7 +102,7 @@ def batch_create_segment_to_index_task(
for segment, tokens in zip(content, tokens_list): for segment, tokens in zip(content, tokens_list):
content = segment["content"] content = segment["content"]
doc_id = str(uuid.uuid4()) doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content) # type: ignore segment_hash = helper.generate_text_hash(content)
max_position = ( max_position = (
db.session.query(func.max(DocumentSegment.position)) db.session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == dataset_document.id) .where(DocumentSegment.document_id == dataset_document.id)

View File

@ -5,11 +5,11 @@ from unittest.mock import MagicMock
import pytest import pytest
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
from pymochow import MochowClient # type: ignore from pymochow import MochowClient
from pymochow.model.database import Database # type: ignore from pymochow.model.database import Database
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState # type: ignore from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState
from pymochow.model.schema import HNSWParams, VectorIndex # type: ignore from pymochow.model.schema import HNSWParams, VectorIndex
from pymochow.model.table import Table # type: ignore from pymochow.model.table import Table
class AttrDict(UserDict): class AttrDict(UserDict):

View File

@ -3,15 +3,15 @@ from typing import Any, Union
import pytest import pytest
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
from tcvectordb import RPCVectorDBClient # type: ignore from tcvectordb import RPCVectorDBClient
from tcvectordb.model import enum from tcvectordb.model import enum
from tcvectordb.model.collection import FilterIndexConfig from tcvectordb.model.collection import FilterIndexConfig
from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank # type: ignore from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank
from tcvectordb.model.enum import ReadConsistency # type: ignore from tcvectordb.model.enum import ReadConsistency
from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex # type: ignore from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex
from tcvectordb.rpc.model.collection import RPCCollection from tcvectordb.rpc.model.collection import RPCCollection
from tcvectordb.rpc.model.database import RPCDatabase from tcvectordb.rpc.model.database import RPCDatabase
from xinference_client.types import Embedding # type: ignore from xinference_client.types import Embedding
class MockTcvectordbClass: class MockTcvectordbClass:

View File

@ -4,7 +4,7 @@ from unittest.mock import MagicMock
import pytest import pytest
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
from volcengine.viking_db import ( # type: ignore from volcengine.viking_db import (
Collection, Collection,
Data, Data,
DistanceType, DistanceType,

View File

@ -43,7 +43,7 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue:
"""Test with None input""" """Test with None input"""
# The method signature expects Union[dict, list, Segment], but implementation handles None # The method signature expects Union[dict, list, Segment], but implementation handles None
# We'll test the actual behavior by passing an empty dict instead # We'll test the actual behavior by passing an empty dict instead
result = WorkflowResponseConverter._fetch_files_from_variable_value(None) # type: ignore result = WorkflowResponseConverter._fetch_files_from_variable_value(None)
assert result == [] assert result == []
def test_fetch_files_from_variable_value_with_empty_dict(self): def test_fetch_files_from_variable_value_with_empty_dict(self):

View File

@ -235,7 +235,7 @@ class TestIndividualHandlers:
# Type assertion needed due to union type # Type assertion needed due to union type
text_content = result.content[0] text_content = result.content[0]
assert hasattr(text_content, "text") assert hasattr(text_content, "text")
assert text_content.text == "test answer" # type: ignore[attr-defined] assert text_content.text == "test answer"
def test_handle_call_tool_no_end_user(self): def test_handle_call_tool_no_end_user(self):
"""Test call tool handler without end user""" """Test call tool handler without end user"""

View File

@ -212,7 +212,7 @@ class TestValidateResult:
parameters=[ parameters=[
ParameterConfig( ParameterConfig(
name="status", name="status",
type="select", # type: ignore type="select",
description="Status", description="Status",
required=True, required=True,
options=["active", "inactive"], options=["active", "inactive"],
@ -400,7 +400,7 @@ class TestTransformResult:
parameters=[ parameters=[
ParameterConfig( ParameterConfig(
name="status", name="status",
type="select", # type: ignore type="select",
description="Status", description="Status",
required=True, required=True,
options=["active", "inactive"], options=["active", "inactive"],
@ -414,7 +414,7 @@ class TestTransformResult:
parameters=[ parameters=[
ParameterConfig( ParameterConfig(
name="status", name="status",
type="select", # type: ignore type="select",
description="Status", description="Status",
required=True, required=True,
options=["active", "inactive"], options=["active", "inactive"],

View File

@ -248,4 +248,4 @@ def test_constructor_with_extra_key():
# Test that SystemVariable should forbid extra keys # Test that SystemVariable should forbid extra keys
with pytest.raises(ValidationError): with pytest.raises(ValidationError):
# This should fail because there is an unexpected key. # This should fail because there is an unexpected key.
SystemVariable(invalid_key=1) # type: ignore SystemVariable(invalid_key=1)

View File

@ -14,36 +14,36 @@ def _create_api_app():
api = ExternalApi(bp) api = ExternalApi(bp)
@api.route("/bad-request") @api.route("/bad-request")
class Bad(Resource): # type: ignore class Bad(Resource):
def get(self): # type: ignore def get(self):
raise BadRequest("invalid input") raise BadRequest("invalid input")
@api.route("/unauth") @api.route("/unauth")
class Unauth(Resource): # type: ignore class Unauth(Resource):
def get(self): # type: ignore def get(self):
raise Unauthorized("auth required") raise Unauthorized("auth required")
@api.route("/value-error") @api.route("/value-error")
class ValErr(Resource): # type: ignore class ValErr(Resource):
def get(self): # type: ignore def get(self):
raise ValueError("boom") raise ValueError("boom")
@api.route("/quota") @api.route("/quota")
class Quota(Resource): # type: ignore class Quota(Resource):
def get(self): # type: ignore def get(self):
raise AppInvokeQuotaExceededError("quota exceeded") raise AppInvokeQuotaExceededError("quota exceeded")
@api.route("/general") @api.route("/general")
class Gen(Resource): # type: ignore class Gen(Resource):
def get(self): # type: ignore def get(self):
raise RuntimeError("oops") raise RuntimeError("oops")
# Note: We avoid altering default_mediatype to keep normal error paths # Note: We avoid altering default_mediatype to keep normal error paths
# Special 400 message rewrite # Special 400 message rewrite
@api.route("/json-empty") @api.route("/json-empty")
class JsonEmpty(Resource): # type: ignore class JsonEmpty(Resource):
def get(self): # type: ignore def get(self):
e = BadRequest() e = BadRequest()
# Force the specific message the handler rewrites # Force the specific message the handler rewrites
e.description = "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)" e.description = "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)"
@ -51,11 +51,11 @@ def _create_api_app():
# 400 mapping payload path # 400 mapping payload path
@api.route("/param-errors") @api.route("/param-errors")
class ParamErrors(Resource): # type: ignore class ParamErrors(Resource):
def get(self): # type: ignore def get(self):
e = BadRequest() e = BadRequest()
# Coerce a mapping description to trigger param error shaping # Coerce a mapping description to trigger param error shaping
e.description = {"field": "is required"} # type: ignore[assignment] e.description = {"field": "is required"}
raise e raise e
app.register_blueprint(bp, url_prefix="/api") app.register_blueprint(bp, url_prefix="/api")
@ -105,7 +105,7 @@ def test_external_api_param_mapping_and_quota_and_exc_info_none():
orig_exc_info = ext.sys.exc_info orig_exc_info = ext.sys.exc_info
try: try:
ext.sys.exc_info = lambda: (None, None, None) # type: ignore[assignment] ext.sys.exc_info = lambda: (None, None, None)
app = _create_api_app() app = _create_api_app()
client = app.test_client() client = app.test_client()

View File

@ -67,7 +67,7 @@ def test_current_user_not_accessible_across_threads(login_app: Flask, test_user:
# without preserve_flask_contexts # without preserve_flask_contexts
result["user_accessible"] = current_user.is_authenticated result["user_accessible"] = current_user.is_authenticated
except Exception as e: except Exception as e:
result["error"] = str(e) # type: ignore result["error"] = str(e)
# Run the function in a separate thread # Run the function in a separate thread
thread = threading.Thread(target=check_user_in_thread) thread = threading.Thread(target=check_user_in_thread)
@ -110,7 +110,7 @@ def test_current_user_accessible_with_preserve_flask_contexts(login_app: Flask,
else: else:
result["user_accessible"] = False result["user_accessible"] = False
except Exception as e: except Exception as e:
result["error"] = str(e) # type: ignore result["error"] = str(e)
# Run the function in a separate thread # Run the function in a separate thread
thread = threading.Thread(target=check_user_in_thread_with_manager) thread = threading.Thread(target=check_user_in_thread_with_manager)

View File

@ -16,4 +16,4 @@ def test_oauth_base_methods_raise_not_implemented():
oauth.get_raw_user_info("token") oauth.get_raw_user_info("token")
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
oauth._transform_user_info({}) # type: ignore[name-defined] oauth._transform_user_info({})

View File

@ -3,8 +3,8 @@ from unittest.mock import MagicMock
import pytest import pytest
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
from qcloud_cos import CosS3Client # type: ignore from qcloud_cos import CosS3Client
from qcloud_cos.streambody import StreamBody # type: ignore from qcloud_cos.streambody import StreamBody
from tests.unit_tests.oss.__mock.base import ( from tests.unit_tests.oss.__mock.base import (
get_example_bucket, get_example_bucket,

View File

@ -4,8 +4,8 @@ from unittest.mock import MagicMock
import pytest import pytest
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
from tos import TosClientV2 # type: ignore from tos import TosClientV2
from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput # type: ignore from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput
from tests.unit_tests.oss.__mock.base import ( from tests.unit_tests.oss.__mock.base import (
get_example_bucket, get_example_bucket,

View File

@ -1,7 +1,7 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from qcloud_cos import CosConfig # type: ignore from qcloud_cos import CosConfig
from extensions.storage.tencent_cos_storage import TencentCosStorage from extensions.storage.tencent_cos_storage import TencentCosStorage
from tests.unit_tests.oss.__mock.base import ( from tests.unit_tests.oss.__mock.base import (

View File

@ -1,7 +1,7 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from tos import TosClientV2 # type: ignore from tos import TosClientV2
from extensions.storage.volcengine_tos_storage import VolcengineTosStorage from extensions.storage.volcengine_tos_storage import VolcengineTosStorage
from tests.unit_tests.oss.__mock.base import ( from tests.unit_tests.oss.__mock.base import (

View File

@ -125,13 +125,13 @@ class TestApiKeyAuthService:
mock_session.commit = Mock() mock_session.commit = Mock()
args_copy = self.mock_args.copy() args_copy = self.mock_args.copy()
original_key = args_copy["credentials"]["config"]["api_key"] # type: ignore original_key = args_copy["credentials"]["config"]["api_key"]
ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy) ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy)
# Verify original key is replaced with encrypted key # Verify original key is replaced with encrypted key
assert args_copy["credentials"]["config"]["api_key"] == encrypted_key # type: ignore assert args_copy["credentials"]["config"]["api_key"] == encrypted_key
assert args_copy["credentials"]["config"]["api_key"] != original_key # type: ignore assert args_copy["credentials"]["config"]["api_key"] != original_key
# Verify encryption function is called correctly # Verify encryption function is called correctly
mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key) mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key)
@ -268,7 +268,7 @@ class TestApiKeyAuthService:
def test_validate_api_key_auth_args_empty_credentials(self): def test_validate_api_key_auth_args_empty_credentials(self):
"""Test API key auth args validation - empty credentials""" """Test API key auth args validation - empty credentials"""
args = self.mock_args.copy() args = self.mock_args.copy()
args["credentials"] = None # type: ignore args["credentials"] = None
with pytest.raises(ValueError, match="credentials is required"): with pytest.raises(ValueError, match="credentials is required"):
ApiKeyAuthService.validate_api_key_auth_args(args) ApiKeyAuthService.validate_api_key_auth_args(args)
@ -284,7 +284,7 @@ class TestApiKeyAuthService:
def test_validate_api_key_auth_args_missing_auth_type(self): def test_validate_api_key_auth_args_missing_auth_type(self):
"""Test API key auth args validation - missing auth_type""" """Test API key auth args validation - missing auth_type"""
args = self.mock_args.copy() args = self.mock_args.copy()
del args["credentials"]["auth_type"] # type: ignore del args["credentials"]["auth_type"]
with pytest.raises(ValueError, match="auth_type is required"): with pytest.raises(ValueError, match="auth_type is required"):
ApiKeyAuthService.validate_api_key_auth_args(args) ApiKeyAuthService.validate_api_key_auth_args(args)
@ -292,7 +292,7 @@ class TestApiKeyAuthService:
def test_validate_api_key_auth_args_empty_auth_type(self): def test_validate_api_key_auth_args_empty_auth_type(self):
"""Test API key auth args validation - empty auth_type""" """Test API key auth args validation - empty auth_type"""
args = self.mock_args.copy() args = self.mock_args.copy()
args["credentials"]["auth_type"] = "" # type: ignore args["credentials"]["auth_type"] = ""
with pytest.raises(ValueError, match="auth_type is required"): with pytest.raises(ValueError, match="auth_type is required"):
ApiKeyAuthService.validate_api_key_auth_args(args) ApiKeyAuthService.validate_api_key_auth_args(args)
@ -380,7 +380,7 @@ class TestApiKeyAuthService:
def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self): def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self):
"""Test API key auth args validation - dict credentials with list auth_type""" """Test API key auth args validation - dict credentials with list auth_type"""
args = self.mock_args.copy() args = self.mock_args.copy()
args["credentials"]["auth_type"] = ["api_key"] # type: ignore # list instead of string args["credentials"]["auth_type"] = ["api_key"]
# Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy # Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy
# So this should not raise exception, this test should pass # So this should not raise exception, this test should pass

View File

@ -116,10 +116,10 @@ class TestSystemOAuthEncrypter:
encrypter = SystemOAuthEncrypter("test_secret") encrypter = SystemOAuthEncrypter("test_secret")
with pytest.raises(Exception): # noqa: B017 with pytest.raises(Exception): # noqa: B017
encrypter.encrypt_oauth_params(None) # type: ignore encrypter.encrypt_oauth_params(None)
with pytest.raises(Exception): # noqa: B017 with pytest.raises(Exception): # noqa: B017
encrypter.encrypt_oauth_params("not_a_dict") # type: ignore encrypter.encrypt_oauth_params("not_a_dict")
def test_decrypt_oauth_params_basic(self): def test_decrypt_oauth_params_basic(self):
"""Test basic OAuth parameters decryption""" """Test basic OAuth parameters decryption"""
@ -207,12 +207,12 @@ class TestSystemOAuthEncrypter:
encrypter = SystemOAuthEncrypter("test_secret") encrypter = SystemOAuthEncrypter("test_secret")
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params(123) # type: ignore encrypter.decrypt_oauth_params(123)
assert "encrypted_data must be a string" in str(exc_info.value) assert "encrypted_data must be a string" in str(exc_info.value)
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params(None) # type: ignore encrypter.decrypt_oauth_params(None)
assert "encrypted_data must be a string" in str(exc_info.value) assert "encrypted_data must be a string" in str(exc_info.value)
@ -461,14 +461,14 @@ class TestConvenienceFunctions:
"""Test convenience functions with error conditions""" """Test convenience functions with error conditions"""
# Test encryption with invalid input # Test encryption with invalid input
with pytest.raises(Exception): # noqa: B017 with pytest.raises(Exception): # noqa: B017
encrypt_system_oauth_params(None) # type: ignore encrypt_system_oauth_params(None)
# Test decryption with invalid input # Test decryption with invalid input
with pytest.raises(ValueError): with pytest.raises(ValueError):
decrypt_system_oauth_params("") decrypt_system_oauth_params("")
with pytest.raises(ValueError): with pytest.raises(ValueError):
decrypt_system_oauth_params(None) # type: ignore decrypt_system_oauth_params(None)
class TestErrorHandling: class TestErrorHandling:
@ -501,7 +501,7 @@ class TestErrorHandling:
# Test non-string error # Test non-string error
with pytest.raises(ValueError) as exc_info: with pytest.raises(ValueError) as exc_info:
encrypter.decrypt_oauth_params(123) # type: ignore encrypter.decrypt_oauth_params(123)
assert "encrypted_data must be a string" in str(exc_info.value) assert "encrypted_data must be a string" in str(exc_info.value)
# Test invalid format error # Test invalid format error