diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 66bf77402f..818da5553d 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -2,7 +2,7 @@ import logging import re import uuid from datetime import datetime -from typing import Any, Literal +from typing import Any, Literal, cast from flask import request from flask_restx import Resource @@ -65,14 +65,13 @@ register_enum_models(console_ns, IconType) _logger = logging.getLogger(__name__) _TAG_IDS_BRACKET_PATTERN = re.compile(r"^tag_ids\[(\d+)\]$") _CREATOR_IDS_BRACKET_PATTERN = re.compile(r"^creator_ids\[(\d+)\]$") +AppListMode = Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"] class AppListQuery(BaseModel): page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)") limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)") - mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"] = Field( - default="all", description="App mode filter" - ) + mode: AppListMode = Field(default=cast(AppListMode, "all"), description="App mode filter") name: str | None = Field(default=None, description="Filter by app name") tag_ids: list[str] | None = Field(default=None, description="Filter by tag IDs") creator_ids: list[str] | None = Field(default=None, description="Filter by creator account IDs") diff --git a/api/controllers/console/auth/email_register.py b/api/controllers/console/auth/email_register.py index 15d8b2ccd2..c22b14159c 100644 --- a/api/controllers/console/auth/email_register.py +++ b/api/controllers/console/auth/email_register.py @@ -76,7 +76,7 @@ class EmailRegisterSendEmailApi(Resource): if AccountService.is_email_send_ip_limit(ip_address): raise EmailSendIpLimitError() language = "en-US" - if args.language in languages: + if args.language is not None and args.language in languages: language = args.language if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email): diff --git a/api/controllers/fastopenapi.py b/api/controllers/fastopenapi.py index c13f22338b..260e7a83e2 100644 --- a/api/controllers/fastopenapi.py +++ b/api/controllers/fastopenapi.py @@ -1,3 +1,3 @@ -from fastopenapi.routers import FlaskRouter +from fastopenapi.routers.flask import FlaskRouter console_router = FlaskRouter() diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 6afd22a6cc..c9bee61541 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -221,9 +221,9 @@ class TracingProviderConfigEntry(TypedDict): class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigEntry]): - def __getitem__(self, provider: str) -> TracingProviderConfigEntry: + def __getitem__(self, key: str) -> TracingProviderConfigEntry: try: - match provider: + match key: case TracingProviderEnum.LANGFUSE: from dify_trace_langfuse.config import LangfuseConfig from dify_trace_langfuse.langfuse_trace import LangFuseDataTrace @@ -330,9 +330,9 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, TracingProviderConfigE } case _: - raise KeyError(f"Unsupported tracing provider: {provider}") + raise KeyError(f"Unsupported tracing provider: {key}") except ImportError: - raise ImportError(f"Provider {provider} is not installed.") + raise ImportError(f"Provider {key} is not installed.") provider_config_map = OpsTraceProviderConfigMap() diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 4b7e5932ba..99be960a20 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -412,7 +412,7 @@ class ToolManager: tool_invoke_from=ToolInvokeFrom.AGENT, credential_id=agent_tool.credential_id, ) - runtime_parameters = {} + runtime_parameters: dict[str, Any] = {} parameters = tool_entity.get_merged_runtime_parameters() runtime_parameters = cls._convert_tool_parameters_type( parameters, variable_pool, agent_tool.tool_parameters, typ="agent" @@ -501,7 +501,7 @@ class ToolManager: tool_invoke_from=ToolInvokeFrom.PLUGIN, credential_id=credential_id, ) - runtime_parameters = {} + runtime_parameters: dict[str, Any] = {} parameters = tool_entity.get_merged_runtime_parameters() for parameter in parameters: if parameter.form == ToolParameter.ToolParameterForm.FORM: @@ -1070,7 +1070,7 @@ class ToolManager: from graphon.nodes.tool.entities import ToolNodeData from graphon.nodes.tool.exc import ToolParameterError - runtime_parameters = {} + runtime_parameters: dict[str, Any] = {} for parameter in parameters: if ( parameter.type diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 677aa4a533..c156fd888f 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -63,7 +63,7 @@ def get_url(url: str, user_agent: str | None = None) -> str: response = remote_fetcher.make_request("GET", url, headers=headers, follow_redirects=True, timeout=(120, 300)) elif response.status_code == 403: scraper = cloudscraper.create_scraper() - scraper.perform_request = remote_fetcher.make_request + object.__setattr__(scraper, "perform_request", remote_fetcher.make_request) response = scraper.get(url, headers=headers, timeout=(120, 300)) if response.status_code != 200: diff --git a/api/openapi/markdown/console-swagger.md b/api/openapi/markdown/console-swagger.md index b3c199d558..b8f89bc24d 100644 --- a/api/openapi/markdown/console-swagger.md +++ b/api/openapi/markdown/console-swagger.md @@ -16080,9 +16080,17 @@ Shared permission levels for resources (datasets, credentials, etc.) | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| mode | string | *Enum:* `"automatic"`, `"custom"`, `"hierarchical"` | Yes | +| mode | [ProcessRuleMode](#processrulemode) | | Yes | | rules | [Rule](#rule) | | No | +#### ProcessRuleMode + +Dataset Process Rule Mode + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| ProcessRuleMode | string | Dataset Process Rule Mode | | + #### PublishWorkflowPayload Payload for publishing snippet workflow. diff --git a/api/openapi/markdown/service-swagger.md b/api/openapi/markdown/service-swagger.md index 735d610eff..687a47e012 100644 --- a/api/openapi/markdown/service-swagger.md +++ b/api/openapi/markdown/service-swagger.md @@ -3039,9 +3039,17 @@ Shared permission levels for resources (datasets, credentials, etc.) | Name | Type | Description | Required | | ---- | ---- | ----------- | -------- | -| mode | string | *Enum:* `"automatic"`, `"custom"`, `"hierarchical"` | Yes | +| mode | [ProcessRuleMode](#processrulemode) | | Yes | | rules | [Rule](#rule) | | No | +#### ProcessRuleMode + +Dataset Process Rule Mode + +| Name | Type | Description | Required | +| ---- | ---- | ----------- | -------- | +| ProcessRuleMode | string | Dataset Process Rule Mode | | + #### RerankingModel | Name | Type | Description | Required | diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt index 2b7baea931..c97014ebcf 100644 --- a/api/pyrefly-local-excludes.txt +++ b/api/pyrefly-local-excludes.txt @@ -1,16 +1,8 @@ -controllers/console/app/app.py -controllers/console/auth/email_register.py -controllers/console/init_validate.py -controllers/console/ping.py -controllers/console/setup.py -controllers/console/version.py core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py core/llm_generator/llm_generator.py providers/trace/trace-mlflow/src/dify_trace_mlflow/mlflow_trace.py -core/ops/ops_trace_manager.py core/prompt/utils/prompt_message_util.py core/rag/retrieval/dataset_retrieval.py -core/tools/tool_manager.py extensions/ext_celery.py providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_factory.py providers/vdb/vdb-alibabacloud-mysql/tests/unit_tests/test_alibabacloud_mysql_vector.py @@ -74,20 +66,9 @@ core/rag/index_processor/processor/parent_child_index_processor.py core/rag/index_processor/processor/qa_index_processor.py core/tools/mcp_tool/provider.py core/tools/plugin_tool/provider.py -core/tools/utils/web_reader_tool.py core/tools/workflow_as_tool/provider.py extensions/storage/huawei_obs_storage.py libs/gmpy2_pkcs10aep_cipher.py -schedule/queue_monitor_task.py -services/account_service.py services/audio_service.py -services/conversation_service.py -services/dataset_service.py -services/app_service.py services/document_indexing_proxy/document_indexing_task_proxy.py services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py -services/plugin/plugin_migration.py -services/trigger/trigger_provider_service.py -services/workflow_draft_variable_service.py -tasks/regenerate_summary_index_task.py -tasks/workflow_cfs_scheduler/cfs_scheduler.py diff --git a/api/schedule/queue_monitor_task.py b/api/schedule/queue_monitor_task.py index 01642e397e..d9060c1854 100644 --- a/api/schedule/queue_monitor_task.py +++ b/api/schedule/queue_monitor_task.py @@ -12,11 +12,11 @@ from libs.email_i18n import EmailType, get_email_i18n_service redis_config = parse_url(dify_config.CELERY_BROKER_URL) celery_redis = Redis( - host=redis_config.get("hostname") or "localhost", - port=redis_config.get("port") or 6379, - password=redis_config.get("password") or None, + host=str(redis_config.get("hostname") or "localhost"), + port=int(redis_config.get("port") or 6379), + password=str(pwd) if (pwd := redis_config.get("password")) is not None else None, db=int(redis_config.get("virtual_host")) if redis_config.get("virtual_host") else 1, - ssl=bool(dify_config.BROKER_USE_SSL), + ssl=dify_config.BROKER_USE_SSL, ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS if dify_config.BROKER_USE_SSL else None, ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None, ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None, diff --git a/api/services/account_service.py b/api/services/account_service.py index 6eb35cb23b..1ab5fbd450 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -1414,7 +1414,7 @@ class TenantService: .limit(1) ) if ta: - tenant.role = ta.role + object.__setattr__(tenant, "role", ta.role) else: raise TenantNotFoundError("Tenant not found for the account.") return tenant diff --git a/api/services/app_service.py b/api/services/app_service.py index 22fbf41f84..f288b04f3b 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -551,9 +551,9 @@ class AppService: keys = list(tool.keys()) if len(keys) >= 4: # current tool standard - provider_type = tool.get("provider_type", "") - provider_id = tool.get("provider_id", "") - tool_name = tool.get("tool_name", "") + provider_type = str(tool.get("provider_type", "")) + provider_id = str(tool.get("provider_id", "")) + tool_name = str(tool.get("tool_name", "")) if provider_type == "builtin": meta["tool_icons"][tool_name] = url_prefix + provider_id + "/icon" elif provider_type == "api": diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index ee8a1c4edd..557ae8e89f 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -128,6 +128,8 @@ class ConversationService: if auto_generate: return cls.auto_generate_name(app_model, conversation) else: + if name is None: + raise ValueError("name is required when auto_generate is false") conversation.name = name conversation.updated_at = naive_utc_now() db.session.commit() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 977ee1192c..ddf35e35e6 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -2173,7 +2173,7 @@ class DocumentService: ) documents_map = {document.name: document for document in db_documents} for file in files: - data_source_info: dict[str, str | bool] = { + data_source_info: dict[str, object] = { "upload_file_id": file.id, } document = documents_map.get(file.name) @@ -2706,7 +2706,7 @@ class DocumentService: # update document data source if document_data.data_source: file_name = "" - data_source_info: dict[str, str | bool] = {} + data_source_info: dict[str, object] = {} if document_data.data_source.info_list.data_source_type == "upload_file": if not document_data.data_source.info_list.file_info_list: raise ValueError("No file info list found.") diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 910f54bebc..4040b03cc4 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -6,6 +6,7 @@ from core.rag.entities import Rule from core.rag.entities.metadata_entities import MetadataFilteringCondition from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod +from models.enums import ProcessRuleMode class RerankingModel(BaseModel): @@ -55,7 +56,7 @@ class DataSource(BaseModel): class ProcessRule(BaseModel): - mode: Literal["automatic", "custom", "hierarchical"] + mode: ProcessRuleMode rules: Rule | None = None diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index 88b9eeefa1..8239186bbc 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -90,7 +90,6 @@ class PluginMigration: # Use lock when updating counter with counter_lock: - nonlocal handled_tenant_count handled_tenant_count += 1 click.echo( click.style( diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 005bcd40a8..d5d258aeab 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -467,7 +467,7 @@ class TriggerProviderService: if not subscription: raise ValueError(f"Trigger provider subscription {subscription_id} not found") - if subscription.credential_type != CredentialType.OAUTH2.value: + if subscription.credential_type != CredentialType.OAUTH2: raise ValueError("Only OAuth credentials can be refreshed") provider_id = TriggerProviderID(subscription.provider_id) diff --git a/api/services/workflow/scheduler.py b/api/services/workflow/scheduler.py index 7728c7f470..a22e9d33ab 100644 --- a/api/services/workflow/scheduler.py +++ b/api/services/workflow/scheduler.py @@ -13,12 +13,14 @@ class SchedulerCommand(StrEnum): NONE = "none" -class CFSPlanScheduler(ABC): +class CFSPlanScheduler[PlanT: WorkflowScheduleCFSPlanEntity](ABC): """ CFS plan scheduler. """ - def __init__(self, plan: WorkflowScheduleCFSPlanEntity): + plan: PlanT + + def __init__(self, plan: PlanT): """ Initialize the CFS plan scheduler. diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index af25edd2bd..5151b7a08c 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -5,7 +5,7 @@ from collections.abc import Mapping, Sequence, Set from concurrent.futures import ThreadPoolExecutor from datetime import datetime from enum import StrEnum -from typing import Any, ClassVar, NotRequired, TypedDict +from typing import Any, ClassVar, NotRequired, TypedDict, cast from sqlalchemy import Engine, delete, orm, select from sqlalchemy.dialects.mysql import insert as mysql_insert @@ -1068,9 +1068,10 @@ class DraftVariableSaver: original_length = len(value_seg.value) # Prepare content for storage + original_content_serialized: str if isinstance(value_seg, StringSegment): # For string types, store as plain text - original_content_serialized = value_seg.value + original_content_serialized = cast(str, value_seg.value) content_type = "text/plain" filename = f"{self._generate_filename(name)}.txt" else: diff --git a/api/tasks/regenerate_summary_index_task.py b/api/tasks/regenerate_summary_index_task.py index e794195c92..16b59fdbba 100644 --- a/api/tasks/regenerate_summary_index_task.py +++ b/api/tasks/regenerate_summary_index_task.py @@ -12,6 +12,7 @@ from core.db.session_factory import session_factory from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument +from models.enums import SummaryStatus from services.summary_index_service import SummaryIndexService logger = logging.getLogger(__name__) @@ -102,8 +103,8 @@ def regenerate_summary_index_task( DocumentSegmentSummary.summary_content.isnot(None), # Must have summary content # Include completed summaries or error summaries (with content) or_( - DocumentSegmentSummary.status == "completed", - DocumentSegmentSummary.status == "error", + DocumentSegmentSummary.status == SummaryStatus.COMPLETED, + DocumentSegmentSummary.status == SummaryStatus.ERROR, ), DatasetDocument.enabled == True, # Document must be enabled DatasetDocument.archived == False, # Document must not be archived @@ -174,7 +175,7 @@ def regenerate_summary_index_task( ) total_segments_failed += 1 # Update summary record with error status - summary_record.status = "error" + summary_record.status = SummaryStatus.ERROR summary_record.error = f"Re-vectorization failed: {str(e)}" session.add(summary_record) session.commit() @@ -240,10 +241,10 @@ def regenerate_summary_index_task( ) for segment in segments: - summary_record = None + existing_summary_record: DocumentSegmentSummary | None = None try: # Get existing summary record - summary_record = session.scalar( + existing_summary_record = session.scalar( select(DocumentSegmentSummary) .where( DocumentSegmentSummary.chunk_id == segment.id, @@ -252,7 +253,7 @@ def regenerate_summary_index_task( .limit(1) ) - if not summary_record: + if not existing_summary_record: logger.warning("Summary record not found for segment %s, skipping", segment.id) continue @@ -272,10 +273,10 @@ def regenerate_summary_index_task( ) total_segments_failed += 1 # Update summary record with error status - if summary_record: - summary_record.status = "error" - summary_record.error = f"Regeneration failed: {str(e)}" - session.add(summary_record) + if existing_summary_record is not None: + existing_summary_record.status = SummaryStatus.ERROR + existing_summary_record.error = f"Regeneration failed: {str(e)}" + session.add(existing_summary_record) session.commit() continue diff --git a/api/tasks/workflow_cfs_scheduler/cfs_scheduler.py b/api/tasks/workflow_cfs_scheduler/cfs_scheduler.py index 218e61f6d9..972501290b 100644 --- a/api/tasks/workflow_cfs_scheduler/cfs_scheduler.py +++ b/api/tasks/workflow_cfs_scheduler/cfs_scheduler.py @@ -11,13 +11,11 @@ class AsyncWorkflowCFSPlanEntity(WorkflowScheduleCFSPlanEntity): queue: AsyncWorkflowQueue -class AsyncWorkflowCFSPlanScheduler(CFSPlanScheduler): +class AsyncWorkflowCFSPlanScheduler(CFSPlanScheduler[AsyncWorkflowCFSPlanEntity]): """ Trigger workflow CFS plan scheduler. """ - plan: AsyncWorkflowCFSPlanEntity - def can_schedule(self) -> SchedulerCommand: """ Check if the workflow can be scheduled. diff --git a/api/tests/unit_tests/services/workflow/test_scheduler.py b/api/tests/unit_tests/services/workflow/test_scheduler.py index 90b6cb2d8b..3b44884161 100644 --- a/api/tests/unit_tests/services/workflow/test_scheduler.py +++ b/api/tests/unit_tests/services/workflow/test_scheduler.py @@ -37,7 +37,7 @@ class TestCFSPlanScheduler: granularity=10, ) with pytest.raises(TypeError): - CFSPlanScheduler(plan) + CFSPlanScheduler(plan) # type: ignore def test_concrete_subclass_can_schedule(self): plan = WorkflowScheduleCFSPlanEntity( diff --git a/packages/contracts/generated/api/console/datasets/types.gen.ts b/packages/contracts/generated/api/console/datasets/types.gen.ts index 4d15231980..a9170dd261 100644 --- a/packages/contracts/generated/api/console/datasets/types.gen.ts +++ b/packages/contracts/generated/api/console/datasets/types.gen.ts @@ -664,7 +664,7 @@ export type DataSource = { } export type ProcessRule = { - mode: 'automatic' | 'custom' | 'hierarchical' + mode: ProcessRuleMode rules?: Rule } @@ -866,6 +866,8 @@ export type InfoList = { website_info_list?: WebsiteInfo } +export type ProcessRuleMode = 'automatic' | 'custom' | 'hierarchical' + export type Rule = { parent_mode?: 'full-doc' | 'paragraph' | null pre_processing_rules?: Array | null diff --git a/packages/contracts/generated/api/console/datasets/zod.gen.ts b/packages/contracts/generated/api/console/datasets/zod.gen.ts index e4e2f37234..44b491d01a 100644 --- a/packages/contracts/generated/api/console/datasets/zod.gen.ts +++ b/packages/contracts/generated/api/console/datasets/zod.gen.ts @@ -709,6 +709,13 @@ export const zDatasetRerankingModel = z.object({ reranking_provider_name: z.string().optional(), }) +/** + * ProcessRuleMode + * + * Dataset Process Rule Mode + */ +export const zProcessRuleMode = z.enum(['automatic', 'custom', 'hierarchical']) + /** * RerankingModel */ @@ -1128,7 +1135,7 @@ export const zRule = z.object({ * ProcessRule */ export const zProcessRule = z.object({ - mode: z.enum(['automatic', 'custom', 'hierarchical']), + mode: zProcessRuleMode, rules: zRule.optional(), }) diff --git a/packages/contracts/generated/api/service/types.gen.ts b/packages/contracts/generated/api/service/types.gen.ts index 4cba0beceb..c90316be79 100644 --- a/packages/contracts/generated/api/service/types.gen.ts +++ b/packages/contracts/generated/api/service/types.gen.ts @@ -721,10 +721,12 @@ export type PreProcessingRule = { } export type ProcessRule = { - mode: 'automatic' | 'custom' | 'hierarchical' + mode: ProcessRuleMode rules?: Rule } +export type ProcessRuleMode = 'automatic' | 'custom' | 'hierarchical' + export type RerankingModel = { reranking_model_name?: string | null reranking_provider_name?: string | null diff --git a/packages/contracts/generated/api/service/zod.gen.ts b/packages/contracts/generated/api/service/zod.gen.ts index 89f4b0e81e..086cf1913b 100644 --- a/packages/contracts/generated/api/service/zod.gen.ts +++ b/packages/contracts/generated/api/service/zod.gen.ts @@ -894,6 +894,13 @@ export const zPreProcessingRule = z.object({ id: z.string(), }) +/** + * ProcessRuleMode + * + * Dataset Process Rule Mode + */ +export const zProcessRuleMode = z.enum(['automatic', 'custom', 'hierarchical']) + /** * RerankingModel */ @@ -1062,7 +1069,7 @@ export const zRule = z.object({ * ProcessRule */ export const zProcessRule = z.object({ - mode: z.enum(['automatic', 'custom', 'hierarchical']), + mode: zProcessRuleMode, rules: zRule.optional(), })