chore(api): Fix several typing errors (#37119)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
chariri 2026-06-06 10:44:32 +09:00 committed by GitHub
parent 1c0080be6f
commit 157ba6f5a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 88 additions and 70 deletions

View File

@ -2,7 +2,7 @@ import logging
import re
import uuid
from datetime import datetime
from typing import Any, Literal
from typing import Any, Literal, cast
from flask import request
from flask_restx import Resource
@ -65,14 +65,13 @@ register_enum_models(console_ns, IconType)
_logger = logging.getLogger(__name__)
_TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
_CREATOR_IDS_BRACKET_PATTERN = re.compile(r"^creator_ids\[(\d+)\]$")
AppListMode = Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"]
class AppListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"] = Field(
default="all", description="App mode filter"
)
mode: AppListMode = Field(default=cast(AppListMode, "all"), description="App mode filter")
name: str | None = Field(default=None, description="Filter by app name")
tag_ids: list[str] | None = Field(default=None, description="Filter by tag IDs")
creator_ids: list[str] | None = Field(default=None, description="Filter by creator account IDs")

View File

@ -76,7 +76,7 @@ class EmailRegisterSendEmailApi(Resource):
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
language = "en-US"
if args.language in languages:
if args.language is not None and args.language in languages:
language = args.language
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):

View File

@ -1,3 +1,3 @@
from fastopenapi.routers import FlaskRouter
from fastopenapi.routers.flask import FlaskRouter
console_router = FlaskRouter()

View File

@ -221,9 +221,9 @@ class TracingProviderConfigEntry(TypedDict):
class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]):
def __getitem__(self, provider: str) -> TracingProviderConfigEntry:
def __getitem__(self, key: str) -> TracingProviderConfigEntry:
try:
match provider:
match key:
case TracingProviderEnum.LANGFUSE:
from dify_trace_langfuse.config import LangfuseConfig
from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace
@ -330,9 +330,9 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigE
}
case _:
raise KeyError(f"Unsupported tracing provider: {provider}")
raise KeyError(f"Unsupported tracing provider: {key}")
except ImportError:
raise ImportError(f"Provider {provider} is not installed.")
raise ImportError(f"Provider {key} is not installed.")
provider_config_map = OpsTraceProviderConfigMap()

View File

@ -412,7 +412,7 @@ class ToolManager:
tool_invoke_from=ToolInvokeFrom.AGENT,
credential_id=agent_tool.credential_id,
)
runtime_parameters = {}
runtime_parameters: dict[str, Any] = {}
parameters = tool_entity.get_merged_runtime_parameters()
runtime_parameters = cls._convert_tool_parameters_type(
parameters, variable_pool, agent_tool.tool_parameters, typ="agent"
@ -501,7 +501,7 @@ class ToolManager:
tool_invoke_from=ToolInvokeFrom.PLUGIN,
credential_id=credential_id,
)
runtime_parameters = {}
runtime_parameters: dict[str, Any] = {}
parameters = tool_entity.get_merged_runtime_parameters()
for parameter in parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM:
@ -1070,7 +1070,7 @@ class ToolManager:
from graphon.nodes.tool.entities import ToolNodeData
from graphon.nodes.tool.exc import ToolParameterError
runtime_parameters = {}
runtime_parameters: dict[str, Any] = {}
for parameter in parameters:
if (
parameter.type

View File

@ -63,7 +63,7 @@ def get_url(url: str, user_agent: str | None = None) -> str:
response = remote_fetcher.make_request("GET", url, headers=headers, follow_redirects=True, timeout=(120, 300))
elif response.status_code == 403:
scraper = cloudscraper.create_scraper()
scraper.perform_request = remote_fetcher.make_request
object.__setattr__(scraper, "perform_request", remote_fetcher.make_request)
response = scraper.get(url, headers=headers, timeout=(120, 300))
if response.status_code != 200:

View File

@ -16080,9 +16080,17 @@ Shared permission levels for resources (datasets, credentials, etc.)
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| mode | string | *Enum:* `"automatic"`, `"custom"`, `"hierarchical"` | Yes |
| mode | [ProcessRuleMode](#processrulemode) | | Yes |
| rules | [Rule](#rule) | | No |
#### ProcessRuleMode
Dataset Process Rule Mode
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| ProcessRuleMode | string | Dataset Process Rule Mode | |
#### PublishWorkflowPayload
Payload for publishing snippet workflow.

View File

@ -3039,9 +3039,17 @@ Shared permission levels for resources (datasets, credentials, etc.)
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| mode | string | *Enum:* `"automatic"`, `"custom"`, `"hierarchical"` | Yes |
| mode | [ProcessRuleMode](#processrulemode) | | Yes |
| rules | [Rule](#rule) | | No |
#### ProcessRuleMode
Dataset Process Rule Mode
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| ProcessRuleMode | string | Dataset Process Rule Mode | |
#### RerankingModel
| Name | Type | Description | Required |

View File

@ -1,16 +1,8 @@
controllers/console/app/app.py
controllers/console/auth/email_register.py
controllers/console/init_validate.py
controllers/console/ping.py
controllers/console/setup.py
controllers/console/version.py
core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
core/llm_generator/llm_generator.py
providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py
core/ops/ops_trace_manager.py
core/prompt/utils/prompt_message_util.py
core/rag/retrieval/dataset_retrieval.py
core/tools/tool_manager.py
extensions/ext_celery.py
providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_factory.py
providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_vector.py
@ -74,20 +66,9 @@ core/rag/index_processor/processor/parent_child_index_processor.py
core/rag/index_processor/processor/qa_index_processor.py
core/tools/mcp_tool/provider.py
core/tools/plugin_tool/provider.py
core/tools/utils/web_reader_tool.py
core/tools/workflow_as_tool/provider.py
extensions/storage/huawei_obs_storage.py
libs/gmpy2_pkcs10aep_cipher.py
schedule/queue_monitor_task.py
services/account_service.py
services/audio_service.py
services/conversation_service.py
services/dataset_service.py
services/app_service.py
services/document_indexing_proxy/document_indexing_task_proxy.py
services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py
services/plugin/plugin_migration.py
services/trigger/trigger_provider_service.py
services/workflow_draft_variable_service.py
tasks/regenerate_summary_index_task.py
tasks/workflow_cfs_scheduler/cfs_scheduler.py

View File

@ -12,11 +12,11 @@ from libs.email_i18n import EmailType, get_email_i18n_service
redis_config = parse_url(dify_config.CELERY_BROKER_URL)
celery_redis = Redis(
host=redis_config.get("hostname") or "localhost",
port=redis_config.get("port") or 6379,
password=redis_config.get("password") or None,
host=str(redis_config.get("hostname") or "localhost"),
port=int(redis_config.get("port") or 6379),
password=str(pwd) if (pwd := redis_config.get("password")) is not None else None,
db=int(redis_config.get("virtual_host")) if redis_config.get("virtual_host") else 1,
ssl=bool(dify_config.BROKER_USE_SSL),
ssl=dify_config.BROKER_USE_SSL,
ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS if dify_config.BROKER_USE_SSL else None,
ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None,
ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None,

View File

@ -1414,7 +1414,7 @@ class TenantService:
.limit(1)
)
if ta:
tenant.role = ta.role
object.__setattr__(tenant, "role", ta.role)
else:
raise TenantNotFoundError("Tenant not found for the account.")
return tenant

View File

@ -551,9 +551,9 @@ class AppService:
keys = list(tool.keys())
if len(keys) >= 4:
# current tool standard
provider_type = tool.get("provider_type", "")
provider_id = tool.get("provider_id", "")
tool_name = tool.get("tool_name", "")
provider_type = str(tool.get("provider_type", ""))
provider_id = str(tool.get("provider_id", ""))
tool_name = str(tool.get("tool_name", ""))
if provider_type == "builtin":
meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon"
elif provider_type == "api":

View File

@ -128,6 +128,8 @@ class ConversationService:
if auto_generate:
return cls.auto_generate_name(app_model, conversation)
else:
if name is None:
raise ValueError("name is required when auto_generate is false")
conversation.name = name
conversation.updated_at = naive_utc_now()
db.session.commit()

View File

@ -2173,7 +2173,7 @@ class DocumentService:
)
documents_map = {document.name: document for document in db_documents}
for file in files:
data_source_info: dict[str, str | bool] = {
data_source_info: dict[str, object] = {
"upload_file_id": file.id,
}
document = documents_map.get(file.name)
@ -2706,7 +2706,7 @@ class DocumentService:
# update document data source
if document_data.data_source:
file_name = ""
data_source_info: dict[str, str | bool] = {}
data_source_info: dict[str, object] = {}
if document_data.data_source.info_list.data_source_type == "upload_file":
if not document_data.data_source.info_list.file_info_list:
raise ValueError("No file info list found.")

View File

@ -6,6 +6,7 @@ from core.rag.entities import Rule
from core.rag.entities.metadata_entities import MetadataFilteringCondition
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from models.enums import ProcessRuleMode
class RerankingModel(BaseModel):
@ -55,7 +56,7 @@ class DataSource(BaseModel):
class ProcessRule(BaseModel):
mode: Literal["automatic", "custom", "hierarchical"]
mode: ProcessRuleMode
rules: Rule | None = None

View File

@ -90,7 +90,6 @@ class PluginMigration:
# Use lock when updating counter
with counter_lock:
nonlocal handled_tenant_count
handled_tenant_count += 1
click.echo(
click.style(

View File

@ -467,7 +467,7 @@ class TriggerProviderService:
if not subscription:
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
if subscription.credential_type != CredentialType.OAUTH2.value:
if subscription.credential_type != CredentialType.OAUTH2:
raise ValueError("Only OAuth credentials can be refreshed")
provider_id = TriggerProviderID(subscription.provider_id)

View File

@ -13,12 +13,14 @@ class SchedulerCommand(StrEnum):
NONE = "none"
class CFSPlanScheduler(ABC):
class CFSPlanScheduler[PlanT: WorkflowScheduleCFSPlanEntity](ABC):
"""
CFS plan scheduler.
"""
def __init__(self, plan: WorkflowScheduleCFSPlanEntity):
plan: PlanT
def __init__(self, plan: PlanT):
"""
Initialize the CFS plan scheduler.

View File

@ -5,7 +5,7 @@ from collections.abc import Mapping, Sequence, Set
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from enum import StrEnum
from typing import Any, ClassVar, NotRequired, TypedDict
from typing import Any, ClassVar, NotRequired, TypedDict, cast
from sqlalchemy import Engine, delete, orm, select
from sqlalchemy.dialects.mysql import insert as mysql_insert
@ -1068,9 +1068,10 @@ class DraftVariableSaver:
original_length = len(value_seg.value)
# Prepare content for storage
original_content_serialized: str
if isinstance(value_seg, StringSegment):
# For string types, store as plain text
original_content_serialized = value_seg.value
original_content_serialized = cast(str, value_seg.value)
content_type = "text/plain"
filename = f"{self._generate_filename(name)}.txt"
else:

View File

@ -12,6 +12,7 @@ from core.db.session_factory import session_factory
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
from models.dataset import Document as DatasetDocument
from models.enums import SummaryStatus
from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)
@ -102,8 +103,8 @@ def regenerate_summary_index_task(
DocumentSegmentSummary.summary_content.isnot(None), # Must have summary content
# Include completed summaries or error summaries (with content)
or_(
DocumentSegmentSummary.status == "completed",
DocumentSegmentSummary.status == "error",
DocumentSegmentSummary.status == SummaryStatus.COMPLETED,
DocumentSegmentSummary.status == SummaryStatus.ERROR,
),
DatasetDocument.enabled == True, # Document must be enabled
DatasetDocument.archived == False, # Document must not be archived
@ -174,7 +175,7 @@ def regenerate_summary_index_task(
)
total_segments_failed += 1
# Update summary record with error status
summary_record.status = "error"
summary_record.status = SummaryStatus.ERROR
summary_record.error = f"Re-vectorization failed: {str(e)}"
session.add(summary_record)
session.commit()
@ -240,10 +241,10 @@ def regenerate_summary_index_task(
)
for segment in segments:
summary_record = None
existing_summary_record: DocumentSegmentSummary | None = None
try:
# Get existing summary record
summary_record = session.scalar(
existing_summary_record = session.scalar(
select(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
@ -252,7 +253,7 @@ def regenerate_summary_index_task(
.limit(1)
)
if not summary_record:
if not existing_summary_record:
logger.warning("Summary record not found for segment %s, skipping", segment.id)
continue
@ -272,10 +273,10 @@ def regenerate_summary_index_task(
)
total_segments_failed += 1
# Update summary record with error status
if summary_record:
summary_record.status = "error"
summary_record.error = f"Regeneration failed: {str(e)}"
session.add(summary_record)
if existing_summary_record is not None:
existing_summary_record.status = SummaryStatus.ERROR
existing_summary_record.error = f"Regeneration failed: {str(e)}"
session.add(existing_summary_record)
session.commit()
continue

View File

@ -11,13 +11,11 @@ class AsyncWorkflowCFSPlanEntity(WorkflowScheduleCFSPlanEntity):
queue: AsyncWorkflowQueue
class AsyncWorkflowCFSPlanScheduler(CFSPlanScheduler):
class AsyncWorkflowCFSPlanScheduler(CFSPlanScheduler[AsyncWorkflowCFSPlanEntity]):
"""
Trigger workflow CFS plan scheduler.
"""
plan: AsyncWorkflowCFSPlanEntity
def can_schedule(self) -> SchedulerCommand:
"""
Check if the workflow can be scheduled.

View File

@ -37,7 +37,7 @@ class TestCFSPlanScheduler:
granularity=10,
)
with pytest.raises(TypeError):
CFSPlanScheduler(plan)
CFSPlanScheduler(plan) # type: ignore
def test_concrete_subclass_can_schedule(self):
plan = WorkflowScheduleCFSPlanEntity(

View File

@ -664,7 +664,7 @@ export type DataSource = {
}
export type ProcessRule = {
mode: 'automatic' | 'custom' | 'hierarchical'
mode: ProcessRuleMode
rules?: Rule
}
@ -866,6 +866,8 @@ export type InfoList = {
website_info_list?: WebsiteInfo
}
export type ProcessRuleMode = 'automatic' | 'custom' | 'hierarchical'
export type Rule = {
parent_mode?: 'full-doc' | 'paragraph' | null
pre_processing_rules?: Array<PreProcessingRule> | null

View File

@ -709,6 +709,13 @@ export const zDatasetRerankingModel = z.object({
reranking_provider_name: z.string().optional(),
})
/**
* ProcessRuleMode
*
* Dataset Process Rule Mode
*/
export const zProcessRuleMode = z.enum(['automatic', 'custom', 'hierarchical'])
/**
* RerankingModel
*/
@ -1128,7 +1135,7 @@ export const zRule = z.object({
* ProcessRule
*/
export const zProcessRule = z.object({
mode: z.enum(['automatic', 'custom', 'hierarchical']),
mode: zProcessRuleMode,
rules: zRule.optional(),
})

View File

@ -721,10 +721,12 @@ export type PreProcessingRule = {
}
export type ProcessRule = {
mode: 'automatic' | 'custom' | 'hierarchical'
mode: ProcessRuleMode
rules?: Rule
}
export type ProcessRuleMode = 'automatic' | 'custom' | 'hierarchical'
export type RerankingModel = {
reranking_model_name?: string | null
reranking_provider_name?: string | null

View File

@ -894,6 +894,13 @@ export const zPreProcessingRule = z.object({
id: z.string(),
})
/**
* ProcessRuleMode
*
* Dataset Process Rule Mode
*/
export const zProcessRuleMode = z.enum(['automatic', 'custom', 'hierarchical'])
/**
* RerankingModel
*/
@ -1062,7 +1069,7 @@ export const zRule = z.object({
* ProcessRule
*/
export const zProcessRule = z.object({
mode: z.enum(['automatic', 'custom', 'hierarchical']),
mode: zProcessRuleMode,
rules: zRule.optional(),
})