mirror of https://github.com/langgenius/dify.git
rm type ignore (#25715)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
c11cdf7468
commit
32c715c4d0
|
|
@ -145,7 +145,7 @@ class DatabaseConfig(BaseSettings):
|
||||||
default="postgresql",
|
default="postgresql",
|
||||||
)
|
)
|
||||||
|
|
||||||
@computed_field # type: ignore[misc]
|
@computed_field # type: ignore[prop-decorator]
|
||||||
@property
|
@property
|
||||||
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
def SQLALCHEMY_DATABASE_URI(self) -> str:
|
||||||
db_extras = (
|
db_extras = (
|
||||||
|
|
@ -198,7 +198,7 @@ class DatabaseConfig(BaseSettings):
|
||||||
default=os.cpu_count() or 1,
|
default=os.cpu_count() or 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
@computed_field # type: ignore[misc]
|
@computed_field # type: ignore[prop-decorator]
|
||||||
@property
|
@property
|
||||||
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
|
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
|
||||||
# Parse DB_EXTRAS for 'options'
|
# Parse DB_EXTRAS for 'options'
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ except ImportError:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
|
warnings.warn("To use python-magic guess MIMETYPE, you need to install `libmagic`", stacklevel=2)
|
||||||
magic = None # type: ignore
|
magic = None # type: ignore[assignment]
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -211,8 +211,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
||||||
user=user,
|
user=user,
|
||||||
stream=streaming,
|
stream=streaming,
|
||||||
)
|
)
|
||||||
# FIXME: Type hinting issue here, ignore it for now, will fix it later
|
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||||
return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) # type: ignore
|
|
||||||
|
|
||||||
def _generate_worker(
|
def _generate_worker(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -89,7 +89,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||||
response_chunk.update(data)
|
response_chunk.update(data)
|
||||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict()) # ty: ignore [unresolved-attribute]
|
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||||
else:
|
else:
|
||||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||||
yield response_chunk
|
yield response_chunk
|
||||||
|
|
|
||||||
|
|
@ -98,7 +98,7 @@ class RateLimit:
|
||||||
else:
|
else:
|
||||||
return RateLimitGenerator(
|
return RateLimitGenerator(
|
||||||
rate_limit=self,
|
rate_limit=self,
|
||||||
generator=generator, # ty: ignore [invalid-argument-type]
|
generator=generator,
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ class BasedGenerateTaskPipeline:
|
||||||
if isinstance(e, InvokeAuthorizationError):
|
if isinstance(e, InvokeAuthorizationError):
|
||||||
err = InvokeAuthorizationError("Incorrect API key provided")
|
err = InvokeAuthorizationError("Incorrect API key provided")
|
||||||
elif isinstance(e, InvokeError | ValueError):
|
elif isinstance(e, InvokeError | ValueError):
|
||||||
err = e # ty: ignore [invalid-assignment]
|
err = e
|
||||||
else:
|
else:
|
||||||
description = getattr(e, "description", None)
|
description = getattr(e, "description", None)
|
||||||
err = Exception(description if description is not None else str(e))
|
err = Exception(description if description is not None else str(e))
|
||||||
|
|
|
||||||
|
|
@ -1868,7 +1868,7 @@ class ProviderConfigurations(BaseModel):
|
||||||
if "/" not in key:
|
if "/" not in key:
|
||||||
key = str(ModelProviderID(key))
|
key = str(ModelProviderID(key))
|
||||||
|
|
||||||
return self.configurations.get(key, default) # type: ignore
|
return self.configurations.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
class ProviderModelBundle(BaseModel):
|
class ProviderModelBundle(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ def import_module_from_source(*, module_name: str, py_file_path: AnyStr, use_laz
|
||||||
else:
|
else:
|
||||||
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
|
# Refer to: https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
|
||||||
# FIXME: mypy does not support the type of spec.loader
|
# FIXME: mypy does not support the type of spec.loader
|
||||||
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore
|
spec = importlib.util.spec_from_file_location(module_name, py_file_path) # type: ignore[assignment]
|
||||||
if not spec or not spec.loader:
|
if not spec or not spec.loader:
|
||||||
raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
|
raise Exception(f"Failed to load module {module_name} from {py_file_path!r}")
|
||||||
if use_lazy_loader:
|
if use_lazy_loader:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from langfuse import Langfuse # type: ignore
|
from langfuse import Langfuse
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.ops.base_trace_instance import BaseTraceInstance
|
from core.ops.base_trace_instance import BaseTraceInstance
|
||||||
|
|
|
||||||
|
|
@ -180,7 +180,7 @@ class BasePluginClient:
|
||||||
Make a request to the plugin daemon inner API and return the response as a model.
|
Make a request to the plugin daemon inner API and return the response as a model.
|
||||||
"""
|
"""
|
||||||
response = self._request(method, path, headers, data, params, files)
|
response = self._request(method, path, headers, data, params, files)
|
||||||
return type_(**response.json()) # type: ignore
|
return type_(**response.json()) # type: ignore[return-value]
|
||||||
|
|
||||||
def _request_with_plugin_daemon_response(
|
def _request_with_plugin_daemon_response(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,7 @@ class CeleryWorkflowExecutionRepository(WorkflowExecutionRepository):
|
||||||
tenant_id = extract_tenant_id(user)
|
tenant_id = extract_tenant_id(user)
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||||
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
|
self._tenant_id = tenant_id
|
||||||
|
|
||||||
# Store app context
|
# Store app context
|
||||||
self._app_id = app_id
|
self._app_id = app_id
|
||||||
|
|
|
||||||
|
|
@ -81,7 +81,7 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
|
||||||
tenant_id = extract_tenant_id(user)
|
tenant_id = extract_tenant_id(user)
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
raise ValueError("User must have a tenant_id or current_tenant_id")
|
raise ValueError("User must have a tenant_id or current_tenant_id")
|
||||||
self._tenant_id = tenant_id # type: ignore[assignment] # We've already checked tenant_id is not None
|
self._tenant_id = tenant_id
|
||||||
|
|
||||||
# Store app context
|
# Store app context
|
||||||
self._app_id = app_id
|
self._app_id = app_id
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,7 @@ class DifyCoreRepositoryFactory:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
repository_class = import_string(class_path)
|
repository_class = import_string(class_path)
|
||||||
return repository_class( # type: ignore[no-any-return]
|
return repository_class(
|
||||||
session_factory=session_factory,
|
session_factory=session_factory,
|
||||||
user=user,
|
user=user,
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
|
|
@ -96,7 +96,7 @@ class DifyCoreRepositoryFactory:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
repository_class = import_string(class_path)
|
repository_class = import_string(class_path)
|
||||||
return repository_class( # type: ignore[no-any-return]
|
return repository_class(
|
||||||
session_factory=session_factory,
|
session_factory=session_factory,
|
||||||
user=user,
|
user=user,
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
|
|
|
||||||
|
|
@ -157,7 +157,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||||
"""
|
"""
|
||||||
returns the tool that the provider can provide
|
returns the tool that the provider can provide
|
||||||
"""
|
"""
|
||||||
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None) # type: ignore
|
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def need_credentials(self) -> bool:
|
def need_credentials(self) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ class TTSTool(BuiltinTool):
|
||||||
content_text=tool_parameters.get("text"), # type: ignore
|
content_text=tool_parameters.get("text"), # type: ignore
|
||||||
user=user_id,
|
user=user_id,
|
||||||
tenant_id=self.runtime.tenant_id,
|
tenant_id=self.runtime.tenant_id,
|
||||||
voice=voice, # type: ignore
|
voice=voice,
|
||||||
)
|
)
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
for chunk in tts:
|
for chunk in tts:
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ class LocaltimeToTimestampTool(BuiltinTool):
|
||||||
|
|
||||||
yield self.create_text_message(f"{timestamp}")
|
yield self.create_text_message(f"{timestamp}")
|
||||||
|
|
||||||
|
# TODO: this method's type is messy
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
|
def localtime_to_timestamp(localtime: str, time_format: str, local_tz=None) -> int | None:
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,6 @@ class TimezoneConversionTool(BuiltinTool):
|
||||||
datetime_with_tz = input_timezone.localize(local_time)
|
datetime_with_tz = input_timezone.localize(local_time)
|
||||||
# timezone convert
|
# timezone convert
|
||||||
converted_datetime = datetime_with_tz.astimezone(output_timezone)
|
converted_datetime = datetime_with_tz.astimezone(output_timezone)
|
||||||
return converted_datetime.strftime(format=time_format) # type: ignore
|
return converted_datetime.strftime(time_format)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise ToolInvokeError(str(e))
|
raise ToolInvokeError(str(e))
|
||||||
|
|
|
||||||
|
|
@ -105,7 +105,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_tool(self, tool_name: str) -> MCPTool: # type: ignore
|
def get_tool(self, tool_name: str) -> MCPTool:
|
||||||
"""
|
"""
|
||||||
return tool with given name
|
return tool with given name
|
||||||
"""
|
"""
|
||||||
|
|
@ -128,7 +128,7 @@ class MCPToolProviderController(ToolProviderController):
|
||||||
sse_read_timeout=self.sse_read_timeout,
|
sse_read_timeout=self.sse_read_timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_tools(self) -> list[MCPTool]: # type: ignore
|
def get_tools(self) -> list[MCPTool]:
|
||||||
"""
|
"""
|
||||||
get all tools
|
get all tools
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ class ToolLabelManager:
|
||||||
labels = cls.filter_tool_labels(labels)
|
labels = cls.filter_tool_labels(labels)
|
||||||
|
|
||||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||||
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
|
provider_id = controller.provider_id
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported tool type")
|
raise ValueError("Unsupported tool type")
|
||||||
|
|
||||||
|
|
@ -51,7 +51,7 @@ class ToolLabelManager:
|
||||||
Get tool labels
|
Get tool labels
|
||||||
"""
|
"""
|
||||||
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
|
||||||
provider_id = controller.provider_id # ty: ignore [unresolved-attribute]
|
provider_id = controller.provider_id
|
||||||
elif isinstance(controller, BuiltinToolProviderController):
|
elif isinstance(controller, BuiltinToolProviderController):
|
||||||
return controller.tool_labels
|
return controller.tool_labels
|
||||||
else:
|
else:
|
||||||
|
|
@ -85,7 +85,7 @@ class ToolLabelManager:
|
||||||
provider_ids = []
|
provider_ids = []
|
||||||
for controller in tool_providers:
|
for controller in tool_providers:
|
||||||
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
|
assert isinstance(controller, ApiToolProviderController | WorkflowToolProviderController)
|
||||||
provider_ids.append(controller.provider_id) # ty: ignore [unresolved-attribute]
|
provider_ids.append(controller.provider_id)
|
||||||
|
|
||||||
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()
|
labels = db.session.scalars(select(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids))).all()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -193,18 +193,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||||
DatasetDocument.enabled == True,
|
DatasetDocument.enabled == True,
|
||||||
DatasetDocument.archived == False,
|
DatasetDocument.archived == False,
|
||||||
)
|
)
|
||||||
document = db.session.scalar(dataset_document_stmt) # type: ignore
|
document = db.session.scalar(dataset_document_stmt)
|
||||||
if dataset and document:
|
if dataset and document:
|
||||||
source = RetrievalSourceMetadata(
|
source = RetrievalSourceMetadata(
|
||||||
dataset_id=dataset.id,
|
dataset_id=dataset.id,
|
||||||
dataset_name=dataset.name,
|
dataset_name=dataset.name,
|
||||||
document_id=document.id, # type: ignore
|
document_id=document.id,
|
||||||
document_name=document.name, # type: ignore
|
document_name=document.name,
|
||||||
data_source_type=document.data_source_type, # type: ignore
|
data_source_type=document.data_source_type,
|
||||||
segment_id=segment.id,
|
segment_id=segment.id,
|
||||||
retriever_from=self.retriever_from,
|
retriever_from=self.retriever_from,
|
||||||
score=record.score or 0.0,
|
score=record.score or 0.0,
|
||||||
doc_metadata=document.doc_metadata, # type: ignore
|
doc_metadata=document.doc_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.retriever_from == "dev":
|
if self.retriever_from == "dev":
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,8 @@ from typing import Any, cast
|
||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
import chardet
|
import chardet
|
||||||
import cloudscraper # type: ignore
|
import cloudscraper
|
||||||
from readabilipy import simple_json_from_html_string # type: ignore
|
from readabilipy import simple_json_from_html_string
|
||||||
|
|
||||||
from core.helper import ssrf_proxy
|
from core.helper import ssrf_proxy
|
||||||
from core.rag.extractor import extract_processor
|
from core.rag.extractor import extract_processor
|
||||||
|
|
@ -63,8 +63,8 @@ def get_url(url: str, user_agent: str | None = None) -> str:
|
||||||
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
||||||
elif response.status_code == 403:
|
elif response.status_code == 403:
|
||||||
scraper = cloudscraper.create_scraper()
|
scraper = cloudscraper.create_scraper()
|
||||||
scraper.perform_request = ssrf_proxy.make_request # type: ignore
|
scraper.perform_request = ssrf_proxy.make_request
|
||||||
response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore
|
response = scraper.get(url, headers=headers, timeout=(120, 300))
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
return f"URL returned status code {response.status_code}."
|
return f"URL returned status code {response.status_code}."
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import yaml # type: ignore
|
import yaml
|
||||||
from yaml import YAMLError
|
from yaml import YAMLError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
|
||||||
|
|
@ -99,7 +99,7 @@ class WorkflowToolProviderController(ToolProviderController):
|
||||||
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
|
||||||
|
|
||||||
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
|
def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
|
||||||
return next(filter(lambda x: x.variable == variable_name, variables), None) # type: ignore
|
return next(filter(lambda x: x.variable == variable_name, variables), None)
|
||||||
|
|
||||||
user = db_provider.user
|
user = db_provider.user
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from .types import SegmentType
|
||||||
|
|
||||||
class SegmentGroup(Segment):
|
class SegmentGroup(Segment):
|
||||||
value_type: SegmentType = SegmentType.GROUP
|
value_type: SegmentType = SegmentType.GROUP
|
||||||
value: list[Segment] = None # type: ignore
|
value: list[Segment]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self):
|
def text(self):
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ class Segment(BaseModel):
|
||||||
model_config = ConfigDict(frozen=True)
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
value_type: SegmentType
|
value_type: SegmentType
|
||||||
value: Any = None
|
value: Any
|
||||||
|
|
||||||
@field_validator("value_type")
|
@field_validator("value_type")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -74,12 +74,12 @@ class NoneSegment(Segment):
|
||||||
|
|
||||||
class StringSegment(Segment):
|
class StringSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.STRING
|
value_type: SegmentType = SegmentType.STRING
|
||||||
value: str = None # type: ignore
|
value: str
|
||||||
|
|
||||||
|
|
||||||
class FloatSegment(Segment):
|
class FloatSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.FLOAT
|
value_type: SegmentType = SegmentType.FLOAT
|
||||||
value: float = None # type: ignore
|
value: float
|
||||||
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
|
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
|
||||||
# The following tests cannot pass.
|
# The following tests cannot pass.
|
||||||
#
|
#
|
||||||
|
|
@ -98,12 +98,12 @@ class FloatSegment(Segment):
|
||||||
|
|
||||||
class IntegerSegment(Segment):
|
class IntegerSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.INTEGER
|
value_type: SegmentType = SegmentType.INTEGER
|
||||||
value: int = None # type: ignore
|
value: int
|
||||||
|
|
||||||
|
|
||||||
class ObjectSegment(Segment):
|
class ObjectSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.OBJECT
|
value_type: SegmentType = SegmentType.OBJECT
|
||||||
value: Mapping[str, Any] = None # type: ignore
|
value: Mapping[str, Any]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self) -> str:
|
def text(self) -> str:
|
||||||
|
|
@ -136,7 +136,7 @@ class ArraySegment(Segment):
|
||||||
|
|
||||||
class FileSegment(Segment):
|
class FileSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.FILE
|
value_type: SegmentType = SegmentType.FILE
|
||||||
value: File = None # type: ignore
|
value: File
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def markdown(self) -> str:
|
def markdown(self) -> str:
|
||||||
|
|
@ -153,17 +153,17 @@ class FileSegment(Segment):
|
||||||
|
|
||||||
class BooleanSegment(Segment):
|
class BooleanSegment(Segment):
|
||||||
value_type: SegmentType = SegmentType.BOOLEAN
|
value_type: SegmentType = SegmentType.BOOLEAN
|
||||||
value: bool = None # type: ignore
|
value: bool
|
||||||
|
|
||||||
|
|
||||||
class ArrayAnySegment(ArraySegment):
|
class ArrayAnySegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_ANY
|
value_type: SegmentType = SegmentType.ARRAY_ANY
|
||||||
value: Sequence[Any] = None # type: ignore
|
value: Sequence[Any]
|
||||||
|
|
||||||
|
|
||||||
class ArrayStringSegment(ArraySegment):
|
class ArrayStringSegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_STRING
|
value_type: SegmentType = SegmentType.ARRAY_STRING
|
||||||
value: Sequence[str] = None # type: ignore
|
value: Sequence[str]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def text(self) -> str:
|
def text(self) -> str:
|
||||||
|
|
@ -175,17 +175,17 @@ class ArrayStringSegment(ArraySegment):
|
||||||
|
|
||||||
class ArrayNumberSegment(ArraySegment):
|
class ArrayNumberSegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_NUMBER
|
value_type: SegmentType = SegmentType.ARRAY_NUMBER
|
||||||
value: Sequence[float | int] = None # type: ignore
|
value: Sequence[float | int]
|
||||||
|
|
||||||
|
|
||||||
class ArrayObjectSegment(ArraySegment):
|
class ArrayObjectSegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
value_type: SegmentType = SegmentType.ARRAY_OBJECT
|
||||||
value: Sequence[Mapping[str, Any]] = None # type: ignore
|
value: Sequence[Mapping[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
class ArrayFileSegment(ArraySegment):
|
class ArrayFileSegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_FILE
|
value_type: SegmentType = SegmentType.ARRAY_FILE
|
||||||
value: Sequence[File] = None # type: ignore
|
value: Sequence[File]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def markdown(self) -> str:
|
def markdown(self) -> str:
|
||||||
|
|
@ -205,7 +205,7 @@ class ArrayFileSegment(ArraySegment):
|
||||||
|
|
||||||
class ArrayBooleanSegment(ArraySegment):
|
class ArrayBooleanSegment(ArraySegment):
|
||||||
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
|
value_type: SegmentType = SegmentType.ARRAY_BOOLEAN
|
||||||
value: Sequence[bool] = None # type: ignore
|
value: Sequence[bool]
|
||||||
|
|
||||||
|
|
||||||
def get_segment_discriminator(v: Any) -> SegmentType | None:
|
def get_segment_discriminator(v: Any) -> SegmentType | None:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import json
|
import json
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from builtins import type as type_
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
@ -58,10 +59,9 @@ class DefaultValue(BaseModel):
|
||||||
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
|
raise DefaultValueTypeError(f"Invalid JSON format for value: {value}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _validate_array(value: Any, element_type: DefaultValueType) -> bool:
|
def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool:
|
||||||
"""Unified array type validation"""
|
"""Unified array type validation"""
|
||||||
# FIXME, type ignore here for do not find the reason mypy complain, if find the root cause, please fix it
|
return isinstance(value, list) and all(isinstance(x, element_type) for x in value)
|
||||||
return isinstance(value, list) and all(isinstance(x, element_type) for x in value) # type: ignore
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_number(value: str) -> float:
|
def _convert_number(value: str) -> float:
|
||||||
|
|
|
||||||
|
|
@ -10,10 +10,10 @@ from typing import Any
|
||||||
import chardet
|
import chardet
|
||||||
import docx
|
import docx
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pypandoc # type: ignore
|
import pypandoc
|
||||||
import pypdfium2 # type: ignore
|
import pypdfium2
|
||||||
import webvtt # type: ignore
|
import webvtt
|
||||||
import yaml # type: ignore
|
import yaml
|
||||||
from docx.document import Document
|
from docx.document import Document
|
||||||
from docx.oxml.table import CT_Tbl
|
from docx.oxml.table import CT_Tbl
|
||||||
from docx.oxml.text.paragraph import CT_P
|
from docx.oxml.text.paragraph import CT_P
|
||||||
|
|
|
||||||
|
|
@ -141,7 +141,7 @@ class KnowledgeRetrievalNode(Node):
|
||||||
def version(cls):
|
def version(cls):
|
||||||
return "1"
|
return "1"
|
||||||
|
|
||||||
def _run(self) -> NodeRunResult: # type: ignore
|
def _run(self) -> NodeRunResult:
|
||||||
# extract variables
|
# extract variables
|
||||||
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
|
variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
|
||||||
if not isinstance(variable, StringSegment):
|
if not isinstance(variable, StringSegment):
|
||||||
|
|
@ -443,7 +443,7 @@ class KnowledgeRetrievalNode(Node):
|
||||||
metadata_condition = MetadataCondition(
|
metadata_condition = MetadataCondition(
|
||||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator
|
logical_operator=node_data.metadata_filtering_conditions.logical_operator
|
||||||
if node_data.metadata_filtering_conditions
|
if node_data.metadata_filtering_conditions
|
||||||
else "or", # type: ignore
|
else "or",
|
||||||
conditions=conditions,
|
conditions=conditions,
|
||||||
)
|
)
|
||||||
elif node_data.metadata_filtering_mode == "manual":
|
elif node_data.metadata_filtering_mode == "manual":
|
||||||
|
|
@ -457,10 +457,10 @@ class KnowledgeRetrievalNode(Node):
|
||||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||||
expected_value
|
expected_value
|
||||||
).value[0]
|
).value[0]
|
||||||
if expected_value.value_type in {"number", "integer", "float"}: # type: ignore
|
if expected_value.value_type in {"number", "integer", "float"}:
|
||||||
expected_value = expected_value.value # type: ignore
|
expected_value = expected_value.value
|
||||||
elif expected_value.value_type == "string": # type: ignore
|
elif expected_value.value_type == "string":
|
||||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
|
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid expected metadata value type")
|
raise ValueError("Invalid expected metadata value type")
|
||||||
conditions.append(
|
conditions.append(
|
||||||
|
|
@ -487,7 +487,7 @@ class KnowledgeRetrievalNode(Node):
|
||||||
if (
|
if (
|
||||||
node_data.metadata_filtering_conditions
|
node_data.metadata_filtering_conditions
|
||||||
and node_data.metadata_filtering_conditions.logical_operator == "and"
|
and node_data.metadata_filtering_conditions.logical_operator == "and"
|
||||||
): # type: ignore
|
):
|
||||||
document_query = document_query.where(and_(*filters))
|
document_query = document_query.where(and_(*filters))
|
||||||
else:
|
else:
|
||||||
document_query = document_query.where(or_(*filters))
|
document_query = document_query.where(or_(*filters))
|
||||||
|
|
|
||||||
|
|
@ -260,7 +260,7 @@ class VariablePool(BaseModel):
|
||||||
# This ensures that we can keep the id of the system variables intact.
|
# This ensures that we can keep the id of the system variables intact.
|
||||||
if self._has(selector):
|
if self._has(selector):
|
||||||
continue
|
continue
|
||||||
self.add(selector, value) # type: ignore
|
self.add(selector, value)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def empty(cls) -> "VariablePool":
|
def empty(cls) -> "VariablePool":
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ def is_enabled() -> bool:
|
||||||
|
|
||||||
|
|
||||||
def init_app(app: DifyApp):
|
def init_app(app: DifyApp):
|
||||||
from flask_compress import Compress # type: ignore
|
from flask_compress import Compress
|
||||||
|
|
||||||
compress = Compress()
|
compress = Compress()
|
||||||
compress.init_app(app)
|
compress.init_app(app)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import json
|
import json
|
||||||
|
|
||||||
import flask_login # type: ignore
|
import flask_login
|
||||||
from flask import Response, request
|
from flask import Response, request
|
||||||
from flask_login import user_loaded_from_request, user_logged_in
|
from flask_login import user_loaded_from_request, user_logged_in
|
||||||
from werkzeug.exceptions import NotFound, Unauthorized
|
from werkzeug.exceptions import NotFound, Unauthorized
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ from dify_app import DifyApp
|
||||||
|
|
||||||
|
|
||||||
def init_app(app: DifyApp):
|
def init_app(app: DifyApp):
|
||||||
import flask_migrate # type: ignore
|
import flask_migrate
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -103,7 +103,7 @@ def init_app(app: DifyApp):
|
||||||
def shutdown_tracer():
|
def shutdown_tracer():
|
||||||
provider = trace.get_tracer_provider()
|
provider = trace.get_tracer_provider()
|
||||||
if hasattr(provider, "force_flush"):
|
if hasattr(provider, "force_flush"):
|
||||||
provider.force_flush() # ty: ignore [call-non-callable]
|
provider.force_flush()
|
||||||
|
|
||||||
class ExceptionLoggingHandler(logging.Handler):
|
class ExceptionLoggingHandler(logging.Handler):
|
||||||
"""Custom logging handler that creates spans for logging.exception() calls"""
|
"""Custom logging handler that creates spans for logging.exception() calls"""
|
||||||
|
|
|
||||||
|
|
@ -6,4 +6,4 @@ def init_app(app: DifyApp):
|
||||||
if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED:
|
if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED:
|
||||||
from werkzeug.middleware.proxy_fix import ProxyFix
|
from werkzeug.middleware.proxy_fix import ProxyFix
|
||||||
|
|
||||||
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore
|
app.wsgi_app = ProxyFix(app.wsgi_app, x_port=1) # type: ignore[method-assign]
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from dify_app import DifyApp
|
||||||
def init_app(app: DifyApp):
|
def init_app(app: DifyApp):
|
||||||
if dify_config.SENTRY_DSN:
|
if dify_config.SENTRY_DSN:
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
from langfuse import parse_error # type: ignore
|
from langfuse import parse_error
|
||||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||||
from sentry_sdk.integrations.flask import FlaskIntegration
|
from sentry_sdk.integrations.flask import FlaskIntegration
|
||||||
from werkzeug.exceptions import HTTPException
|
from werkzeug.exceptions import HTTPException
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import posixpath
|
import posixpath
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
import oss2 as aliyun_s3 # type: ignore
|
import oss2 as aliyun_s3
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from extensions.storage.base_storage import BaseStorage
|
from extensions.storage.base_storage import BaseStorage
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,9 @@ import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
from baidubce.auth.bce_credentials import BceCredentials # type: ignore
|
from baidubce.auth.bce_credentials import BceCredentials
|
||||||
from baidubce.bce_client_configuration import BceClientConfiguration # type: ignore
|
from baidubce.bce_client_configuration import BceClientConfiguration
|
||||||
from baidubce.services.bos.bos_client import BosClient # type: ignore
|
from baidubce.services.bos.bos_client import BosClient
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from extensions.storage.base_storage import BaseStorage
|
from extensions.storage.base_storage import BaseStorage
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from collections.abc import Generator
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import clickzetta # type: ignore[import]
|
import clickzetta
|
||||||
from pydantic import BaseModel, model_validator
|
from pydantic import BaseModel, model_validator
|
||||||
|
|
||||||
from extensions.storage.base_storage import BaseStorage
|
from extensions.storage.base_storage import BaseStorage
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ class VolumePermissionManager:
|
||||||
# Support two initialization methods: connection object or configuration dictionary
|
# Support two initialization methods: connection object or configuration dictionary
|
||||||
if isinstance(connection_or_config, dict):
|
if isinstance(connection_or_config, dict):
|
||||||
# Create connection from configuration dictionary
|
# Create connection from configuration dictionary
|
||||||
import clickzetta # type: ignore[import-untyped]
|
import clickzetta
|
||||||
|
|
||||||
config = connection_or_config
|
config = connection_or_config
|
||||||
self._connection = clickzetta.connect(
|
self._connection = clickzetta.connect(
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import io
|
||||||
import json
|
import json
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
from google.cloud import storage as google_cloud_storage # type: ignore
|
from google.cloud import storage as google_cloud_storage
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from extensions.storage.base_storage import BaseStorage
|
from extensions.storage.base_storage import BaseStorage
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
from obs import ObsClient # type: ignore
|
from obs import ObsClient
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from extensions.storage.base_storage import BaseStorage
|
from extensions.storage.base_storage import BaseStorage
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
import boto3 # type: ignore
|
import boto3
|
||||||
from botocore.exceptions import ClientError # type: ignore
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from extensions.storage.base_storage import BaseStorage
|
from extensions.storage.base_storage import BaseStorage
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
from qcloud_cos import CosConfig, CosS3Client # type: ignore
|
from qcloud_cos import CosConfig, CosS3Client
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from extensions.storage.base_storage import BaseStorage
|
from extensions.storage.base_storage import BaseStorage
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
import tos # type: ignore
|
import tos
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from extensions.storage.base_storage import BaseStorage
|
from extensions.storage.base_storage import BaseStorage
|
||||||
|
|
|
||||||
|
|
@ -146,6 +146,6 @@ class ExternalApi(Api):
|
||||||
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False
|
kwargs["doc"] = dify_config.SWAGGER_UI_PATH if dify_config.SWAGGER_UI_ENABLED else False
|
||||||
|
|
||||||
# manual separate call on construction and init_app to ensure configs in kwargs effective
|
# manual separate call on construction and init_app to ensure configs in kwargs effective
|
||||||
super().__init__(app=None, *args, **kwargs) # type: ignore
|
super().__init__(app=None, *args, **kwargs)
|
||||||
self.init_app(app, **kwargs)
|
self.init_app(app, **kwargs)
|
||||||
register_external_error_handlers(self)
|
register_external_error_handlers(self)
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ from hashlib import sha1
|
||||||
|
|
||||||
import Crypto.Hash.SHA1
|
import Crypto.Hash.SHA1
|
||||||
import Crypto.Util.number
|
import Crypto.Util.number
|
||||||
import gmpy2 # type: ignore
|
import gmpy2
|
||||||
from Crypto import Random
|
from Crypto import Random
|
||||||
from Crypto.Signature.pss import MGF1
|
from Crypto.Signature.pss import MGF1
|
||||||
from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
|
from Crypto.Util.number import bytes_to_long, ceil_div, long_to_bytes
|
||||||
|
|
@ -136,7 +136,7 @@ class PKCS1OAepCipher:
|
||||||
# Step 3a (OS2IP)
|
# Step 3a (OS2IP)
|
||||||
em_int = bytes_to_long(em)
|
em_int = bytes_to_long(em)
|
||||||
# Step 3b (RSAEP)
|
# Step 3b (RSAEP)
|
||||||
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) # ty: ignore [unresolved-attribute]
|
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n)
|
||||||
# Step 3c (I2OSP)
|
# Step 3c (I2OSP)
|
||||||
c = long_to_bytes(m_int, k)
|
c = long_to_bytes(m_int, k)
|
||||||
return c
|
return c
|
||||||
|
|
@ -169,7 +169,7 @@ class PKCS1OAepCipher:
|
||||||
ct_int = bytes_to_long(ciphertext)
|
ct_int = bytes_to_long(ciphertext)
|
||||||
# Step 2b (RSADP)
|
# Step 2b (RSADP)
|
||||||
# m_int = self._key._decrypt(ct_int)
|
# m_int = self._key._decrypt(ct_int)
|
||||||
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # ty: ignore [unresolved-attribute]
|
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n)
|
||||||
# Complete step 2c (I2OSP)
|
# Complete step 2c (I2OSP)
|
||||||
em = long_to_bytes(m_int, k)
|
em = long_to_bytes(m_int, k)
|
||||||
# Step 3a
|
# Step 3a
|
||||||
|
|
@ -191,12 +191,12 @@ class PKCS1OAepCipher:
|
||||||
# Step 3g
|
# Step 3g
|
||||||
one_pos = hLen + db[hLen:].find(b"\x01")
|
one_pos = hLen + db[hLen:].find(b"\x01")
|
||||||
lHash1 = db[:hLen]
|
lHash1 = db[:hLen]
|
||||||
invalid = bord(y) | int(one_pos < hLen) # type: ignore
|
invalid = bord(y) | int(one_pos < hLen) # type: ignore[arg-type]
|
||||||
hash_compare = strxor(lHash1, lHash)
|
hash_compare = strxor(lHash1, lHash)
|
||||||
for x in hash_compare:
|
for x in hash_compare:
|
||||||
invalid |= bord(x) # type: ignore
|
invalid |= bord(x) # type: ignore[arg-type]
|
||||||
for x in db[hLen:one_pos]:
|
for x in db[hLen:one_pos]:
|
||||||
invalid |= bord(x) # type: ignore
|
invalid |= bord(x) # type: ignore[arg-type]
|
||||||
if invalid != 0:
|
if invalid != 0:
|
||||||
raise ValueError("Incorrect decryption.")
|
raise ValueError("Incorrect decryption.")
|
||||||
# Step 4
|
# Step 4
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from functools import wraps
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from flask import current_app, g, has_request_context, request
|
from flask import current_app, g, has_request_context, request
|
||||||
from flask_login.config import EXEMPT_METHODS # type: ignore
|
from flask_login.config import EXEMPT_METHODS
|
||||||
from werkzeug.local import LocalProxy
|
from werkzeug.local import LocalProxy
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
|
@ -87,7 +87,7 @@ def _get_user() -> EndUser | Account | None:
|
||||||
if "_login_user" not in g:
|
if "_login_user" not in g:
|
||||||
current_app.login_manager._load_user() # type: ignore
|
current_app.login_manager._load_user() # type: ignore
|
||||||
|
|
||||||
return g._login_user # type: ignore
|
return g._login_user
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import sendgrid # type: ignore
|
import sendgrid
|
||||||
from python_http_client.exceptions import ForbiddenError, UnauthorizedError
|
from python_http_client.exceptions import ForbiddenError, UnauthorizedError
|
||||||
from sendgrid.helpers.mail import Content, Email, Mail, To # type: ignore
|
from sendgrid.helpers.mail import Content, Email, Mail, To
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from datetime import datetime
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from flask_login import UserMixin # type: ignore[import-untyped]
|
from flask_login import UserMixin
|
||||||
from sqlalchemy import DateTime, String, func, select
|
from sqlalchemy import DateTime, String, func, select
|
||||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask_login import UserMixin # type: ignore[import-untyped]
|
from flask_login import UserMixin
|
||||||
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
|
from sqlalchemy import Float, Index, PrimaryKeyConstraint, String, exists, func, select, text
|
||||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,25 @@
|
||||||
"opentelemetry.instrumentation.requests",
|
"opentelemetry.instrumentation.requests",
|
||||||
"opentelemetry.instrumentation.sqlalchemy",
|
"opentelemetry.instrumentation.sqlalchemy",
|
||||||
"opentelemetry.instrumentation.redis",
|
"opentelemetry.instrumentation.redis",
|
||||||
"opentelemetry.instrumentation.httpx"
|
"langfuse",
|
||||||
|
"cloudscraper",
|
||||||
|
"readabilipy",
|
||||||
|
"pypandoc",
|
||||||
|
"pypdfium2",
|
||||||
|
"webvtt",
|
||||||
|
"flask_compress",
|
||||||
|
"oss2",
|
||||||
|
"baidubce.auth.bce_credentials",
|
||||||
|
"baidubce.bce_client_configuration",
|
||||||
|
"baidubce.services.bos.bos_client",
|
||||||
|
"clickzetta",
|
||||||
|
"google.cloud",
|
||||||
|
"obs",
|
||||||
|
"qcloud_cos",
|
||||||
|
"tos",
|
||||||
|
"gmpy2",
|
||||||
|
"sendgrid",
|
||||||
|
"sendgrid.helpers.mail"
|
||||||
],
|
],
|
||||||
"reportUnknownMemberType": "hint",
|
"reportUnknownMemberType": "hint",
|
||||||
"reportUnknownParameterType": "hint",
|
"reportUnknownParameterType": "hint",
|
||||||
|
|
@ -28,7 +46,7 @@
|
||||||
"reportUnnecessaryComparison": "hint",
|
"reportUnnecessaryComparison": "hint",
|
||||||
"reportUnnecessaryIsInstance": "hint",
|
"reportUnnecessaryIsInstance": "hint",
|
||||||
"reportUntypedFunctionDecorator": "hint",
|
"reportUntypedFunctionDecorator": "hint",
|
||||||
|
"reportUnnecessaryTypeIgnoreComment": "hint",
|
||||||
"reportAttributeAccessIssue": "hint",
|
"reportAttributeAccessIssue": "hint",
|
||||||
"pythonVersion": "3.11",
|
"pythonVersion": "3.11",
|
||||||
"pythonPlatform": "All"
|
"pythonPlatform": "All"
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
repository_class = import_string(class_path)
|
repository_class = import_string(class_path)
|
||||||
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
|
return repository_class(session_maker=session_maker)
|
||||||
except (ImportError, Exception) as e:
|
except (ImportError, Exception) as e:
|
||||||
raise RepositoryImportError(
|
raise RepositoryImportError(
|
||||||
f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}"
|
f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}"
|
||||||
|
|
@ -77,6 +77,6 @@ class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
repository_class = import_string(class_path)
|
repository_class = import_string(class_path)
|
||||||
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
|
return repository_class(session_maker=session_maker)
|
||||||
except (ImportError, Exception) as e:
|
except (ImportError, Exception) as e:
|
||||||
raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e
|
raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from enum import StrEnum
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import yaml # type: ignore
|
import yaml
|
||||||
from Crypto.Cipher import AES
|
from Crypto.Cipher import AES
|
||||||
from Crypto.Util.Padding import pad, unpad
|
from Crypto.Util.Padding import pad, unpad
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
@ -563,7 +563,7 @@ class AppDslService:
|
||||||
else:
|
else:
|
||||||
cls._append_model_config_export_data(export_data, app_model)
|
cls._append_model_config_export_data(export_data, app_model)
|
||||||
|
|
||||||
return yaml.dump(export_data, allow_unicode=True) # type: ignore
|
return yaml.dump(export_data, allow_unicode=True)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _append_workflow_export_data(
|
def _append_workflow_export_data(
|
||||||
|
|
|
||||||
|
|
@ -241,9 +241,9 @@ class DatasetService:
|
||||||
dataset.created_by = account.id
|
dataset.created_by = account.id
|
||||||
dataset.updated_by = account.id
|
dataset.updated_by = account.id
|
||||||
dataset.tenant_id = tenant_id
|
dataset.tenant_id = tenant_id
|
||||||
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None # type: ignore
|
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
|
||||||
dataset.embedding_model = embedding_model.model if embedding_model else None # type: ignore
|
dataset.embedding_model = embedding_model.model if embedding_model else None
|
||||||
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None # type: ignore
|
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None
|
||||||
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
|
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
|
||||||
dataset.provider = provider
|
dataset.provider = provider
|
||||||
db.session.add(dataset)
|
db.session.add(dataset)
|
||||||
|
|
@ -1416,6 +1416,8 @@ class DocumentService:
|
||||||
# check document limit
|
# check document limit
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
assert current_user.current_tenant_id is not None
|
assert current_user.current_tenant_id is not None
|
||||||
|
assert knowledge_config.data_source
|
||||||
|
assert knowledge_config.data_source.info_list.file_info_list
|
||||||
|
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
|
|
||||||
|
|
@ -1424,15 +1426,16 @@ class DocumentService:
|
||||||
count = 0
|
count = 0
|
||||||
if knowledge_config.data_source:
|
if knowledge_config.data_source:
|
||||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||||
count = len(upload_file_list)
|
count = len(upload_file_list)
|
||||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
notion_info_list = knowledge_config.data_source.info_list.notion_info_list or []
|
||||||
for notion_info in notion_info_list: # type: ignore
|
for notion_info in notion_info_list:
|
||||||
count = count + len(notion_info.pages)
|
count = count + len(notion_info.pages)
|
||||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||||
count = len(website_info.urls) # type: ignore
|
assert website_info
|
||||||
|
count = len(website_info.urls)
|
||||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||||
|
|
||||||
if features.billing.subscription.plan == "sandbox" and count > 1:
|
if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||||
|
|
@ -1444,7 +1447,7 @@ class DocumentService:
|
||||||
|
|
||||||
# if dataset is empty, update dataset data_source_type
|
# if dataset is empty, update dataset data_source_type
|
||||||
if not dataset.data_source_type:
|
if not dataset.data_source_type:
|
||||||
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
|
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type
|
||||||
|
|
||||||
if not dataset.indexing_technique:
|
if not dataset.indexing_technique:
|
||||||
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
|
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
|
||||||
|
|
@ -1481,7 +1484,7 @@ class DocumentService:
|
||||||
knowledge_config.retrieval_model.model_dump()
|
knowledge_config.retrieval_model.model_dump()
|
||||||
if knowledge_config.retrieval_model
|
if knowledge_config.retrieval_model
|
||||||
else default_retrieval_model
|
else default_retrieval_model
|
||||||
) # type: ignore
|
)
|
||||||
|
|
||||||
documents = []
|
documents = []
|
||||||
if knowledge_config.original_document_id:
|
if knowledge_config.original_document_id:
|
||||||
|
|
@ -1523,11 +1526,12 @@ class DocumentService:
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
|
lock_name = f"add_document_lock_dataset_id_{dataset.id}"
|
||||||
with redis_client.lock(lock_name, timeout=600):
|
with redis_client.lock(lock_name, timeout=600):
|
||||||
|
assert dataset_process_rule
|
||||||
position = DocumentService.get_documents_position(dataset.id)
|
position = DocumentService.get_documents_position(dataset.id)
|
||||||
document_ids = []
|
document_ids = []
|
||||||
duplicate_document_ids = []
|
duplicate_document_ids = []
|
||||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
|
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||||
for file_id in upload_file_list:
|
for file_id in upload_file_list:
|
||||||
file = (
|
file = (
|
||||||
db.session.query(UploadFile)
|
db.session.query(UploadFile)
|
||||||
|
|
@ -1540,7 +1544,7 @@ class DocumentService:
|
||||||
raise FileNotExistsError()
|
raise FileNotExistsError()
|
||||||
|
|
||||||
file_name = file.name
|
file_name = file.name
|
||||||
data_source_info = {
|
data_source_info: dict[str, str | bool] = {
|
||||||
"upload_file_id": file_id,
|
"upload_file_id": file_id,
|
||||||
}
|
}
|
||||||
# check duplicate
|
# check duplicate
|
||||||
|
|
@ -1557,7 +1561,7 @@ class DocumentService:
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
if document:
|
if document:
|
||||||
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
|
document.dataset_process_rule_id = dataset_process_rule.id
|
||||||
document.updated_at = naive_utc_now()
|
document.updated_at = naive_utc_now()
|
||||||
document.created_from = created_from
|
document.created_from = created_from
|
||||||
document.doc_form = knowledge_config.doc_form
|
document.doc_form = knowledge_config.doc_form
|
||||||
|
|
@ -1571,8 +1575,8 @@ class DocumentService:
|
||||||
continue
|
continue
|
||||||
document = DocumentService.build_document(
|
document = DocumentService.build_document(
|
||||||
dataset,
|
dataset,
|
||||||
dataset_process_rule.id, # type: ignore
|
dataset_process_rule.id,
|
||||||
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
knowledge_config.data_source.info_list.data_source_type,
|
||||||
knowledge_config.doc_form,
|
knowledge_config.doc_form,
|
||||||
knowledge_config.doc_language,
|
knowledge_config.doc_language,
|
||||||
data_source_info,
|
data_source_info,
|
||||||
|
|
@ -1587,7 +1591,7 @@ class DocumentService:
|
||||||
document_ids.append(document.id)
|
document_ids.append(document.id)
|
||||||
documents.append(document)
|
documents.append(document)
|
||||||
position += 1
|
position += 1
|
||||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
|
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
|
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
|
||||||
if not notion_info_list:
|
if not notion_info_list:
|
||||||
raise ValueError("No notion info list found.")
|
raise ValueError("No notion info list found.")
|
||||||
|
|
@ -1616,15 +1620,15 @@ class DocumentService:
|
||||||
"credential_id": notion_info.credential_id,
|
"credential_id": notion_info.credential_id,
|
||||||
"notion_workspace_id": workspace_id,
|
"notion_workspace_id": workspace_id,
|
||||||
"notion_page_id": page.page_id,
|
"notion_page_id": page.page_id,
|
||||||
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None,
|
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore
|
||||||
"type": page.type,
|
"type": page.type,
|
||||||
}
|
}
|
||||||
# Truncate page name to 255 characters to prevent DB field length errors
|
# Truncate page name to 255 characters to prevent DB field length errors
|
||||||
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
|
truncated_page_name = page.page_name[:255] if page.page_name else "nopagename"
|
||||||
document = DocumentService.build_document(
|
document = DocumentService.build_document(
|
||||||
dataset,
|
dataset,
|
||||||
dataset_process_rule.id, # type: ignore
|
dataset_process_rule.id,
|
||||||
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
knowledge_config.data_source.info_list.data_source_type,
|
||||||
knowledge_config.doc_form,
|
knowledge_config.doc_form,
|
||||||
knowledge_config.doc_language,
|
knowledge_config.doc_language,
|
||||||
data_source_info,
|
data_source_info,
|
||||||
|
|
@ -1644,8 +1648,8 @@ class DocumentService:
|
||||||
# delete not selected documents
|
# delete not selected documents
|
||||||
if len(exist_document) > 0:
|
if len(exist_document) > 0:
|
||||||
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
||||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
|
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||||
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
|
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||||
if not website_info:
|
if not website_info:
|
||||||
raise ValueError("No website info list found.")
|
raise ValueError("No website info list found.")
|
||||||
urls = website_info.urls
|
urls = website_info.urls
|
||||||
|
|
@ -1663,8 +1667,8 @@ class DocumentService:
|
||||||
document_name = url
|
document_name = url
|
||||||
document = DocumentService.build_document(
|
document = DocumentService.build_document(
|
||||||
dataset,
|
dataset,
|
||||||
dataset_process_rule.id, # type: ignore
|
dataset_process_rule.id,
|
||||||
knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
knowledge_config.data_source.info_list.data_source_type,
|
||||||
knowledge_config.doc_form,
|
knowledge_config.doc_form,
|
||||||
knowledge_config.doc_language,
|
knowledge_config.doc_language,
|
||||||
data_source_info,
|
data_source_info,
|
||||||
|
|
@ -2071,7 +2075,7 @@ class DocumentService:
|
||||||
# update document data source
|
# update document data source
|
||||||
if document_data.data_source:
|
if document_data.data_source:
|
||||||
file_name = ""
|
file_name = ""
|
||||||
data_source_info = {}
|
data_source_info: dict[str, str | bool] = {}
|
||||||
if document_data.data_source.info_list.data_source_type == "upload_file":
|
if document_data.data_source.info_list.data_source_type == "upload_file":
|
||||||
if not document_data.data_source.info_list.file_info_list:
|
if not document_data.data_source.info_list.file_info_list:
|
||||||
raise ValueError("No file info list found.")
|
raise ValueError("No file info list found.")
|
||||||
|
|
@ -2128,7 +2132,7 @@ class DocumentService:
|
||||||
"url": url,
|
"url": url,
|
||||||
"provider": website_info.provider,
|
"provider": website_info.provider,
|
||||||
"job_id": website_info.job_id,
|
"job_id": website_info.job_id,
|
||||||
"only_main_content": website_info.only_main_content, # type: ignore
|
"only_main_content": website_info.only_main_content,
|
||||||
"mode": "crawl",
|
"mode": "crawl",
|
||||||
}
|
}
|
||||||
document.data_source_type = document_data.data_source.info_list.data_source_type
|
document.data_source_type = document_data.data_source.info_list.data_source_type
|
||||||
|
|
@ -2154,7 +2158,7 @@ class DocumentService:
|
||||||
|
|
||||||
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(
|
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(
|
||||||
{DocumentSegment.status: "re_segment"}
|
{DocumentSegment.status: "re_segment"}
|
||||||
) # type: ignore
|
)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
# trigger async task
|
# trigger async task
|
||||||
document_indexing_update_task.delay(document.dataset_id, document.id)
|
document_indexing_update_task.delay(document.dataset_id, document.id)
|
||||||
|
|
@ -2164,25 +2168,26 @@ class DocumentService:
|
||||||
def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account):
|
def save_document_without_dataset_id(tenant_id: str, knowledge_config: KnowledgeConfig, account: Account):
|
||||||
assert isinstance(current_user, Account)
|
assert isinstance(current_user, Account)
|
||||||
assert current_user.current_tenant_id is not None
|
assert current_user.current_tenant_id is not None
|
||||||
|
assert knowledge_config.data_source
|
||||||
|
|
||||||
features = FeatureService.get_features(current_user.current_tenant_id)
|
features = FeatureService.get_features(current_user.current_tenant_id)
|
||||||
|
|
||||||
if features.billing.enabled:
|
if features.billing.enabled:
|
||||||
count = 0
|
count = 0
|
||||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file": # type: ignore
|
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||||
upload_file_list = (
|
upload_file_list = (
|
||||||
knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||||
if knowledge_config.data_source.info_list.file_info_list # type: ignore
|
if knowledge_config.data_source.info_list.file_info_list
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
count = len(upload_file_list)
|
count = len(upload_file_list)
|
||||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import": # type: ignore
|
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list # type: ignore
|
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||||
if notion_info_list:
|
if notion_info_list:
|
||||||
for notion_info in notion_info_list:
|
for notion_info in notion_info_list:
|
||||||
count = count + len(notion_info.pages)
|
count = count + len(notion_info.pages)
|
||||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl": # type: ignore
|
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||||
website_info = knowledge_config.data_source.info_list.website_info_list # type: ignore
|
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||||
if website_info:
|
if website_info:
|
||||||
count = len(website_info.urls)
|
count = len(website_info.urls)
|
||||||
if features.billing.subscription.plan == "sandbox" and count > 1:
|
if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||||
|
|
@ -2196,9 +2201,11 @@ class DocumentService:
|
||||||
dataset_collection_binding_id = None
|
dataset_collection_binding_id = None
|
||||||
retrieval_model = None
|
retrieval_model = None
|
||||||
if knowledge_config.indexing_technique == "high_quality":
|
if knowledge_config.indexing_technique == "high_quality":
|
||||||
|
assert knowledge_config.embedding_model_provider
|
||||||
|
assert knowledge_config.embedding_model
|
||||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||||
knowledge_config.embedding_model_provider, # type: ignore
|
knowledge_config.embedding_model_provider,
|
||||||
knowledge_config.embedding_model, # type: ignore
|
knowledge_config.embedding_model,
|
||||||
)
|
)
|
||||||
dataset_collection_binding_id = dataset_collection_binding.id
|
dataset_collection_binding_id = dataset_collection_binding.id
|
||||||
if knowledge_config.retrieval_model:
|
if knowledge_config.retrieval_model:
|
||||||
|
|
@ -2215,7 +2222,7 @@ class DocumentService:
|
||||||
dataset = Dataset(
|
dataset = Dataset(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
name="",
|
name="",
|
||||||
data_source_type=knowledge_config.data_source.info_list.data_source_type, # type: ignore
|
data_source_type=knowledge_config.data_source.info_list.data_source_type,
|
||||||
indexing_technique=knowledge_config.indexing_technique,
|
indexing_technique=knowledge_config.indexing_technique,
|
||||||
created_by=account.id,
|
created_by=account.id,
|
||||||
embedding_model=knowledge_config.embedding_model,
|
embedding_model=knowledge_config.embedding_model,
|
||||||
|
|
@ -2224,7 +2231,7 @@ class DocumentService:
|
||||||
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
|
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
db.session.add(dataset) # type: ignore
|
db.session.add(dataset)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
|
||||||
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account)
|
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account)
|
||||||
|
|
|
||||||
|
|
@ -88,7 +88,7 @@ class HitTestingService:
|
||||||
db.session.add(dataset_query)
|
db.session.add(dataset_query)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return cls.compact_retrieve_response(query, all_documents) # type: ignore
|
return cls.compact_retrieve_response(query, all_documents)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def external_retrieve(
|
def external_retrieve(
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
import boto3 # type: ignore
|
import boto3
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -89,7 +89,7 @@ class MetadataService:
|
||||||
document.doc_metadata = doc_metadata
|
document.doc_metadata = doc_metadata
|
||||||
db.session.add(document)
|
db.session.add(document)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
return metadata # type: ignore
|
return metadata
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Update metadata name failed")
|
logger.exception("Update metadata name failed")
|
||||||
finally:
|
finally:
|
||||||
|
|
|
||||||
|
|
@ -137,7 +137,7 @@ class ModelProviderService:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||||
return provider_configuration.get_provider_credential(credential_id=credential_id) # type: ignore
|
return provider_configuration.get_provider_credential(credential_id=credential_id)
|
||||||
|
|
||||||
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict):
|
def validate_provider_credentials(self, tenant_id: str, provider: str, credentials: dict):
|
||||||
"""
|
"""
|
||||||
|
|
@ -225,7 +225,7 @@ class ModelProviderService:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
provider_configuration = self._get_provider_configuration(tenant_id, provider)
|
||||||
return provider_configuration.get_custom_model_credential( # type: ignore
|
return provider_configuration.get_custom_model_credential(
|
||||||
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
model_type=ModelType.value_of(model_type), model=model, credential_id=credential_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -146,7 +146,7 @@ class PluginMigration:
|
||||||
futures.append(
|
futures.append(
|
||||||
thread_pool.submit(
|
thread_pool.submit(
|
||||||
process_tenant,
|
process_tenant,
|
||||||
current_app._get_current_object(), # type: ignore[attr-defined]
|
current_app._get_current_object(), # type: ignore
|
||||||
tenant_id,
|
tenant_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -544,8 +544,8 @@ class BuiltinToolManageService:
|
||||||
try:
|
try:
|
||||||
# handle include, exclude
|
# handle include, exclude
|
||||||
if is_filtered(
|
if is_filtered(
|
||||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
|
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
|
||||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
|
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
|
||||||
data=provider_controller,
|
data=provider_controller,
|
||||||
name_func=lambda x: x.entity.identity.name,
|
name_func=lambda x: x.entity.identity.name,
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -308,7 +308,7 @@ class MCPToolManageService:
|
||||||
provider_controller = MCPToolProviderController.from_db(mcp_provider)
|
provider_controller = MCPToolProviderController.from_db(mcp_provider)
|
||||||
tool_configuration = ProviderConfigEncrypter(
|
tool_configuration = ProviderConfigEncrypter(
|
||||||
tenant_id=mcp_provider.tenant_id,
|
tenant_id=mcp_provider.tenant_id,
|
||||||
config=list(provider_controller.get_credentials_schema()), # ty: ignore [invalid-argument-type]
|
config=list(provider_controller.get_credentials_schema()),
|
||||||
provider_config_cache=NoOpProviderCredentialCache(),
|
provider_config_cache=NoOpProviderCredentialCache(),
|
||||||
)
|
)
|
||||||
credentials = tool_configuration.encrypt(credentials)
|
credentials = tool_configuration.encrypt(credentials)
|
||||||
|
|
|
||||||
|
|
@ -102,7 +102,7 @@ def batch_create_segment_to_index_task(
|
||||||
for segment, tokens in zip(content, tokens_list):
|
for segment, tokens in zip(content, tokens_list):
|
||||||
content = segment["content"]
|
content = segment["content"]
|
||||||
doc_id = str(uuid.uuid4())
|
doc_id = str(uuid.uuid4())
|
||||||
segment_hash = helper.generate_text_hash(content) # type: ignore
|
segment_hash = helper.generate_text_hash(content)
|
||||||
max_position = (
|
max_position = (
|
||||||
db.session.query(func.max(DocumentSegment.position))
|
db.session.query(func.max(DocumentSegment.position))
|
||||||
.where(DocumentSegment.document_id == dataset_document.id)
|
.where(DocumentSegment.document_id == dataset_document.id)
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,11 @@ from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from _pytest.monkeypatch import MonkeyPatch
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
from pymochow import MochowClient # type: ignore
|
from pymochow import MochowClient
|
||||||
from pymochow.model.database import Database # type: ignore
|
from pymochow.model.database import Database
|
||||||
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState # type: ignore
|
from pymochow.model.enum import IndexState, IndexType, MetricType, ReadConsistency, TableState
|
||||||
from pymochow.model.schema import HNSWParams, VectorIndex # type: ignore
|
from pymochow.model.schema import HNSWParams, VectorIndex
|
||||||
from pymochow.model.table import Table # type: ignore
|
from pymochow.model.table import Table
|
||||||
|
|
||||||
|
|
||||||
class AttrDict(UserDict):
|
class AttrDict(UserDict):
|
||||||
|
|
|
||||||
|
|
@ -3,15 +3,15 @@ from typing import Any, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from _pytest.monkeypatch import MonkeyPatch
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
from tcvectordb import RPCVectorDBClient # type: ignore
|
from tcvectordb import RPCVectorDBClient
|
||||||
from tcvectordb.model import enum
|
from tcvectordb.model import enum
|
||||||
from tcvectordb.model.collection import FilterIndexConfig
|
from tcvectordb.model.collection import FilterIndexConfig
|
||||||
from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank # type: ignore
|
from tcvectordb.model.document import AnnSearch, Document, Filter, KeywordSearch, Rerank
|
||||||
from tcvectordb.model.enum import ReadConsistency # type: ignore
|
from tcvectordb.model.enum import ReadConsistency
|
||||||
from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex # type: ignore
|
from tcvectordb.model.index import FilterIndex, HNSWParams, Index, IndexField, VectorIndex
|
||||||
from tcvectordb.rpc.model.collection import RPCCollection
|
from tcvectordb.rpc.model.collection import RPCCollection
|
||||||
from tcvectordb.rpc.model.database import RPCDatabase
|
from tcvectordb.rpc.model.database import RPCDatabase
|
||||||
from xinference_client.types import Embedding # type: ignore
|
from xinference_client.types import Embedding
|
||||||
|
|
||||||
|
|
||||||
class MockTcvectordbClass:
|
class MockTcvectordbClass:
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from _pytest.monkeypatch import MonkeyPatch
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
from volcengine.viking_db import ( # type: ignore
|
from volcengine.viking_db import (
|
||||||
Collection,
|
Collection,
|
||||||
Data,
|
Data,
|
||||||
DistanceType,
|
DistanceType,
|
||||||
|
|
|
||||||
|
|
@ -43,7 +43,7 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue:
|
||||||
"""Test with None input"""
|
"""Test with None input"""
|
||||||
# The method signature expects Union[dict, list, Segment], but implementation handles None
|
# The method signature expects Union[dict, list, Segment], but implementation handles None
|
||||||
# We'll test the actual behavior by passing an empty dict instead
|
# We'll test the actual behavior by passing an empty dict instead
|
||||||
result = WorkflowResponseConverter._fetch_files_from_variable_value(None) # type: ignore
|
result = WorkflowResponseConverter._fetch_files_from_variable_value(None)
|
||||||
assert result == []
|
assert result == []
|
||||||
|
|
||||||
def test_fetch_files_from_variable_value_with_empty_dict(self):
|
def test_fetch_files_from_variable_value_with_empty_dict(self):
|
||||||
|
|
|
||||||
|
|
@ -235,7 +235,7 @@ class TestIndividualHandlers:
|
||||||
# Type assertion needed due to union type
|
# Type assertion needed due to union type
|
||||||
text_content = result.content[0]
|
text_content = result.content[0]
|
||||||
assert hasattr(text_content, "text")
|
assert hasattr(text_content, "text")
|
||||||
assert text_content.text == "test answer" # type: ignore[attr-defined]
|
assert text_content.text == "test answer"
|
||||||
|
|
||||||
def test_handle_call_tool_no_end_user(self):
|
def test_handle_call_tool_no_end_user(self):
|
||||||
"""Test call tool handler without end user"""
|
"""Test call tool handler without end user"""
|
||||||
|
|
|
||||||
|
|
@ -212,7 +212,7 @@ class TestValidateResult:
|
||||||
parameters=[
|
parameters=[
|
||||||
ParameterConfig(
|
ParameterConfig(
|
||||||
name="status",
|
name="status",
|
||||||
type="select", # type: ignore
|
type="select",
|
||||||
description="Status",
|
description="Status",
|
||||||
required=True,
|
required=True,
|
||||||
options=["active", "inactive"],
|
options=["active", "inactive"],
|
||||||
|
|
@ -400,7 +400,7 @@ class TestTransformResult:
|
||||||
parameters=[
|
parameters=[
|
||||||
ParameterConfig(
|
ParameterConfig(
|
||||||
name="status",
|
name="status",
|
||||||
type="select", # type: ignore
|
type="select",
|
||||||
description="Status",
|
description="Status",
|
||||||
required=True,
|
required=True,
|
||||||
options=["active", "inactive"],
|
options=["active", "inactive"],
|
||||||
|
|
@ -414,7 +414,7 @@ class TestTransformResult:
|
||||||
parameters=[
|
parameters=[
|
||||||
ParameterConfig(
|
ParameterConfig(
|
||||||
name="status",
|
name="status",
|
||||||
type="select", # type: ignore
|
type="select",
|
||||||
description="Status",
|
description="Status",
|
||||||
required=True,
|
required=True,
|
||||||
options=["active", "inactive"],
|
options=["active", "inactive"],
|
||||||
|
|
|
||||||
|
|
@ -248,4 +248,4 @@ def test_constructor_with_extra_key():
|
||||||
# Test that SystemVariable should forbid extra keys
|
# Test that SystemVariable should forbid extra keys
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
# This should fail because there is an unexpected key.
|
# This should fail because there is an unexpected key.
|
||||||
SystemVariable(invalid_key=1) # type: ignore
|
SystemVariable(invalid_key=1)
|
||||||
|
|
|
||||||
|
|
@ -14,36 +14,36 @@ def _create_api_app():
|
||||||
api = ExternalApi(bp)
|
api = ExternalApi(bp)
|
||||||
|
|
||||||
@api.route("/bad-request")
|
@api.route("/bad-request")
|
||||||
class Bad(Resource): # type: ignore
|
class Bad(Resource):
|
||||||
def get(self): # type: ignore
|
def get(self):
|
||||||
raise BadRequest("invalid input")
|
raise BadRequest("invalid input")
|
||||||
|
|
||||||
@api.route("/unauth")
|
@api.route("/unauth")
|
||||||
class Unauth(Resource): # type: ignore
|
class Unauth(Resource):
|
||||||
def get(self): # type: ignore
|
def get(self):
|
||||||
raise Unauthorized("auth required")
|
raise Unauthorized("auth required")
|
||||||
|
|
||||||
@api.route("/value-error")
|
@api.route("/value-error")
|
||||||
class ValErr(Resource): # type: ignore
|
class ValErr(Resource):
|
||||||
def get(self): # type: ignore
|
def get(self):
|
||||||
raise ValueError("boom")
|
raise ValueError("boom")
|
||||||
|
|
||||||
@api.route("/quota")
|
@api.route("/quota")
|
||||||
class Quota(Resource): # type: ignore
|
class Quota(Resource):
|
||||||
def get(self): # type: ignore
|
def get(self):
|
||||||
raise AppInvokeQuotaExceededError("quota exceeded")
|
raise AppInvokeQuotaExceededError("quota exceeded")
|
||||||
|
|
||||||
@api.route("/general")
|
@api.route("/general")
|
||||||
class Gen(Resource): # type: ignore
|
class Gen(Resource):
|
||||||
def get(self): # type: ignore
|
def get(self):
|
||||||
raise RuntimeError("oops")
|
raise RuntimeError("oops")
|
||||||
|
|
||||||
# Note: We avoid altering default_mediatype to keep normal error paths
|
# Note: We avoid altering default_mediatype to keep normal error paths
|
||||||
|
|
||||||
# Special 400 message rewrite
|
# Special 400 message rewrite
|
||||||
@api.route("/json-empty")
|
@api.route("/json-empty")
|
||||||
class JsonEmpty(Resource): # type: ignore
|
class JsonEmpty(Resource):
|
||||||
def get(self): # type: ignore
|
def get(self):
|
||||||
e = BadRequest()
|
e = BadRequest()
|
||||||
# Force the specific message the handler rewrites
|
# Force the specific message the handler rewrites
|
||||||
e.description = "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)"
|
e.description = "Failed to decode JSON object: Expecting value: line 1 column 1 (char 0)"
|
||||||
|
|
@ -51,11 +51,11 @@ def _create_api_app():
|
||||||
|
|
||||||
# 400 mapping payload path
|
# 400 mapping payload path
|
||||||
@api.route("/param-errors")
|
@api.route("/param-errors")
|
||||||
class ParamErrors(Resource): # type: ignore
|
class ParamErrors(Resource):
|
||||||
def get(self): # type: ignore
|
def get(self):
|
||||||
e = BadRequest()
|
e = BadRequest()
|
||||||
# Coerce a mapping description to trigger param error shaping
|
# Coerce a mapping description to trigger param error shaping
|
||||||
e.description = {"field": "is required"} # type: ignore[assignment]
|
e.description = {"field": "is required"}
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
app.register_blueprint(bp, url_prefix="/api")
|
app.register_blueprint(bp, url_prefix="/api")
|
||||||
|
|
@ -105,7 +105,7 @@ def test_external_api_param_mapping_and_quota_and_exc_info_none():
|
||||||
|
|
||||||
orig_exc_info = ext.sys.exc_info
|
orig_exc_info = ext.sys.exc_info
|
||||||
try:
|
try:
|
||||||
ext.sys.exc_info = lambda: (None, None, None) # type: ignore[assignment]
|
ext.sys.exc_info = lambda: (None, None, None)
|
||||||
|
|
||||||
app = _create_api_app()
|
app = _create_api_app()
|
||||||
client = app.test_client()
|
client = app.test_client()
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,7 @@ def test_current_user_not_accessible_across_threads(login_app: Flask, test_user:
|
||||||
# without preserve_flask_contexts
|
# without preserve_flask_contexts
|
||||||
result["user_accessible"] = current_user.is_authenticated
|
result["user_accessible"] = current_user.is_authenticated
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
result["error"] = str(e) # type: ignore
|
result["error"] = str(e)
|
||||||
|
|
||||||
# Run the function in a separate thread
|
# Run the function in a separate thread
|
||||||
thread = threading.Thread(target=check_user_in_thread)
|
thread = threading.Thread(target=check_user_in_thread)
|
||||||
|
|
@ -110,7 +110,7 @@ def test_current_user_accessible_with_preserve_flask_contexts(login_app: Flask,
|
||||||
else:
|
else:
|
||||||
result["user_accessible"] = False
|
result["user_accessible"] = False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
result["error"] = str(e) # type: ignore
|
result["error"] = str(e)
|
||||||
|
|
||||||
# Run the function in a separate thread
|
# Run the function in a separate thread
|
||||||
thread = threading.Thread(target=check_user_in_thread_with_manager)
|
thread = threading.Thread(target=check_user_in_thread_with_manager)
|
||||||
|
|
|
||||||
|
|
@ -16,4 +16,4 @@ def test_oauth_base_methods_raise_not_implemented():
|
||||||
oauth.get_raw_user_info("token")
|
oauth.get_raw_user_info("token")
|
||||||
|
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
oauth._transform_user_info({}) # type: ignore[name-defined]
|
oauth._transform_user_info({})
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,8 @@ from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from _pytest.monkeypatch import MonkeyPatch
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
from qcloud_cos import CosS3Client # type: ignore
|
from qcloud_cos import CosS3Client
|
||||||
from qcloud_cos.streambody import StreamBody # type: ignore
|
from qcloud_cos.streambody import StreamBody
|
||||||
|
|
||||||
from tests.unit_tests.oss.__mock.base import (
|
from tests.unit_tests.oss.__mock.base import (
|
||||||
get_example_bucket,
|
get_example_bucket,
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,8 @@ from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from _pytest.monkeypatch import MonkeyPatch
|
from _pytest.monkeypatch import MonkeyPatch
|
||||||
from tos import TosClientV2 # type: ignore
|
from tos import TosClientV2
|
||||||
from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput # type: ignore
|
from tos.clientv2 import DeleteObjectOutput, GetObjectOutput, HeadObjectOutput, PutObjectOutput
|
||||||
|
|
||||||
from tests.unit_tests.oss.__mock.base import (
|
from tests.unit_tests.oss.__mock.base import (
|
||||||
get_example_bucket,
|
get_example_bucket,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from qcloud_cos import CosConfig # type: ignore
|
from qcloud_cos import CosConfig
|
||||||
|
|
||||||
from extensions.storage.tencent_cos_storage import TencentCosStorage
|
from extensions.storage.tencent_cos_storage import TencentCosStorage
|
||||||
from tests.unit_tests.oss.__mock.base import (
|
from tests.unit_tests.oss.__mock.base import (
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from tos import TosClientV2 # type: ignore
|
from tos import TosClientV2
|
||||||
|
|
||||||
from extensions.storage.volcengine_tos_storage import VolcengineTosStorage
|
from extensions.storage.volcengine_tos_storage import VolcengineTosStorage
|
||||||
from tests.unit_tests.oss.__mock.base import (
|
from tests.unit_tests.oss.__mock.base import (
|
||||||
|
|
|
||||||
|
|
@ -125,13 +125,13 @@ class TestApiKeyAuthService:
|
||||||
mock_session.commit = Mock()
|
mock_session.commit = Mock()
|
||||||
|
|
||||||
args_copy = self.mock_args.copy()
|
args_copy = self.mock_args.copy()
|
||||||
original_key = args_copy["credentials"]["config"]["api_key"] # type: ignore
|
original_key = args_copy["credentials"]["config"]["api_key"]
|
||||||
|
|
||||||
ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy)
|
ApiKeyAuthService.create_provider_auth(self.tenant_id, args_copy)
|
||||||
|
|
||||||
# Verify original key is replaced with encrypted key
|
# Verify original key is replaced with encrypted key
|
||||||
assert args_copy["credentials"]["config"]["api_key"] == encrypted_key # type: ignore
|
assert args_copy["credentials"]["config"]["api_key"] == encrypted_key
|
||||||
assert args_copy["credentials"]["config"]["api_key"] != original_key # type: ignore
|
assert args_copy["credentials"]["config"]["api_key"] != original_key
|
||||||
|
|
||||||
# Verify encryption function is called correctly
|
# Verify encryption function is called correctly
|
||||||
mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key)
|
mock_encrypter.encrypt_token.assert_called_once_with(self.tenant_id, original_key)
|
||||||
|
|
@ -268,7 +268,7 @@ class TestApiKeyAuthService:
|
||||||
def test_validate_api_key_auth_args_empty_credentials(self):
|
def test_validate_api_key_auth_args_empty_credentials(self):
|
||||||
"""Test API key auth args validation - empty credentials"""
|
"""Test API key auth args validation - empty credentials"""
|
||||||
args = self.mock_args.copy()
|
args = self.mock_args.copy()
|
||||||
args["credentials"] = None # type: ignore
|
args["credentials"] = None
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="credentials is required"):
|
with pytest.raises(ValueError, match="credentials is required"):
|
||||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||||
|
|
@ -284,7 +284,7 @@ class TestApiKeyAuthService:
|
||||||
def test_validate_api_key_auth_args_missing_auth_type(self):
|
def test_validate_api_key_auth_args_missing_auth_type(self):
|
||||||
"""Test API key auth args validation - missing auth_type"""
|
"""Test API key auth args validation - missing auth_type"""
|
||||||
args = self.mock_args.copy()
|
args = self.mock_args.copy()
|
||||||
del args["credentials"]["auth_type"] # type: ignore
|
del args["credentials"]["auth_type"]
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="auth_type is required"):
|
with pytest.raises(ValueError, match="auth_type is required"):
|
||||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||||
|
|
@ -292,7 +292,7 @@ class TestApiKeyAuthService:
|
||||||
def test_validate_api_key_auth_args_empty_auth_type(self):
|
def test_validate_api_key_auth_args_empty_auth_type(self):
|
||||||
"""Test API key auth args validation - empty auth_type"""
|
"""Test API key auth args validation - empty auth_type"""
|
||||||
args = self.mock_args.copy()
|
args = self.mock_args.copy()
|
||||||
args["credentials"]["auth_type"] = "" # type: ignore
|
args["credentials"]["auth_type"] = ""
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="auth_type is required"):
|
with pytest.raises(ValueError, match="auth_type is required"):
|
||||||
ApiKeyAuthService.validate_api_key_auth_args(args)
|
ApiKeyAuthService.validate_api_key_auth_args(args)
|
||||||
|
|
@ -380,7 +380,7 @@ class TestApiKeyAuthService:
|
||||||
def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self):
|
def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self):
|
||||||
"""Test API key auth args validation - dict credentials with list auth_type"""
|
"""Test API key auth args validation - dict credentials with list auth_type"""
|
||||||
args = self.mock_args.copy()
|
args = self.mock_args.copy()
|
||||||
args["credentials"]["auth_type"] = ["api_key"] # type: ignore # list instead of string
|
args["credentials"]["auth_type"] = ["api_key"]
|
||||||
|
|
||||||
# Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy
|
# Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy
|
||||||
# So this should not raise exception, this test should pass
|
# So this should not raise exception, this test should pass
|
||||||
|
|
|
||||||
|
|
@ -116,10 +116,10 @@ class TestSystemOAuthEncrypter:
|
||||||
encrypter = SystemOAuthEncrypter("test_secret")
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
|
||||||
with pytest.raises(Exception): # noqa: B017
|
with pytest.raises(Exception): # noqa: B017
|
||||||
encrypter.encrypt_oauth_params(None) # type: ignore
|
encrypter.encrypt_oauth_params(None)
|
||||||
|
|
||||||
with pytest.raises(Exception): # noqa: B017
|
with pytest.raises(Exception): # noqa: B017
|
||||||
encrypter.encrypt_oauth_params("not_a_dict") # type: ignore
|
encrypter.encrypt_oauth_params("not_a_dict")
|
||||||
|
|
||||||
def test_decrypt_oauth_params_basic(self):
|
def test_decrypt_oauth_params_basic(self):
|
||||||
"""Test basic OAuth parameters decryption"""
|
"""Test basic OAuth parameters decryption"""
|
||||||
|
|
@ -207,12 +207,12 @@ class TestSystemOAuthEncrypter:
|
||||||
encrypter = SystemOAuthEncrypter("test_secret")
|
encrypter = SystemOAuthEncrypter("test_secret")
|
||||||
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
encrypter.decrypt_oauth_params(123) # type: ignore
|
encrypter.decrypt_oauth_params(123)
|
||||||
|
|
||||||
assert "encrypted_data must be a string" in str(exc_info.value)
|
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||||
|
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
encrypter.decrypt_oauth_params(None) # type: ignore
|
encrypter.decrypt_oauth_params(None)
|
||||||
|
|
||||||
assert "encrypted_data must be a string" in str(exc_info.value)
|
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
@ -461,14 +461,14 @@ class TestConvenienceFunctions:
|
||||||
"""Test convenience functions with error conditions"""
|
"""Test convenience functions with error conditions"""
|
||||||
# Test encryption with invalid input
|
# Test encryption with invalid input
|
||||||
with pytest.raises(Exception): # noqa: B017
|
with pytest.raises(Exception): # noqa: B017
|
||||||
encrypt_system_oauth_params(None) # type: ignore
|
encrypt_system_oauth_params(None)
|
||||||
|
|
||||||
# Test decryption with invalid input
|
# Test decryption with invalid input
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
decrypt_system_oauth_params("")
|
decrypt_system_oauth_params("")
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
decrypt_system_oauth_params(None) # type: ignore
|
decrypt_system_oauth_params(None)
|
||||||
|
|
||||||
|
|
||||||
class TestErrorHandling:
|
class TestErrorHandling:
|
||||||
|
|
@ -501,7 +501,7 @@ class TestErrorHandling:
|
||||||
|
|
||||||
# Test non-string error
|
# Test non-string error
|
||||||
with pytest.raises(ValueError) as exc_info:
|
with pytest.raises(ValueError) as exc_info:
|
||||||
encrypter.decrypt_oauth_params(123) # type: ignore
|
encrypter.decrypt_oauth_params(123)
|
||||||
assert "encrypted_data must be a string" in str(exc_info.value)
|
assert "encrypted_data must be a string" in str(exc_info.value)
|
||||||
|
|
||||||
# Test invalid format error
|
# Test invalid format error
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue