mirror of
https://github.com/langgenius/dify.git
synced 2026-06-07 16:13:59 +08:00
chore(api): Fix several typing errors (#37119)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
1c0080be6f
commit
157ba6f5a0
@ -2,7 +2,7 @@ import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
@ -65,14 +65,13 @@ register_enum_models(console_ns, IconType)
|
||||
_logger = logging.getLogger(__name__)
|
||||
_TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$")
|
||||
_CREATOR_IDS_BRACKET_PATTERN = re.compile(r"^creator_ids\[(\d+)\]$")
|
||||
AppListMode = Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"]
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"] = Field(
|
||||
default="all", description="App mode filter"
|
||||
)
|
||||
mode: AppListMode = Field(default=cast(AppListMode, "all"), description="App mode filter")
|
||||
name: str | None = Field(default=None, description="Filter by app name")
|
||||
tag_ids: list[str] | None = Field(default=None, description="Filter by tag IDs")
|
||||
creator_ids: list[str] | None = Field(default=None, description="Filter by creator account IDs")
|
||||
|
||||
@ -76,7 +76,7 @@ class EmailRegisterSendEmailApi(Resource):
|
||||
if AccountService.is_email_send_ip_limit(ip_address):
|
||||
raise EmailSendIpLimitError()
|
||||
language = "en-US"
|
||||
if args.language in languages:
|
||||
if args.language is not None and args.language in languages:
|
||||
language = args.language
|
||||
|
||||
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
from fastopenapi.routers import FlaskRouter
|
||||
from fastopenapi.routers.flask import FlaskRouter
|
||||
|
||||
console_router = FlaskRouter()
|
||||
|
||||
@ -221,9 +221,9 @@ class TracingProviderConfigEntry(TypedDict):
|
||||
|
||||
|
||||
class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]):
|
||||
def __getitem__(self, provider: str) -> TracingProviderConfigEntry:
|
||||
def __getitem__(self, key: str) -> TracingProviderConfigEntry:
|
||||
try:
|
||||
match provider:
|
||||
match key:
|
||||
case TracingProviderEnum.LANGFUSE:
|
||||
from dify_trace_langfuse.config import LangfuseConfig
|
||||
from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace
|
||||
@ -330,9 +330,9 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigE
|
||||
}
|
||||
|
||||
case _:
|
||||
raise KeyError(f"Unsupported tracing provider: {provider}")
|
||||
raise KeyError(f"Unsupported tracing provider: {key}")
|
||||
except ImportError:
|
||||
raise ImportError(f"Provider {provider} is not installed.")
|
||||
raise ImportError(f"Provider {key} is not installed.")
|
||||
|
||||
|
||||
provider_config_map = OpsTraceProviderConfigMap()
|
||||
|
||||
@ -412,7 +412,7 @@ class ToolManager:
|
||||
tool_invoke_from=ToolInvokeFrom.AGENT,
|
||||
credential_id=agent_tool.credential_id,
|
||||
)
|
||||
runtime_parameters = {}
|
||||
runtime_parameters: dict[str, Any] = {}
|
||||
parameters = tool_entity.get_merged_runtime_parameters()
|
||||
runtime_parameters = cls._convert_tool_parameters_type(
|
||||
parameters, variable_pool, agent_tool.tool_parameters, typ="agent"
|
||||
@ -501,7 +501,7 @@ class ToolManager:
|
||||
tool_invoke_from=ToolInvokeFrom.PLUGIN,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
runtime_parameters = {}
|
||||
runtime_parameters: dict[str, Any] = {}
|
||||
parameters = tool_entity.get_merged_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
@ -1070,7 +1070,7 @@ class ToolManager:
|
||||
from graphon.nodes.tool.entities import ToolNodeData
|
||||
from graphon.nodes.tool.exc import ToolParameterError
|
||||
|
||||
runtime_parameters = {}
|
||||
runtime_parameters: dict[str, Any] = {}
|
||||
for parameter in parameters:
|
||||
if (
|
||||
parameter.type
|
||||
|
||||
@ -63,7 +63,7 @@ def get_url(url: str, user_agent: str | None = None) -> str:
|
||||
response = remote_fetcher.make_request("GET", url, headers=headers, follow_redirects=True, timeout=(120, 300))
|
||||
elif response.status_code == 403:
|
||||
scraper = cloudscraper.create_scraper()
|
||||
scraper.perform_request = remote_fetcher.make_request
|
||||
object.__setattr__(scraper, "perform_request", remote_fetcher.make_request)
|
||||
response = scraper.get(url, headers=headers, timeout=(120, 300))
|
||||
|
||||
if response.status_code != 200:
|
||||
|
||||
@ -16080,9 +16080,17 @@ Shared permission levels for resources (datasets, credentials, etc.)
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| mode | string | *Enum:* `"automatic"`, `"custom"`, `"hierarchical"` | Yes |
|
||||
| mode | [ProcessRuleMode](#processrulemode) | | Yes |
|
||||
| rules | [Rule](#rule) | | No |
|
||||
|
||||
#### ProcessRuleMode
|
||||
|
||||
Dataset Process Rule Mode
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| ProcessRuleMode | string | Dataset Process Rule Mode | |
|
||||
|
||||
#### PublishWorkflowPayload
|
||||
|
||||
Payload for publishing snippet workflow.
|
||||
|
||||
@ -3039,9 +3039,17 @@ Shared permission levels for resources (datasets, credentials, etc.)
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| mode | string | *Enum:* `"automatic"`, `"custom"`, `"hierarchical"` | Yes |
|
||||
| mode | [ProcessRuleMode](#processrulemode) | | Yes |
|
||||
| rules | [Rule](#rule) | | No |
|
||||
|
||||
#### ProcessRuleMode
|
||||
|
||||
Dataset Process Rule Mode
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| ProcessRuleMode | string | Dataset Process Rule Mode | |
|
||||
|
||||
#### RerankingModel
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
|
||||
@ -1,16 +1,8 @@
|
||||
controllers/console/app/app.py
|
||||
controllers/console/auth/email_register.py
|
||||
controllers/console/init_validate.py
|
||||
controllers/console/ping.py
|
||||
controllers/console/setup.py
|
||||
controllers/console/version.py
|
||||
core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
|
||||
core/llm_generator/llm_generator.py
|
||||
providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py
|
||||
core/ops/ops_trace_manager.py
|
||||
core/prompt/utils/prompt_message_util.py
|
||||
core/rag/retrieval/dataset_retrieval.py
|
||||
core/tools/tool_manager.py
|
||||
extensions/ext_celery.py
|
||||
providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_factory.py
|
||||
providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_vector.py
|
||||
@ -74,20 +66,9 @@ core/rag/index_processor/processor/parent_child_index_processor.py
|
||||
core/rag/index_processor/processor/qa_index_processor.py
|
||||
core/tools/mcp_tool/provider.py
|
||||
core/tools/plugin_tool/provider.py
|
||||
core/tools/utils/web_reader_tool.py
|
||||
core/tools/workflow_as_tool/provider.py
|
||||
extensions/storage/huawei_obs_storage.py
|
||||
libs/gmpy2_pkcs10aep_cipher.py
|
||||
schedule/queue_monitor_task.py
|
||||
services/account_service.py
|
||||
services/audio_service.py
|
||||
services/conversation_service.py
|
||||
services/dataset_service.py
|
||||
services/app_service.py
|
||||
services/document_indexing_proxy/document_indexing_task_proxy.py
|
||||
services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py
|
||||
services/plugin/plugin_migration.py
|
||||
services/trigger/trigger_provider_service.py
|
||||
services/workflow_draft_variable_service.py
|
||||
tasks/regenerate_summary_index_task.py
|
||||
tasks/workflow_cfs_scheduler/cfs_scheduler.py
|
||||
|
||||
@ -12,11 +12,11 @@ from libs.email_i18n import EmailType, get_email_i18n_service
|
||||
|
||||
redis_config = parse_url(dify_config.CELERY_BROKER_URL)
|
||||
celery_redis = Redis(
|
||||
host=redis_config.get("hostname") or "localhost",
|
||||
port=redis_config.get("port") or 6379,
|
||||
password=redis_config.get("password") or None,
|
||||
host=str(redis_config.get("hostname") or "localhost"),
|
||||
port=int(redis_config.get("port") or 6379),
|
||||
password=str(pwd) if (pwd := redis_config.get("password")) is not None else None,
|
||||
db=int(redis_config.get("virtual_host")) if redis_config.get("virtual_host") else 1,
|
||||
ssl=bool(dify_config.BROKER_USE_SSL),
|
||||
ssl=dify_config.BROKER_USE_SSL,
|
||||
ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS if dify_config.BROKER_USE_SSL else None,
|
||||
ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None,
|
||||
ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None,
|
||||
|
||||
@ -1414,7 +1414,7 @@ class TenantService:
|
||||
.limit(1)
|
||||
)
|
||||
if ta:
|
||||
tenant.role = ta.role
|
||||
object.__setattr__(tenant, "role", ta.role)
|
||||
else:
|
||||
raise TenantNotFoundError("Tenant not found for the account.")
|
||||
return tenant
|
||||
|
||||
@ -551,9 +551,9 @@ class AppService:
|
||||
keys = list(tool.keys())
|
||||
if len(keys) >= 4:
|
||||
# current tool standard
|
||||
provider_type = tool.get("provider_type", "")
|
||||
provider_id = tool.get("provider_id", "")
|
||||
tool_name = tool.get("tool_name", "")
|
||||
provider_type = str(tool.get("provider_type", ""))
|
||||
provider_id = str(tool.get("provider_id", ""))
|
||||
tool_name = str(tool.get("tool_name", ""))
|
||||
if provider_type == "builtin":
|
||||
meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon"
|
||||
elif provider_type == "api":
|
||||
|
||||
@ -128,6 +128,8 @@ class ConversationService:
|
||||
if auto_generate:
|
||||
return cls.auto_generate_name(app_model, conversation)
|
||||
else:
|
||||
if name is None:
|
||||
raise ValueError("name is required when auto_generate is false")
|
||||
conversation.name = name
|
||||
conversation.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
|
||||
@ -2173,7 +2173,7 @@ class DocumentService:
|
||||
)
|
||||
documents_map = {document.name: document for document in db_documents}
|
||||
for file in files:
|
||||
data_source_info: dict[str, str | bool] = {
|
||||
data_source_info: dict[str, object] = {
|
||||
"upload_file_id": file.id,
|
||||
}
|
||||
document = documents_map.get(file.name)
|
||||
@ -2706,7 +2706,7 @@ class DocumentService:
|
||||
# update document data source
|
||||
if document_data.data_source:
|
||||
file_name = ""
|
||||
data_source_info: dict[str, str | bool] = {}
|
||||
data_source_info: dict[str, object] = {}
|
||||
if document_data.data_source.info_list.data_source_type == "upload_file":
|
||||
if not document_data.data_source.info_list.file_info_list:
|
||||
raise ValueError("No file info list found.")
|
||||
|
||||
@ -6,6 +6,7 @@ from core.rag.entities import Rule
|
||||
from core.rag.entities.metadata_entities import MetadataFilteringCondition
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from models.enums import ProcessRuleMode
|
||||
|
||||
|
||||
class RerankingModel(BaseModel):
|
||||
@ -55,7 +56,7 @@ class DataSource(BaseModel):
|
||||
|
||||
|
||||
class ProcessRule(BaseModel):
|
||||
mode: Literal["automatic", "custom", "hierarchical"]
|
||||
mode: ProcessRuleMode
|
||||
rules: Rule | None = None
|
||||
|
||||
|
||||
|
||||
@ -90,7 +90,6 @@ class PluginMigration:
|
||||
|
||||
# Use lock when updating counter
|
||||
with counter_lock:
|
||||
nonlocal handled_tenant_count
|
||||
handled_tenant_count += 1
|
||||
click.echo(
|
||||
click.style(
|
||||
|
||||
@ -467,7 +467,7 @@ class TriggerProviderService:
|
||||
if not subscription:
|
||||
raise ValueError(f"Trigger provider subscription {subscription_id} not found")
|
||||
|
||||
if subscription.credential_type != CredentialType.OAUTH2.value:
|
||||
if subscription.credential_type != CredentialType.OAUTH2:
|
||||
raise ValueError("Only OAuth credentials can be refreshed")
|
||||
|
||||
provider_id = TriggerProviderID(subscription.provider_id)
|
||||
|
||||
@ -13,12 +13,14 @@ class SchedulerCommand(StrEnum):
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class CFSPlanScheduler(ABC):
|
||||
class CFSPlanScheduler[PlanT: WorkflowScheduleCFSPlanEntity](ABC):
|
||||
"""
|
||||
CFS plan scheduler.
|
||||
"""
|
||||
|
||||
def __init__(self, plan: WorkflowScheduleCFSPlanEntity):
|
||||
plan: PlanT
|
||||
|
||||
def __init__(self, plan: PlanT):
|
||||
"""
|
||||
Initialize the CFS plan scheduler.
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ from collections.abc import Mapping, Sequence, Set
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any, ClassVar, NotRequired, TypedDict
|
||||
from typing import Any, ClassVar, NotRequired, TypedDict, cast
|
||||
|
||||
from sqlalchemy import Engine, delete, orm, select
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
@ -1068,9 +1068,10 @@ class DraftVariableSaver:
|
||||
original_length = len(value_seg.value)
|
||||
|
||||
# Prepare content for storage
|
||||
original_content_serialized: str
|
||||
if isinstance(value_seg, StringSegment):
|
||||
# For string types, store as plain text
|
||||
original_content_serialized = value_seg.value
|
||||
original_content_serialized = cast(str, value_seg.value)
|
||||
content_type = "text/plain"
|
||||
filename = f"{self._generate_filename(name)}.txt"
|
||||
else:
|
||||
|
||||
@ -12,6 +12,7 @@ from core.db.session_factory import session_factory
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType
|
||||
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import SummaryStatus
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -102,8 +103,8 @@ def regenerate_summary_index_task(
|
||||
DocumentSegmentSummary.summary_content.isnot(None), # Must have summary content
|
||||
# Include completed summaries or error summaries (with content)
|
||||
or_(
|
||||
DocumentSegmentSummary.status == "completed",
|
||||
DocumentSegmentSummary.status == "error",
|
||||
DocumentSegmentSummary.status == SummaryStatus.COMPLETED,
|
||||
DocumentSegmentSummary.status == SummaryStatus.ERROR,
|
||||
),
|
||||
DatasetDocument.enabled == True, # Document must be enabled
|
||||
DatasetDocument.archived == False, # Document must not be archived
|
||||
@ -174,7 +175,7 @@ def regenerate_summary_index_task(
|
||||
)
|
||||
total_segments_failed += 1
|
||||
# Update summary record with error status
|
||||
summary_record.status = "error"
|
||||
summary_record.status = SummaryStatus.ERROR
|
||||
summary_record.error = f"Re-vectorization failed: {str(e)}"
|
||||
session.add(summary_record)
|
||||
session.commit()
|
||||
@ -240,10 +241,10 @@ def regenerate_summary_index_task(
|
||||
)
|
||||
|
||||
for segment in segments:
|
||||
summary_record = None
|
||||
existing_summary_record: DocumentSegmentSummary | None = None
|
||||
try:
|
||||
# Get existing summary record
|
||||
summary_record = session.scalar(
|
||||
existing_summary_record = session.scalar(
|
||||
select(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
@ -252,7 +253,7 @@ def regenerate_summary_index_task(
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if not summary_record:
|
||||
if not existing_summary_record:
|
||||
logger.warning("Summary record not found for segment %s, skipping", segment.id)
|
||||
continue
|
||||
|
||||
@ -272,10 +273,10 @@ def regenerate_summary_index_task(
|
||||
)
|
||||
total_segments_failed += 1
|
||||
# Update summary record with error status
|
||||
if summary_record:
|
||||
summary_record.status = "error"
|
||||
summary_record.error = f"Regeneration failed: {str(e)}"
|
||||
session.add(summary_record)
|
||||
if existing_summary_record is not None:
|
||||
existing_summary_record.status = SummaryStatus.ERROR
|
||||
existing_summary_record.error = f"Regeneration failed: {str(e)}"
|
||||
session.add(existing_summary_record)
|
||||
session.commit()
|
||||
continue
|
||||
|
||||
|
||||
@ -11,13 +11,11 @@ class AsyncWorkflowCFSPlanEntity(WorkflowScheduleCFSPlanEntity):
|
||||
queue: AsyncWorkflowQueue
|
||||
|
||||
|
||||
class AsyncWorkflowCFSPlanScheduler(CFSPlanScheduler):
|
||||
class AsyncWorkflowCFSPlanScheduler(CFSPlanScheduler[AsyncWorkflowCFSPlanEntity]):
|
||||
"""
|
||||
Trigger workflow CFS plan scheduler.
|
||||
"""
|
||||
|
||||
plan: AsyncWorkflowCFSPlanEntity
|
||||
|
||||
def can_schedule(self) -> SchedulerCommand:
|
||||
"""
|
||||
Check if the workflow can be scheduled.
|
||||
|
||||
@ -37,7 +37,7 @@ class TestCFSPlanScheduler:
|
||||
granularity=10,
|
||||
)
|
||||
with pytest.raises(TypeError):
|
||||
CFSPlanScheduler(plan)
|
||||
CFSPlanScheduler(plan) # type: ignore
|
||||
|
||||
def test_concrete_subclass_can_schedule(self):
|
||||
plan = WorkflowScheduleCFSPlanEntity(
|
||||
|
||||
@ -664,7 +664,7 @@ export type DataSource = {
|
||||
}
|
||||
|
||||
export type ProcessRule = {
|
||||
mode: 'automatic' | 'custom' | 'hierarchical'
|
||||
mode: ProcessRuleMode
|
||||
rules?: Rule
|
||||
}
|
||||
|
||||
@ -866,6 +866,8 @@ export type InfoList = {
|
||||
website_info_list?: WebsiteInfo
|
||||
}
|
||||
|
||||
export type ProcessRuleMode = 'automatic' | 'custom' | 'hierarchical'
|
||||
|
||||
export type Rule = {
|
||||
parent_mode?: 'full-doc' | 'paragraph' | null
|
||||
pre_processing_rules?: Array<PreProcessingRule> | null
|
||||
|
||||
@ -709,6 +709,13 @@ export const zDatasetRerankingModel = z.object({
|
||||
reranking_provider_name: z.string().optional(),
|
||||
})
|
||||
|
||||
/**
|
||||
* ProcessRuleMode
|
||||
*
|
||||
* Dataset Process Rule Mode
|
||||
*/
|
||||
export const zProcessRuleMode = z.enum(['automatic', 'custom', 'hierarchical'])
|
||||
|
||||
/**
|
||||
* RerankingModel
|
||||
*/
|
||||
@ -1128,7 +1135,7 @@ export const zRule = z.object({
|
||||
* ProcessRule
|
||||
*/
|
||||
export const zProcessRule = z.object({
|
||||
mode: z.enum(['automatic', 'custom', 'hierarchical']),
|
||||
mode: zProcessRuleMode,
|
||||
rules: zRule.optional(),
|
||||
})
|
||||
|
||||
|
||||
@ -721,10 +721,12 @@ export type PreProcessingRule = {
|
||||
}
|
||||
|
||||
export type ProcessRule = {
|
||||
mode: 'automatic' | 'custom' | 'hierarchical'
|
||||
mode: ProcessRuleMode
|
||||
rules?: Rule
|
||||
}
|
||||
|
||||
export type ProcessRuleMode = 'automatic' | 'custom' | 'hierarchical'
|
||||
|
||||
export type RerankingModel = {
|
||||
reranking_model_name?: string | null
|
||||
reranking_provider_name?: string | null
|
||||
|
||||
@ -894,6 +894,13 @@ export const zPreProcessingRule = z.object({
|
||||
id: z.string(),
|
||||
})
|
||||
|
||||
/**
|
||||
* ProcessRuleMode
|
||||
*
|
||||
* Dataset Process Rule Mode
|
||||
*/
|
||||
export const zProcessRuleMode = z.enum(['automatic', 'custom', 'hierarchical'])
|
||||
|
||||
/**
|
||||
* RerankingModel
|
||||
*/
|
||||
@ -1062,7 +1069,7 @@ export const zRule = z.object({
|
||||
* ProcessRule
|
||||
*/
|
||||
export const zProcessRule = z.object({
|
||||
mode: z.enum(['automatic', 'custom', 'hierarchical']),
|
||||
mode: zProcessRuleMode,
|
||||
rules: zRule.optional(),
|
||||
})
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user