diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 6c02646c22..a8077d9eb0 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -346,89 +346,6 @@ class PublishedRagPipelineRunApi(Resource): raise InvokeRateLimitHttpError(ex.description) -# class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource): -# @setup_required -# @login_required -# @account_initialization_required -# @get_rag_pipeline -# def post(self, pipeline: Pipeline, node_id: str): -# """ -# Run rag pipeline datasource -# """ -# # The role of the current user in the ta table must be admin, owner, or editor -# if not current_user.has_edit_permission: -# raise Forbidden() -# -# if not isinstance(current_user, Account): -# raise Forbidden() -# -# parser = (reqparse.RequestParser() -# .add_argument("job_id", type=str, required=True, nullable=False, location="json") -# .add_argument("datasource_type", type=str, required=True, location="json") -# ) -# args = parser.parse_args() -# -# job_id = args.get("job_id") -# if job_id == None: -# raise ValueError("missing job_id") -# datasource_type = args.get("datasource_type") -# if datasource_type == None: -# raise ValueError("missing datasource_type") -# -# rag_pipeline_service = RagPipelineService() -# result = rag_pipeline_service.run_datasource_workflow_node_status( -# pipeline=pipeline, -# node_id=node_id, -# job_id=job_id, -# account=current_user, -# datasource_type=datasource_type, -# is_published=True -# ) -# -# return result - - -# class RagPipelineDraftDatasourceNodeRunStatusApi(Resource): -# @setup_required -# @login_required -# @account_initialization_required -# @get_rag_pipeline -# def post(self, pipeline: Pipeline, node_id: str): -# """ -# Run rag pipeline datasource -# """ -# # The role of the current user in the ta table must be admin, owner, or editor -# if not current_user.has_edit_permission: -# raise Forbidden() -# -# if not isinstance(current_user, Account): -# raise Forbidden() -# -# parser = (reqparse.RequestParser() -# .add_argument("job_id", type=str, required=True, nullable=False, location="json") -# .add_argument("datasource_type", type=str, required=True, location="json") -# ) -# args = parser.parse_args() -# -# job_id = args.get("job_id") -# if job_id == None: -# raise ValueError("missing job_id") -# datasource_type = args.get("datasource_type") -# if datasource_type == None: -# raise ValueError("missing datasource_type") -# -# rag_pipeline_service = RagPipelineService() -# result = rag_pipeline_service.run_datasource_workflow_node_status( -# pipeline=pipeline, -# node_id=node_id, -# job_id=job_id, -# account=current_user, -# datasource_type=datasource_type, -# is_published=False -# ) -# -# return result -# @console_ns.route("/rag/pipelines//workflows/published/datasource/nodes//run") class RagPipelinePublishedDatasourceNodeRunApi(Resource): @console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__]) diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py index 5d79e1b5e9..845af37365 100644 --- a/api/controllers/console/human_input_form.py +++ b/api/controllers/console/human_input_form.py @@ -7,7 +7,8 @@ import logging from collections.abc import Generator from flask import Response, jsonify, request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -33,6 +34,11 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream logger = logging.getLogger(__name__) +class HumanInputFormSubmitPayload(BaseModel): + inputs: dict + action: str + + def _jsonify_form_definition(form: Form) -> Response: payload = form.get_definition().model_dump() payload["expiration_time"] = int(form.expiration_time.timestamp()) @@ -84,10 +90,7 @@ class ConsoleHumanInputFormApi(Resource): "action": "Approve" } """ - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("action", type=str, required=True, location="json") - args = parser.parse_args() + payload = HumanInputFormSubmitPayload.model_validate(request.get_json()) current_user, _ = current_account_with_tenant() service = HumanInputService(db.engine) @@ -107,8 +110,8 @@ class ConsoleHumanInputFormApi(Resource): service.submit_form_by_token( recipient_type=recipient_type, form_token=form_token, - selected_action_id=args["action"], - form_data=args["inputs"], + selected_action_id=payload.action, + form_data=payload.inputs, submission_user_id=current_user.id, ) diff --git a/api/controllers/web/human_input_form.py b/api/controllers/web/human_input_form.py index 36728a47d1..aff0b42d95 100644 --- a/api/controllers/web/human_input_form.py +++ b/api/controllers/web/human_input_form.py @@ -7,7 +7,8 @@ import logging from datetime import datetime from flask import Response, request -from flask_restx import Resource, reqparse +from flask_restx import Resource +from pydantic import BaseModel from sqlalchemy import select from werkzeug.exceptions import Forbidden @@ -23,6 +24,12 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ logger = logging.getLogger(__name__) + +class HumanInputFormSubmitPayload(BaseModel): + inputs: dict + action: str + + _FORM_SUBMIT_RATE_LIMITER = RateLimiter( prefix="web_form_submit_rate_limit", max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS, @@ -112,10 +119,7 @@ class HumanInputFormApi(Resource): "action": "Approve" } """ - parser = reqparse.RequestParser() - parser.add_argument("inputs", type=dict, required=True, location="json") - parser.add_argument("action", type=str, required=True, location="json") - args = parser.parse_args() + payload = HumanInputFormSubmitPayload.model_validate(request.get_json()) ip_address = extract_remote_ip(request) if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address): @@ -135,8 +139,8 @@ class HumanInputFormApi(Resource): service.submit_form_by_token( recipient_type=recipient_type, form_token=form_token, - selected_action_id=args["action"], - form_data=args["inputs"], + selected_action_id=payload.action, + form_data=payload.inputs, submission_end_user_id=None, # submission_end_user_id=_end_user.id, ) diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index b4d2310da8..36daaf09e9 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -2,7 +2,6 @@ import logging import time from typing import cast -from graphon.entities import GraphInitParams from graphon.enums import WorkflowType from graphon.graph import Graph from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent @@ -22,7 +21,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository -from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from core.workflow.node_factory import DifyGraphInitContext, DifyNodeFactory, get_default_root_node_id from core.workflow.system_variables import build_bootstrap_variables, build_system_variables from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry @@ -265,22 +264,23 @@ class PipelineRunner(WorkflowBasedAppRunner): # graph_config["nodes"] = real_run_nodes # graph_config["edges"] = real_edges # init graph - # Create required parameters for Graph.init - graph_init_params = GraphInitParams( + # Create explicit graph init context for Graph.init. + run_context = build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=self._app_id, + user_id=self.application_generate_entity.user_id, + user_from=user_from, + invoke_from=invoke_from, + ) + graph_init_context = DifyGraphInitContext( workflow_id=workflow.id, graph_config=graph_config, - run_context=build_dify_run_context( - tenant_id=workflow.tenant_id, - app_id=self._app_id, - user_id=self.application_generate_entity.user_id, - user_from=user_from, - invoke_from=invoke_from, - ), + run_context=run_context, call_depth=0, ) - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, + node_factory = DifyNodeFactory.from_graph_init_context( + graph_init_context=graph_init_context, graph_runtime_state=graph_runtime_state, ) if start_node_id is None: diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index caa6b82bab..437432611d 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -3,7 +3,6 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast -from graphon.entities import GraphInitParams from graphon.entities.graph_config import NodeConfigDictAdapter from graphon.entities.pause_reason import HumanInputRequired from graphon.graph import Graph @@ -67,7 +66,12 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.rag.entities import RetrievalSourceMetadata -from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class +from core.workflow.node_factory import ( + DifyGraphInitContext, + DifyNodeFactory, + get_default_root_node_id, + resolve_workflow_node_class, +) from core.workflow.system_variables import ( build_bootstrap_variables, default_system_variables, @@ -127,24 +131,25 @@ class WorkflowBasedAppRunner: if not isinstance(graph_config.get("edges"), list): raise ValueError("edges in workflow graph must be a list") - # Create required parameters for Graph.init - graph_init_params = GraphInitParams( + # Create explicit graph init context for Graph.init. + run_context = build_dify_run_context( + tenant_id=tenant_id or "", + app_id=self._app_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + ) + graph_init_context = DifyGraphInitContext( workflow_id=workflow_id, graph_config=graph_config, - run_context=build_dify_run_context( - tenant_id=tenant_id or "", - app_id=self._app_id, - user_id=user_id, - user_from=user_from, - invoke_from=invoke_from, - ), + run_context=run_context, call_depth=0, ) # Use the provided graph_runtime_state for consistent state management - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, + node_factory = DifyNodeFactory.from_graph_init_context( + graph_init_context=graph_init_context, graph_runtime_state=graph_runtime_state, ) @@ -289,22 +294,23 @@ class WorkflowBasedAppRunner: typed_node_configs = [NodeConfigDictAdapter.validate_python(node) for node in node_configs] - # Create required parameters for Graph.init - graph_init_params = GraphInitParams( + # Create explicit graph init context for Graph.init. + run_context = build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=self._app_id, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + graph_init_context = DifyGraphInitContext( workflow_id=workflow.id, graph_config=graph_config, - run_context=build_dify_run_context( - tenant_id=workflow.tenant_id, - app_id=self._app_id, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - ), + run_context=run_context, call_depth=0, ) - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, + node_factory = DifyNodeFactory.from_graph_init_context( + graph_init_context=graph_init_context, graph_runtime_state=graph_runtime_state, ) diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index d015769b54..1d8356acf6 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -146,7 +146,7 @@ def discover_protected_resource_metadata( return ProtectedResourceMetadata.model_validate(response.json()) elif response.status_code == 404: continue # Try next URL - except (RequestError, ValidationError): + except (RequestError, ValidationError, json.JSONDecodeError): continue # Try next URL return None @@ -166,7 +166,7 @@ def discover_oauth_authorization_server_metadata( return OAuthMetadata.model_validate(response.json()) elif response.status_code == 404: continue # Try next URL - except (RequestError, ValidationError): + except (RequestError, ValidationError, json.JSONDecodeError): continue # Try next URL return None @@ -276,7 +276,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]: else: return False, "" return False, "" - except RequestError: + except (RequestError, json.JSONDecodeError, IndexError): # Not support resource discovery, fall back to well-known OAuth metadata return False, "" diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 09c84538a9..5809d6f74a 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -61,27 +61,28 @@ class TokenBufferMemory: :param is_user_message: whether this is a user message :return: PromptMessage """ - if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}: - file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) - elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: - app = self.conversation.app - if not app: - raise ValueError("App not found for conversation") + match self.conversation.mode: + case AppMode.AGENT_CHAT | AppMode.COMPLETION | AppMode.CHAT: + file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) + case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW: + app = self.conversation.app + if not app: + raise ValueError("App not found for conversation") - if not message.workflow_run_id: - raise ValueError("Workflow run ID not found") + if not message.workflow_run_id: + raise ValueError("Workflow run ID not found") - workflow_run = self.workflow_run_repo.get_workflow_run_by_id( - tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id - ) - if not workflow_run: - raise ValueError(f"Workflow run not found: {message.workflow_run_id}") - workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id)) - if not workflow: - raise ValueError(f"Workflow not found: {workflow_run.workflow_id}") - file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) - else: - raise AssertionError(f"Invalid app mode: {self.conversation.mode}") + workflow_run = self.workflow_run_repo.get_workflow_run_by_id( + tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id + ) + if not workflow_run: + raise ValueError(f"Workflow run not found: {message.workflow_run_id}") + workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id)) + if not workflow: + raise ValueError(f"Workflow not found: {workflow_run.workflow_id}") + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + case _: + raise AssertionError(f"Invalid app mode: {self.conversation.mode}") detail = ImagePromptMessageContent.DETAIL.HIGH if file_extra_config and app_record: diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py index df02c584ed..90d6d98c63 100644 --- a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py @@ -5,6 +5,7 @@ from typing import Any from elasticsearch import Elasticsearch from pydantic import BaseModel, model_validator +from typing_extensions import TypedDict from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -19,6 +20,16 @@ from models.dataset import Dataset logger = logging.getLogger(__name__) +class HuaweiElasticsearchParamsDict(TypedDict, total=False): + hosts: list[str] + verify_certs: bool + ssl_show_warn: bool + request_timeout: int + retry_on_timeout: bool + max_retries: int + basic_auth: tuple[str, str] + + def create_ssl_context() -> ssl.SSLContext: ssl_context = ssl.create_default_context() ssl_context.check_hostname = False @@ -38,15 +49,15 @@ class HuaweiCloudVectorConfig(BaseModel): raise ValueError("config HOSTS is required") return values - def to_elasticsearch_params(self) -> dict[str, Any]: - params = { - "hosts": self.hosts.split(","), - "verify_certs": False, - "ssl_show_warn": False, - "request_timeout": 30000, - "retry_on_timeout": True, - "max_retries": 10, - } + def to_elasticsearch_params(self) -> HuaweiElasticsearchParamsDict: + params = HuaweiElasticsearchParamsDict( + hosts=self.hosts.split(","), + verify_certs=False, + ssl_show_warn=False, + request_timeout=30000, + retry_on_timeout=True, + max_retries=10, + ) if self.username and self.password: params["basic_auth"] = (self.username, self.password) return params diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py index bfcb620618..fbe0bcad02 100644 --- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py +++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py @@ -7,6 +7,7 @@ from opensearchpy import OpenSearch, helpers from opensearchpy.helpers import BulkIndexError from pydantic import BaseModel, model_validator from tenacity import retry, stop_after_attempt, wait_exponential +from typing_extensions import TypedDict from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -26,6 +27,14 @@ ROUTING_FIELD = "routing_field" UGC_INDEX_PREFIX = "ugc_index" +class LindormOpenSearchParamsDict(TypedDict, total=False): + hosts: str | None + use_ssl: bool + pool_maxsize: int + timeout: int + http_auth: tuple[str, str] + + class LindormVectorStoreConfig(BaseModel): hosts: str | None username: str | None = None @@ -44,13 +53,13 @@ class LindormVectorStoreConfig(BaseModel): raise ValueError("config PASSWORD is required") return values - def to_opensearch_params(self) -> dict[str, Any]: - params: dict[str, Any] = { - "hosts": self.hosts, - "use_ssl": False, - "pool_maxsize": 128, - "timeout": 30, - } + def to_opensearch_params(self) -> LindormOpenSearchParamsDict: + params = LindormOpenSearchParamsDict( + hosts=self.hosts, + use_ssl=False, + pool_maxsize=128, + timeout=30, + ) if self.username and self.password: params["http_auth"] = (self.username, self.password) return params diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index 2f77776807..50d18cdc4c 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -6,6 +6,7 @@ from uuid import uuid4 from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers from opensearchpy.helpers import BulkIndexError from pydantic import BaseModel, model_validator +from typing_extensions import TypedDict from configs import dify_config from configs.middleware.vdb.opensearch_config import AuthMethod @@ -21,6 +22,20 @@ from models.dataset import Dataset logger = logging.getLogger(__name__) +class _OpenSearchHostDict(TypedDict): + host: str + port: int + + +class OpenSearchParamsDict(TypedDict, total=False): + hosts: list[_OpenSearchHostDict] + use_ssl: bool + verify_certs: bool + connection_class: type + pool_maxsize: int + http_auth: tuple[str | None, str | None] | Urllib3AWSV4SignerAuth + + class OpenSearchConfig(BaseModel): host: str port: int @@ -57,14 +72,14 @@ class OpenSearchConfig(BaseModel): service=self.aws_service, # type: ignore[arg-type] ) - def to_opensearch_params(self) -> dict[str, Any]: - params = { - "hosts": [{"host": self.host, "port": self.port}], - "use_ssl": self.secure, - "verify_certs": self.verify_certs, - "connection_class": Urllib3HttpConnection, - "pool_maxsize": 20, - } + def to_opensearch_params(self) -> OpenSearchParamsDict: + params = OpenSearchParamsDict( + hosts=[{"host": self.host, "port": self.port}], + use_ssl=self.secure, + verify_certs=self.verify_certs, + connection_class=Urllib3HttpConnection, + pool_maxsize=20, + ) if self.auth_method == "basic": logger.info("Using basic authentication for OpenSearch Vector DB") diff --git a/api/core/tools/utils/text_processing_utils.py b/api/core/tools/utils/text_processing_utils.py index 4bfaa5e49b..1dd0605f28 100644 --- a/api/core/tools/utils/text_processing_utils.py +++ b/api/core/tools/utils/text_processing_utils.py @@ -19,5 +19,18 @@ def remove_leading_symbols(text: str) -> str: # Match Unicode ranges for punctuation and symbols # FIXME this pattern is confused quick fix for #11868 maybe refactor it later - pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+' + pattern = re.compile( + r""" + ^ + (?: + [\u2000-\u2025] # General Punctuation: spaces, quotes, dashes + | [\u2027-\u206F] # General Punctuation: ellipsis, underscores, etc. + | [\u2E00-\u2E7F] # Supplemental Punctuation: medieval, ancient marks + | [\u3000-\u300F] # CJK Punctuation: 、。〃「」『》』 (excludes 【】) + | [\u3012-\u303F] # CJK Punctuation: 〖〗〔〕〘〙〚〛〜 etc. + | ["#$%&'()*+,./:;<=>?@^_`~] # ASCII punctuation (excludes []【】) + )+ + """, + re.VERBOSE, + ) return re.sub(pattern, "", text) diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index f6c3aee4c1..b04ac7da3d 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -1,6 +1,7 @@ import importlib import pkgutil from collections.abc import Callable, Iterator, Mapping, MutableMapping +from dataclasses import dataclass from functools import lru_cache from typing import TYPE_CHECKING, Any, cast, final, override @@ -67,6 +68,31 @@ _START_NODE_TYPES: frozenset[NodeType] = frozenset( ) +@dataclass(frozen=True, slots=True) +class DifyGraphInitContext: + """Explicit graph-init values owned by the workflow layer. + + Dify is gradually removing direct `GraphInitParams` construction from its + production call sites. Keep the translation here until `graphon` exposes an + equivalent explicit API. + """ + + workflow_id: str + graph_config: Mapping[str, Any] + run_context: Mapping[str, Any] + call_depth: int + + def to_graph_init_params(self) -> "GraphInitParams": + from graphon.entities import GraphInitParams + + return GraphInitParams( + workflow_id=self.workflow_id, + graph_config=self.graph_config, + run_context=self.run_context, + call_depth=self.call_depth, + ) + + def _import_node_package(package_name: str, *, excluded_modules: frozenset[str] = frozenset()) -> None: package = importlib.import_module(package_name) for _, module_name, _ in pkgutil.walk_packages(package.__path__, package.__name__ + "."): @@ -237,6 +263,19 @@ class DifyNodeFactory(NodeFactory): Default implementation of NodeFactory that resolves node classes from the live registry. """ + @classmethod + def from_graph_init_context( + cls, + *, + graph_init_context: DifyGraphInitContext, + graph_runtime_state: "GraphRuntimeState", + ) -> "DifyNodeFactory": + """Bridge Dify's explicit init context into the current `graphon` API.""" + return cls( + graph_init_params=graph_init_context.to_graph_init_params(), + graph_runtime_state=graph_runtime_state, + ) + def __init__( self, graph_init_params: "GraphInitParams", diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index 6a0d633627..8c866aea81 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -29,7 +29,7 @@ class TriggerWebhookNode(Node[WebhookData]): def post_init(self) -> None: from core.workflow.node_runtime import DifyFileReferenceFactory - self._file_reference_factory = DifyFileReferenceFactory(self.graph_init_params.run_context) + self._file_reference_factory = DifyFileReferenceFactory(self.run_context) @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index cecc20145a..f0a5fbb400 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -24,7 +24,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_di from core.app.file_access import DatabaseFileAccessController from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.app.workflow.layers.observability import ObservabilityLayer -from core.workflow.node_factory import DifyNodeFactory, is_start_node_type, resolve_workflow_node_class +from core.workflow.node_factory import ( + DifyGraphInitContext, + DifyNodeFactory, + is_start_node_type, + resolve_workflow_node_class, +) from core.workflow.system_variables import ( default_system_variables, get_node_creation_preload_selectors, @@ -251,17 +256,18 @@ class WorkflowEntry: node_version = str(node_config_data.version) node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version) - # init graph init params and runtime state - graph_init_params = GraphInitParams( + # init graph context and runtime state + run_context = build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + graph_init_context = DifyGraphInitContext( workflow_id=workflow.id, graph_config=workflow.graph_dict, - run_context=build_dify_run_context( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - ), + run_context=run_context, call_depth=0, ) graph_runtime_state = GraphRuntimeState( @@ -313,8 +319,8 @@ class WorkflowEntry: ) # init workflow run state - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, + node_factory = DifyNodeFactory.from_graph_init_context( + graph_init_context=graph_init_context, graph_runtime_state=graph_runtime_state, ) node = node_factory.create_node(node_config) @@ -409,17 +415,18 @@ class WorkflowEntry: variable_pool = VariablePool() add_variables_to_pool(variable_pool, default_system_variables()) - # init graph init params and runtime state - graph_init_params = GraphInitParams( + # init graph context and runtime state + run_context = build_dify_run_context( + tenant_id=tenant_id, + app_id="", + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + graph_init_context = DifyGraphInitContext( workflow_id="", graph_config=graph_dict, - run_context=build_dify_run_context( - tenant_id=tenant_id, - app_id="", - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - ), + run_context=run_context, call_depth=0, ) graph_runtime_state = GraphRuntimeState( @@ -430,8 +437,8 @@ class WorkflowEntry: # init workflow run state node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data}) - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, + node_factory = DifyNodeFactory.from_graph_init_context( + graph_init_context=graph_init_context, graph_runtime_state=graph_runtime_state, ) node = node_factory.create_node(node_config) diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 1b3ccd1207..86b0550187 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -5,12 +5,30 @@ from typing import Any import pytz # type: ignore[import-untyped] from celery import Celery, Task from celery.schedules import crontab +from typing_extensions import TypedDict from configs import dify_config from dify_app import DifyApp -def get_celery_ssl_options() -> dict[str, Any] | None: +class _CelerySentinelKwargsDict(TypedDict): + socket_timeout: float | None + password: str | None + + +class CelerySentinelTransportDict(TypedDict): + master_name: str | None + sentinel_kwargs: _CelerySentinelKwargsDict + + +class CelerySSLOptionsDict(TypedDict): + ssl_cert_reqs: int + ssl_ca_certs: str | None + ssl_certfile: str | None + ssl_keyfile: str | None + + +def get_celery_ssl_options() -> CelerySSLOptionsDict | None: """Get SSL configuration for Celery broker/backend connections.""" # Only apply SSL if we're using Redis as broker/backend if not dify_config.BROKER_USE_SSL: @@ -33,26 +51,24 @@ def get_celery_ssl_options() -> dict[str, Any] | None: ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE) - ssl_options = { - "ssl_cert_reqs": ssl_cert_reqs, - "ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS, - "ssl_certfile": dify_config.REDIS_SSL_CERTFILE, - "ssl_keyfile": dify_config.REDIS_SSL_KEYFILE, - } - - return ssl_options + return CelerySSLOptionsDict( + ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS, + ssl_certfile=dify_config.REDIS_SSL_CERTFILE, + ssl_keyfile=dify_config.REDIS_SSL_KEYFILE, + ) -def get_celery_broker_transport_options() -> dict[str, Any]: +def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]: """Get broker transport options (e.g. Redis Sentinel) for Celery connections.""" if dify_config.CELERY_USE_SENTINEL: - return { - "master_name": dify_config.CELERY_SENTINEL_MASTER_NAME, - "sentinel_kwargs": { - "socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT, - "password": dify_config.CELERY_SENTINEL_PASSWORD, - }, - } + return CelerySentinelTransportDict( + master_name=dify_config.CELERY_SENTINEL_MASTER_NAME, + sentinel_kwargs=_CelerySentinelKwargsDict( + socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT, + password=dify_config.CELERY_SENTINEL_PASSWORD, + ), + ) return {} diff --git a/api/models/model.py b/api/models/model.py index ece3ff8b87..d2ff8065e2 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -674,28 +674,24 @@ class AppModelConfig(TypeBase): def suggested_questions_list(self) -> list[str]: return json.loads(self.suggested_questions) if self.suggested_questions else [] + def _get_enabled_config(self, value: str | None, *, default_enabled: bool = False) -> EnabledConfig: + return cast(EnabledConfig, json.loads(value) if value else {"enabled": default_enabled}) + @property def suggested_questions_after_answer_dict(self) -> EnabledConfig: - return cast( - EnabledConfig, - json.loads(self.suggested_questions_after_answer) - if self.suggested_questions_after_answer - else {"enabled": False}, - ) + return self._get_enabled_config(self.suggested_questions_after_answer) @property def speech_to_text_dict(self) -> EnabledConfig: - return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}) + return self._get_enabled_config(self.speech_to_text) @property def text_to_speech_dict(self) -> EnabledConfig: - return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}) + return self._get_enabled_config(self.text_to_speech) @property def retriever_resource_dict(self) -> EnabledConfig: - return cast( - EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} - ) + return self._get_enabled_config(self.retriever_resource, default_enabled=True) @property def annotation_reply_dict(self) -> AnnotationReplyConfig: @@ -722,7 +718,7 @@ class AppModelConfig(TypeBase): @property def more_like_this_dict(self) -> EnabledConfig: - return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}) + return self._get_enabled_config(self.more_like_this) @property def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig: @@ -902,7 +898,7 @@ class InstalledApp(TypeBase): return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id)) -class TrialApp(Base): +class TrialApp(TypeBase): __tablename__ = "trial_apps" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="trial_app_pkey"), @@ -911,18 +907,26 @@ class TrialApp(Base): sa.UniqueConstraint("app_id", name="unique_trail_app_id"), ) - id = mapped_column(StringUUID, default=gen_uuidv4_string) - app_id = mapped_column(StringUUID, nullable=False) - tenant_id = mapped_column(StringUUID, nullable=False) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) - trial_limit = mapped_column(sa.Integer, nullable=False, default=3) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False + ) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, + nullable=False, + insert_default=func.current_timestamp(), + server_default=func.current_timestamp(), + init=False, + ) + trial_limit: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=3) @property def app(self) -> App | None: return db.session.scalar(select(App).where(App.id == self.app_id)) -class AccountTrialAppRecord(Base): +class AccountTrialAppRecord(TypeBase): __tablename__ = "account_trial_app_records" __table_args__ = ( sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"), @@ -930,11 +934,19 @@ class AccountTrialAppRecord(Base): sa.Index("account_trial_app_record_app_id_idx", "app_id"), sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"), ) - id = mapped_column(StringUUID, default=gen_uuidv4_string) - account_id = mapped_column(StringUUID, nullable=False) - app_id = mapped_column(StringUUID, nullable=False) - count = mapped_column(sa.Integer, nullable=False, default=0) - created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column( + StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False + ) + account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) + created_at: Mapped[datetime] = mapped_column( + sa.DateTime, + nullable=False, + insert_default=func.current_timestamp(), + server_default=func.current_timestamp(), + init=False, + ) @property def app(self) -> App | None: diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py index f71583c1cd..8b767779ce 100644 --- a/api/models/utils/file_input_compat.py +++ b/api/models/utils/file_input_compat.py @@ -66,12 +66,15 @@ def build_file_from_stored_mapping( record_id = resolve_file_record_id(mapping) transfer_method = FileTransferMethod.value_of(mapping["transfer_method"]) - if transfer_method == FileTransferMethod.TOOL_FILE and record_id: - mapping["tool_file_id"] = record_id - elif transfer_method in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL] and record_id: - mapping["upload_file_id"] = record_id - elif transfer_method == FileTransferMethod.DATASOURCE_FILE and record_id: - mapping["datasource_file_id"] = record_id + match transfer_method: + case FileTransferMethod.TOOL_FILE if record_id: + mapping["tool_file_id"] = record_id + case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL if record_id: + mapping["upload_file_id"] = record_id + case FileTransferMethod.DATASOURCE_FILE if record_id: + mapping["datasource_file_id"] = record_id + case _: + pass if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None: remote_url = mapping.get("remote_url") diff --git a/api/pyproject.toml b/api/pyproject.toml index cd38ed33f7..5176964e9b 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -173,7 +173,7 @@ dev = [ "sseclient-py>=1.8.0", "pytest-timeout>=2.4.0", "pytest-xdist>=3.8.0", - "pyrefly>=0.59.1", + "pyrefly>=0.60.0", ] ############################################################ diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index c6c8a15109..40e1e5f8ab 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -467,61 +467,67 @@ class AppDslService: ) # Initialize app based on mode - if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: - workflow_data = data.get("workflow") - if not workflow_data or not isinstance(workflow_data, dict): - raise ValueError("Missing workflow data for workflow/advanced chat app") + match app_mode: + case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW: + workflow_data = data.get("workflow") + if not workflow_data or not isinstance(workflow_data, dict): + raise ValueError("Missing workflow data for workflow/advanced chat app") - environment_variables_list = workflow_data.get("environment_variables", []) - environment_variables = [ - variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list - ] - conversation_variables_list = workflow_data.get("conversation_variables", []) - conversation_variables = [ - variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list - ] + environment_variables_list = workflow_data.get("environment_variables", []) + environment_variables = [ + variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list + ] + conversation_variables_list = workflow_data.get("conversation_variables", []) + conversation_variables = [ + variable_factory.build_conversation_variable_from_mapping(obj) + for obj in conversation_variables_list + ] - workflow_service = WorkflowService() - current_draft_workflow = workflow_service.get_draft_workflow(app_model=app) - if current_draft_workflow: - unique_hash = current_draft_workflow.unique_hash - else: - unique_hash = None - graph = workflow_data.get("graph", {}) - for node in graph.get("nodes", []): - if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: - dataset_ids = node["data"].get("dataset_ids", []) - node["data"]["dataset_ids"] = [ - decrypted_id - for dataset_id in dataset_ids - if (decrypted_id := self.decrypt_dataset_id(encrypted_data=dataset_id, tenant_id=app.tenant_id)) - ] - workflow_service.sync_draft_workflow( - app_model=app, - graph=workflow_data.get("graph", {}), - features=workflow_data.get("features", {}), - unique_hash=unique_hash, - account=account, - environment_variables=environment_variables, - conversation_variables=conversation_variables, - ) - elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}: - # Initialize model config - model_config = data.get("model_config") - if not model_config or not isinstance(model_config, dict): - raise ValueError("Missing model_config for chat/agent-chat/completion app") - # Initialize or update model config - if not app.app_model_config: - app_model_config = AppModelConfig( - app_id=app.id, created_by=account.id, updated_by=account.id - ).from_model_config_dict(cast(AppModelConfigDict, model_config)) - app_model_config.id = str(uuid4()) - app.app_model_config_id = app_model_config.id + workflow_service = WorkflowService() + current_draft_workflow = workflow_service.get_draft_workflow(app_model=app) + if current_draft_workflow: + unique_hash = current_draft_workflow.unique_hash + else: + unique_hash = None + graph = workflow_data.get("graph", {}) + for node in graph.get("nodes", []): + if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: + dataset_ids = node["data"].get("dataset_ids", []) + node["data"]["dataset_ids"] = [ + decrypted_id + for dataset_id in dataset_ids + if ( + decrypted_id := self.decrypt_dataset_id( + encrypted_data=dataset_id, tenant_id=app.tenant_id + ) + ) + ] + workflow_service.sync_draft_workflow( + app_model=app, + graph=workflow_data.get("graph", {}), + features=workflow_data.get("features", {}), + unique_hash=unique_hash, + account=account, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + ) + case AppMode.CHAT | AppMode.AGENT_CHAT | AppMode.COMPLETION: + # Initialize model config + model_config = data.get("model_config") + if not model_config or not isinstance(model_config, dict): + raise ValueError("Missing model_config for chat/agent-chat/completion app") + # Initialize or update model config + if not app.app_model_config: + app_model_config = AppModelConfig( + app_id=app.id, created_by=account.id, updated_by=account.id + ).from_model_config_dict(cast(AppModelConfigDict, model_config)) + app_model_config.id = str(uuid4()) + app.app_model_config_id = app_model_config.id - self._session.add(app_model_config) - app_model_config_was_updated.send(app, app_model_config=app_model_config) - else: - raise ValueError("Invalid app mode") + self._session.add(app_model_config) + app_model_config_was_updated.send(app, app_model_config=app_model_config) + case _: + raise ValueError("Invalid app mode") return app @classmethod diff --git a/api/services/file_service.py b/api/services/file_service.py index 50a326d813..7443ca3271 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -132,8 +132,8 @@ class FileService: return file_size <= file_size_limit def get_file_base64(self, file_id: str) -> str: - upload_file = ( - self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = self._session_maker(expire_on_commit=False).scalar( + select(UploadFile).where(UploadFile.id == file_id).limit(1) ) if not upload_file: raise NotFound("File not found") @@ -178,7 +178,7 @@ class FileService: Return a short text preview extracted from a document file. """ with self._session_maker(expire_on_commit=False) as session: - upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found") @@ -200,7 +200,7 @@ class FileService: if not result: raise NotFound("File not found or signature is invalid") with self._session_maker(expire_on_commit=False) as session: - upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found or signature is invalid") @@ -220,7 +220,7 @@ class FileService: raise NotFound("File not found or signature is invalid") with self._session_maker(expire_on_commit=False) as session: - upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found or signature is invalid") @@ -231,7 +231,7 @@ class FileService: def get_public_image_preview(self, file_id: str): with self._session_maker(expire_on_commit=False) as session: - upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found or signature is invalid") @@ -247,7 +247,7 @@ class FileService: def get_file_content(self, file_id: str) -> str: with self._session_maker(expire_on_commit=False) as session: - upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file: UploadFile | None = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1)) if not upload_file: raise NotFound("File not found") diff --git a/api/services/plugin/plugin_auto_upgrade_service.py b/api/services/plugin/plugin_auto_upgrade_service.py index a58bede8db..9bb0ab6ae2 100644 --- a/api/services/plugin/plugin_auto_upgrade_service.py +++ b/api/services/plugin/plugin_auto_upgrade_service.py @@ -1,14 +1,13 @@ from sqlalchemy import select -from sqlalchemy.orm import sessionmaker -from extensions.ext_database import db +from core.db.session_factory import session_factory from models.account import TenantPluginAutoUpgradeStrategy class PluginAutoUpgradeService: @staticmethod def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None: - with sessionmaker(bind=db.engine).begin() as session: + with session_factory.create_session() as session: return session.scalar( select(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) @@ -24,7 +23,7 @@ class PluginAutoUpgradeService: exclude_plugins: list[str], include_plugins: list[str], ) -> bool: - with sessionmaker(bind=db.engine).begin() as session: + with session_factory.create_session() as session: exist_strategy = session.scalar( select(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) @@ -51,7 +50,7 @@ class PluginAutoUpgradeService: @staticmethod def exclude_plugin(tenant_id: str, plugin_id: str) -> bool: - with sessionmaker(bind=db.engine).begin() as session: + with session_factory.create_session() as session: exist_strategy = session.scalar( select(TenantPluginAutoUpgradeStrategy) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py index 6fb90d356d..1df5fd13b6 100644 --- a/api/services/recommend_app/database/database_retrieval.py +++ b/api/services/recommend_app/database/database_retrieval.py @@ -1,3 +1,5 @@ +from typing import Any, TypedDict + from sqlalchemy import select from constants.languages import languages @@ -8,16 +10,43 @@ from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase from services.recommend_app.recommend_app_type import RecommendAppType +class RecommendedAppItemDict(TypedDict): + id: str + app: App | None + app_id: str + description: Any + copyright: Any + privacy_policy: Any + custom_disclaimer: str + category: str + position: int + is_listed: bool + + +class RecommendedAppsResultDict(TypedDict): + recommended_apps: list[RecommendedAppItemDict] + categories: list[str] + + +class RecommendedAppDetailDict(TypedDict): + id: str + name: str + icon: Any + icon_background: str | None + mode: str + export_data: str + + class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): """ Retrieval recommended app from database """ - def get_recommended_apps_and_categories(self, language: str): + def get_recommended_apps_and_categories(self, language: str) -> RecommendedAppsResultDict: result = self.fetch_recommended_apps_from_db(language) return result - def get_recommend_app_detail(self, app_id: str): + def get_recommend_app_detail(self, app_id: str) -> RecommendedAppDetailDict | None: result = self.fetch_recommended_app_detail_from_db(app_id) return result @@ -25,7 +54,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): return RecommendAppType.DATABASE @classmethod - def fetch_recommended_apps_from_db(cls, language: str): + def fetch_recommended_apps_from_db(cls, language: str) -> RecommendedAppsResultDict: """ Fetch recommended apps from db. :param language: language @@ -41,7 +70,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): ).all() categories = set() - recommended_apps_result = [] + recommended_apps_result: list[RecommendedAppItemDict] = [] for recommended_app in recommended_apps: app = recommended_app.app if not app or not app.is_public: @@ -51,7 +80,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): if not site: continue - recommended_app_result = { + recommended_app_result: RecommendedAppItemDict = { "id": recommended_app.id, "app": recommended_app.app, "app_id": recommended_app.app_id, @@ -67,10 +96,10 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): categories.add(recommended_app.category) - return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} + return RecommendedAppsResultDict(recommended_apps=recommended_apps_result, categories=sorted(categories)) @classmethod - def fetch_recommended_app_detail_from_db(cls, app_id: str) -> dict | None: + def fetch_recommended_app_detail_from_db(cls, app_id: str) -> RecommendedAppDetailDict | None: """ Fetch recommended app detail from db. :param app_id: App ID @@ -89,11 +118,11 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): if not app_model or not app_model.is_public: return None - return { - "id": app_model.id, - "name": app_model.name, - "icon": app_model.icon, - "icon_background": app_model.icon_background, - "mode": app_model.mode, - "export_data": AppDslService.export_dsl(app_model=app_model), - } + return RecommendedAppDetailDict( + id=app_model.id, + name=app_model.name, + icon=app_model.icon, + icon_background=app_model.icon_background, + mode=app_model.mode, + export_data=AppDslService.export_dsl(app_model=app_model), + ) diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 7b69ccfce7..bb767a6759 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -104,32 +104,32 @@ class WebhookService: """ with Session(db.engine) as session: # Get webhook trigger - webhook_trigger = ( - session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.webhook_id == webhook_id).first() + webhook_trigger = session.scalar( + select(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.webhook_id == webhook_id).limit(1) ) if not webhook_trigger: raise ValueError(f"Webhook not found: {webhook_id}") if is_debug: - workflow = ( - session.query(Workflow) - .filter( + workflow = session.scalar( + select(Workflow) + .where( Workflow.app_id == webhook_trigger.app_id, Workflow.version == Workflow.VERSION_DRAFT, ) .order_by(Workflow.created_at.desc()) - .first() + .limit(1) ) else: # Check if the corresponding AppTrigger exists - app_trigger = ( - session.query(AppTrigger) - .filter( + app_trigger = session.scalar( + select(AppTrigger) + .where( AppTrigger.app_id == webhook_trigger.app_id, AppTrigger.node_id == webhook_trigger.node_id, AppTrigger.trigger_type == AppTriggerType.TRIGGER_WEBHOOK, ) - .first() + .limit(1) ) if not app_trigger: @@ -146,14 +146,14 @@ class WebhookService: raise ValueError(f"Webhook trigger is disabled for webhook {webhook_id}") # Get workflow - workflow = ( - session.query(Workflow) - .filter( + workflow = session.scalar( + select(Workflow) + .where( Workflow.app_id == webhook_trigger.app_id, Workflow.version != Workflow.VERSION_DRAFT, ) .order_by(Workflow.created_at.desc()) - .first() + .limit(1) ) if not workflow: raise ValueError(f"Workflow not found for app {webhook_trigger.app_id}") diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index a4e9f6943f..ace7ef68c2 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -5,7 +5,7 @@ import uuid from collections.abc import Callable, Generator, Mapping, Sequence from typing import Any, cast -from graphon.entities import GraphInitParams, WorkflowNodeExecution +from graphon.entities import WorkflowNodeExecution from graphon.entities.graph_config import NodeConfigDict from graphon.entities.pause_reason import HumanInputRequired from graphon.enums import ( @@ -48,7 +48,12 @@ from core.workflow.human_input_compat import ( normalize_human_input_node_data_for_graph, parse_human_input_delivery_methods, ) -from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type +from core.workflow.node_factory import ( + LATEST_VERSION, + DifyGraphInitContext, + get_node_type_classes_mapping, + is_start_node_type, +) from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool @@ -1204,18 +1209,20 @@ class WorkflowService: node_config: NodeConfigDict, variable_pool: VariablePool, ) -> HumanInputNode: - graph_init_params = GraphInitParams( + run_context = build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + user_id=account.id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + graph_init_context = DifyGraphInitContext( workflow_id=workflow.id, graph_config=workflow.graph_dict, - run_context=build_dify_run_context( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, - user_id=account.id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, - ), + run_context=run_context, call_depth=0, ) + graph_init_params = graph_init_context.to_graph_init_params() graph_runtime_state = GraphRuntimeState( variable_pool=variable_pool, start_at=time.perf_counter(), @@ -1225,7 +1232,7 @@ class WorkflowService: config=node_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), + runtime=DifyHumanInputNodeRuntime(run_context), ) return node diff --git a/api/tasks/deal_dataset_index_update_task.py b/api/tasks/deal_dataset_index_update_task.py index fa844a8647..c9b5121a08 100644 --- a/api/tasks/deal_dataset_index_update_task.py +++ b/api/tasks/deal_dataset_index_update_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task # type: ignore +from sqlalchemy import select, update from core.db.session_factory import session_factory from core.rag.index_processor.constant.doc_type import DocType @@ -26,43 +27,42 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): with session_factory.create_session() as session: try: - dataset = session.query(Dataset).filter_by(id=dataset_id).first() + dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1)) if not dataset: raise Exception("Dataset not found") index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX index_processor = IndexProcessorFactory(index_type).init_index_processor() if action == "upgrade": - dataset_documents = ( - session.query(DatasetDocument) - .where( + dataset_documents = session.scalars( + select(DatasetDocument).where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - .all() - ) + ).all() if dataset_documents: dataset_documents_ids = [doc.id for doc in dataset_documents] - session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id.in_(dataset_documents_ids)) + .values(indexing_status="indexing") ) session.commit() for dataset_document in dataset_documents: try: # add from vector index - segments = ( - session.query(DocumentSegment) + segments = session.scalars( + select(DocumentSegment) .where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True, ) .order_by(DocumentSegment.position.asc()) - .all() - ) + ).all() if segments: documents = [] for segment in segments: @@ -81,32 +81,36 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): # clean keywords index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False) index_processor.load(dataset, documents, with_keywords=False) - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="completed") ) session.commit() except Exception as e: - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="error", error=str(e)) ) session.commit() elif action == "update": - dataset_documents = ( - session.query(DatasetDocument) - .where( + dataset_documents = session.scalars( + select(DatasetDocument).where( DatasetDocument.dataset_id == dataset_id, DatasetDocument.indexing_status == "completed", DatasetDocument.enabled == True, DatasetDocument.archived == False, ) - .all() - ) + ).all() # add new index if dataset_documents: # update document status dataset_documents_ids = [doc.id for doc in dataset_documents] - session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( - {"indexing_status": "indexing"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id.in_(dataset_documents_ids)) + .values(indexing_status="indexing") ) session.commit() @@ -116,15 +120,14 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): for dataset_document in dataset_documents: # update from vector index try: - segments = ( - session.query(DocumentSegment) + segments = session.scalars( + select(DocumentSegment) .where( DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True, ) .order_by(DocumentSegment.position.asc()) - .all() - ) + ).all() if segments: documents = [] multimodal_documents = [] @@ -173,13 +176,17 @@ def deal_dataset_index_update_task(dataset_id: str, action: str): index_processor.load( dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False ) - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "completed"}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="completed") ) session.commit() except Exception as e: - session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( - {"indexing_status": "error", "error": str(e)}, synchronize_session=False + session.execute( + update(DatasetDocument) + .where(DatasetDocument.id == dataset_document.id) + .values(indexing_status="error", error=str(e)) ) session.commit() else: diff --git a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py index fe533e62af..1f5fdd2657 100644 --- a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py +++ b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py @@ -862,6 +862,15 @@ class TestAuthOrchestration: result = discover_protected_resource_metadata(None, "https://api.example.com") assert result is None + # JSONDecodeError (non-JSON 200 response) + mock_get.side_effect = None + bad_json_response = Mock() + bad_json_response.status_code = 200 + bad_json_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0) + mock_get.return_value = bad_json_response + result = discover_protected_resource_metadata(None, "https://api.example.com") + assert result is None + @patch("core.helper.ssrf_proxy.get") def test_discover_oauth_authorization_server_metadata(self, mock_get): # Success @@ -892,6 +901,14 @@ class TestAuthOrchestration: result = discover_oauth_authorization_server_metadata(None, "https://api.example.com") assert result is None + # JSONDecodeError (non-JSON 200 response) + bad_json_response = Mock() + bad_json_response.status_code = 200 + bad_json_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0) + mock_get.return_value = bad_json_response + result = discover_oauth_authorization_server_metadata(None, "https://api.example.com") + assert result is None + def test_get_effective_scope(self): prm = ProtectedResourceMetadata( resource="https://api.example.com", @@ -997,6 +1014,24 @@ class TestAuthOrchestration: supported, url = check_support_resource_discovery("https://api") assert supported is False + # Case 6: JSONDecodeError (non-JSON 200 response) + mock_get.side_effect = None + bad_json_res = Mock() + bad_json_res.status_code = 200 + bad_json_res.json.side_effect = json.JSONDecodeError("Expecting value", "", 0) + mock_get.return_value = bad_json_res + supported, url = check_support_resource_discovery("https://api") + assert supported is False + assert url == "" + + # Case 7: Empty authorization_servers array (IndexError) + empty_res = Mock() + empty_res.status_code = 200 + empty_res.json.return_value = {"authorization_servers": []} + mock_get.return_value = empty_res + supported, url = check_support_resource_discovery("https://api") + assert supported is False + def test_discover_oauth_metadata(self): with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm: with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm: diff --git a/api/tests/unit_tests/core/workflow/test_node_factory.py b/api/tests/unit_tests/core/workflow/test_node_factory.py index bc0b339fec..dfe1a47e37 100644 --- a/api/tests/unit_tests/core/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/workflow/test_node_factory.py @@ -110,6 +110,34 @@ class TestFetchMemory: ) +class TestDifyGraphInitContext: + def test_to_graph_init_params_preserves_explicit_values(self): + run_context = { + DIFY_RUN_CONTEXT_KEY: DifyRunContext( + tenant_id="tenant-id", + app_id="app-id", + user_id="user-id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ), + "extra": "value", + } + graph_config = {"nodes": [], "edges": []} + graph_init_context = node_factory.DifyGraphInitContext( + workflow_id="workflow-id", + graph_config=graph_config, + run_context=run_context, + call_depth=2, + ) + + result = graph_init_context.to_graph_init_params() + + assert result.workflow_id == "workflow-id" + assert result.graph_config == graph_config + assert result.run_context == run_context + assert result.call_depth == 2 + + class TestDefaultWorkflowCodeExecutor: def test_execute_delegates_to_code_executor(self, monkeypatch): executor = node_factory.DefaultWorkflowCodeExecutor() @@ -172,6 +200,23 @@ class TestCodeExecutorJinja2TemplateRenderer: class TestDifyNodeFactoryInit: + def test_from_graph_init_context_translates_before_init(self): + graph_init_context = MagicMock() + graph_init_context.to_graph_init_params.return_value = sentinel.graph_init_params + + with patch.object(node_factory.DifyNodeFactory, "__init__", return_value=None) as init: + factory = node_factory.DifyNodeFactory.from_graph_init_context( + graph_init_context=graph_init_context, + graph_runtime_state=sentinel.graph_runtime_state, + ) + + assert isinstance(factory, node_factory.DifyNodeFactory) + graph_init_context.to_graph_init_params.assert_called_once_with() + init.assert_called_once_with( + graph_init_params=sentinel.graph_init_params, + graph_runtime_state=sentinel.graph_runtime_state, + ) + def test_init_builds_default_dependencies(self): graph_init_params = SimpleNamespace(run_context={"context": "value"}) graph_runtime_state = sentinel.graph_runtime_state diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py index 879c0bb721..6dcaed1143 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py @@ -349,7 +349,7 @@ class TestWorkflowEntrySingleStepRun: ] with ( - patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), + patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context), patch.object( workflow_entry, "GraphRuntimeState", @@ -358,7 +358,7 @@ class TestWorkflowEntrySingleStepRun: patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), patch.object(workflow_entry.time, "perf_counter", return_value=123.0), patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeLLMNode), - patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory, + patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory, patch.object(workflow_entry, "load_into_variable_pool"), patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"), patch.object( @@ -412,12 +412,12 @@ class TestWorkflowEntrySingleStepRun: raise NotImplementedError with ( - patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), + patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context), patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), patch.object(workflow_entry.time, "perf_counter", return_value=123.0), patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeNode), - patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory, + patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory, patch.object(workflow_entry, "add_node_inputs_to_pool") as add_node_inputs_to_pool, patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool, patch.object( @@ -481,12 +481,12 @@ class TestWorkflowEntrySingleStepRun: return {"question": ["node", "question"]} with ( - patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), + patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context), patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), patch.object(workflow_entry.time, "perf_counter", return_value=123.0), patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeDatasourceNode), - patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory, + patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory, patch.object(workflow_entry, "add_node_inputs_to_pool") as add_node_inputs_to_pool, patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool, patch.object( @@ -541,12 +541,12 @@ class TestWorkflowEntrySingleStepRun: return "1" with ( - patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), + patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context), patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), patch.object(workflow_entry.time, "perf_counter", return_value=123.0), patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeNode), - patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory, + patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory, patch.object(workflow_entry, "add_node_inputs_to_pool"), patch.object(workflow_entry, "load_into_variable_pool"), patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"), @@ -651,14 +651,18 @@ class TestWorkflowEntryHelpers: patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool) as variable_pool_cls, patch.object(workflow_entry, "add_variables_to_pool") as add_variables_to_pool, patch.object( - workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params - ) as graph_init_params, + workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context + ) as graph_init_context_cls, patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), patch.object( workflow_entry, "build_dify_run_context", return_value={"_dify": "context"} ) as build_dify_run_context, patch.object(workflow_entry.time, "perf_counter", return_value=123.0), - patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory) as dify_node_factory_cls, + patch.object( + workflow_entry.DifyNodeFactory, + "from_graph_init_context", + return_value=dify_node_factory, + ) as dify_node_factory_cls, patch.object( workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool", @@ -688,7 +692,7 @@ class TestWorkflowEntryHelpers: user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, ) - graph_init_params.assert_called_once_with( + graph_init_context_cls.assert_called_once_with( workflow_id="", graph_config=workflow_entry.WorkflowEntry._create_single_node_graph( "node-id", {"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR, "title": "Node"} @@ -697,7 +701,7 @@ class TestWorkflowEntryHelpers: call_depth=0, ) dify_node_factory_cls.assert_called_once_with( - graph_init_params=sentinel.graph_init_params, + graph_init_context=sentinel.graph_init_context, graph_runtime_state=sentinel.graph_runtime_state, ) mapping_user_inputs_to_variable_pool.assert_called_once_with( @@ -734,11 +738,15 @@ class TestWorkflowEntryHelpers: patch.object(workflow_entry, "default_system_variables", return_value=sentinel.system_variables), patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool), patch.object(workflow_entry, "add_variables_to_pool"), - patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), + patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context), patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), patch.object(workflow_entry.time, "perf_counter", return_value=123.0), - patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory), + patch.object( + workflow_entry.DifyNodeFactory, + "from_graph_init_context", + return_value=dify_node_factory, + ), patch.object( workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool", diff --git a/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py b/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py index bc2f1c6ecc..021bebceff 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py +++ b/api/tests/unit_tests/services/plugin/test_plugin_auto_upgrade_service.py @@ -6,23 +6,23 @@ MODULE = "services.plugin.plugin_auto_upgrade_service" def _patched_session(): - """Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager.""" + """Patch session_factory.create_session() to return a mock session as context manager.""" session = MagicMock() - mock_sessionmaker = MagicMock() - mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session) - mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) - patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker) - db_patcher = patch(f"{MODULE}.db") - return patcher, db_patcher, session + session.__enter__ = MagicMock(return_value=session) + session.__exit__ = MagicMock(return_value=False) + mock_factory = MagicMock() + mock_factory.create_session.return_value = session + patcher = patch(f"{MODULE}.session_factory", mock_factory) + return patcher, session class TestGetStrategy: def test_returns_strategy_when_found(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() strategy = MagicMock() session.scalar.return_value = strategy - with p1, p2: + with p1: from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService result = PluginAutoUpgradeService.get_strategy("t1") @@ -30,10 +30,10 @@ class TestGetStrategy: assert result is strategy def test_returns_none_when_not_found(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() session.scalar.return_value = None - with p1, p2: + with p1: from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService result = PluginAutoUpgradeService.get_strategy("t1") @@ -43,10 +43,10 @@ class TestGetStrategy: class TestChangeStrategy: def test_creates_new_strategy(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() session.scalar.return_value = None - with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: strat_cls.return_value = MagicMock() from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService @@ -63,11 +63,11 @@ class TestChangeStrategy: session.add.assert_called_once() def test_updates_existing_strategy(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() existing = MagicMock() session.scalar.return_value = existing - with p1, p2: + with p1: from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService result = PluginAutoUpgradeService.change_strategy( @@ -89,12 +89,11 @@ class TestChangeStrategy: class TestExcludePlugin: def test_creates_default_strategy_when_none_exists(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() session.scalar.return_value = None with ( p1, - p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls, patch(f"{MODULE}.PluginAutoUpgradeService.change_strategy") as cs, @@ -110,13 +109,13 @@ class TestExcludePlugin: cs.assert_called_once() def test_appends_to_exclude_list_in_exclude_mode(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() existing = MagicMock() existing.upgrade_mode = "exclude" existing.exclude_plugins = ["p-existing"] session.scalar.return_value = existing - with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: strat_cls.UpgradeMode.EXCLUDE = "exclude" strat_cls.UpgradeMode.PARTIAL = "partial" strat_cls.UpgradeMode.ALL = "all" @@ -128,13 +127,13 @@ class TestExcludePlugin: assert existing.exclude_plugins == ["p-existing", "p-new"] def test_removes_from_include_list_in_partial_mode(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() existing = MagicMock() existing.upgrade_mode = "partial" existing.include_plugins = ["p1", "p2"] session.scalar.return_value = existing - with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: strat_cls.UpgradeMode.EXCLUDE = "exclude" strat_cls.UpgradeMode.PARTIAL = "partial" strat_cls.UpgradeMode.ALL = "all" @@ -146,12 +145,12 @@ class TestExcludePlugin: assert existing.include_plugins == ["p2"] def test_switches_to_exclude_mode_from_all(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() existing = MagicMock() existing.upgrade_mode = "all" session.scalar.return_value = existing - with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: strat_cls.UpgradeMode.EXCLUDE = "exclude" strat_cls.UpgradeMode.PARTIAL = "partial" strat_cls.UpgradeMode.ALL = "all" @@ -164,13 +163,13 @@ class TestExcludePlugin: assert existing.exclude_plugins == ["p1"] def test_no_duplicate_in_exclude_list(self): - p1, p2, session = _patched_session() + p1, session = _patched_session() existing = MagicMock() existing.upgrade_mode = "exclude" existing.exclude_plugins = ["p1"] session.scalar.return_value = existing - with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: + with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: strat_cls.UpgradeMode.EXCLUDE = "exclude" strat_cls.UpgradeMode.PARTIAL = "partial" strat_cls.UpgradeMode.ALL = "all" diff --git a/api/tests/unit_tests/services/test_file_service.py b/api/tests/unit_tests/services/test_file_service.py index b7259c3e82..8e1b22886b 100644 --- a/api/tests/unit_tests/services/test_file_service.py +++ b/api/tests/unit_tests/services/test_file_service.py @@ -165,7 +165,7 @@ class TestFileService: upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.key = "test_key" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with patch("services.file_service.storage") as mock_storage: mock_storage.load_once.return_value = b"test content" @@ -178,7 +178,7 @@ class TestFileService: mock_storage.load_once.assert_called_once_with("test_key") def test_get_file_base64_not_found(self, file_service, mock_db_session): - mock_db_session.query().where().first.return_value = None + mock_db_session.scalar.return_value = None with pytest.raises(NotFound, match="File not found"): file_service.get_file_base64("non_existent") @@ -215,7 +215,7 @@ class TestFileService: upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.extension = "pdf" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with patch("services.file_service.ExtractProcessor.load_from_upload_file") as mock_extract: mock_extract.return_value = "Extracted text content" @@ -227,7 +227,7 @@ class TestFileService: assert result == "Extracted text content" def test_get_file_preview_not_found(self, file_service, mock_db_session): - mock_db_session.query().where().first.return_value = None + mock_db_session.scalar.return_value = None with pytest.raises(NotFound, match="File not found"): file_service.get_file_preview("non_existent") @@ -235,7 +235,7 @@ class TestFileService: upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.extension = "exe" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with pytest.raises(UnsupportedFileTypeError): file_service.get_file_preview("file_id") @@ -246,7 +246,7 @@ class TestFileService: upload_file.extension = "jpg" upload_file.mime_type = "image/jpeg" upload_file.key = "key" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with ( patch("services.file_service.file_helpers.verify_image_signature") as mock_verify, @@ -269,7 +269,7 @@ class TestFileService: file_service.get_image_preview("file_id", "ts", "nonce", "sign") def test_get_image_preview_not_found(self, file_service, mock_db_session): - mock_db_session.query().where().first.return_value = None + mock_db_session.scalar.return_value = None with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify: mock_verify.return_value = True with pytest.raises(NotFound, match="File not found or signature is invalid"): @@ -279,7 +279,7 @@ class TestFileService: upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.extension = "txt" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify: mock_verify.return_value = True with pytest.raises(UnsupportedFileTypeError): @@ -289,7 +289,7 @@ class TestFileService: upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.key = "key" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with ( patch("services.file_service.file_helpers.verify_file_signature") as mock_verify, @@ -309,7 +309,7 @@ class TestFileService: file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign") def test_get_file_generator_by_file_id_not_found(self, file_service, mock_db_session): - mock_db_session.query().where().first.return_value = None + mock_db_session.scalar.return_value = None with patch("services.file_service.file_helpers.verify_file_signature") as mock_verify: mock_verify.return_value = True with pytest.raises(NotFound, match="File not found or signature is invalid"): @@ -321,7 +321,7 @@ class TestFileService: upload_file.extension = "png" upload_file.mime_type = "image/png" upload_file.key = "key" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with patch("services.file_service.storage") as mock_storage: mock_storage.load.return_value = b"image content" @@ -330,7 +330,7 @@ class TestFileService: assert mime == "image/png" def test_get_public_image_preview_not_found(self, file_service, mock_db_session): - mock_db_session.query().where().first.return_value = None + mock_db_session.scalar.return_value = None with pytest.raises(NotFound, match="File not found or signature is invalid"): file_service.get_public_image_preview("file_id") @@ -338,7 +338,7 @@ class TestFileService: upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.extension = "txt" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with pytest.raises(UnsupportedFileTypeError): file_service.get_public_image_preview("file_id") @@ -346,7 +346,7 @@ class TestFileService: upload_file = MagicMock(spec=UploadFile) upload_file.id = "file_id" upload_file.key = "key" - mock_db_session.query().where().first.return_value = upload_file + mock_db_session.scalar.return_value = upload_file with patch("services.file_service.storage") as mock_storage: mock_storage.load.return_value = b"hello world" @@ -354,7 +354,7 @@ class TestFileService: assert result == "hello world" def test_get_file_content_not_found(self, file_service, mock_db_session): - mock_db_session.query().where().first.return_value = None + mock_db_session.scalar.return_value = None with pytest.raises(NotFound, match="File not found"): file_service.get_file_content("file_id") diff --git a/api/tests/unit_tests/services/test_webhook_service.py b/api/tests/unit_tests/services/test_webhook_service.py index 1b5252fc64..39693e3f4b 100644 --- a/api/tests/unit_tests/services/test_webhook_service.py +++ b/api/tests/unit_tests/services/test_webhook_service.py @@ -657,7 +657,7 @@ def _app(**kwargs: Any) -> App: def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None: # Arrange fake_session = MagicMock() - fake_session.query.return_value = _FakeQuery(None) + fake_session.scalar.return_value = None _patch_session(monkeypatch, fake_session) # Act / Assert @@ -671,7 +671,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_foun # Arrange webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") fake_session = MagicMock() - fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(None)] + fake_session.scalar.side_effect = [webhook_trigger, None] _patch_session(monkeypatch, fake_session) # Act / Assert @@ -686,7 +686,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_lim webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED) fake_session = MagicMock() - fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)] + fake_session.scalar.side_effect = [webhook_trigger, app_trigger] _patch_session(monkeypatch, fake_session) # Act / Assert @@ -701,7 +701,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED) fake_session = MagicMock() - fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)] + fake_session.scalar.side_effect = [webhook_trigger, app_trigger] _patch_session(monkeypatch, fake_session) # Act / Assert @@ -714,7 +714,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(m webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED) fake_session = MagicMock() - fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(None)] + fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None] _patch_session(monkeypatch, fake_session) # Act / Assert @@ -732,7 +732,7 @@ def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mod workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}} fake_session = MagicMock() - fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(workflow)] + fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow] _patch_session(monkeypatch, fake_session) # Act @@ -751,7 +751,7 @@ def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(mo workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}} fake_session = MagicMock() - fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(workflow)] + fake_session.scalar.side_effect = [webhook_trigger, workflow] _patch_session(monkeypatch, fake_session) # Act diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index aaa090a9aa..002ec8dee7 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -2826,9 +2826,9 @@ class TestWorkflowServiceFreeNodeExecution: variable_pool = MagicMock() with ( - patch("services.workflow_service.GraphInitParams") as mock_graph_init_params, + patch("services.workflow_service.DifyGraphInitContext") as mock_graph_init_context_cls, patch("services.workflow_service.GraphRuntimeState"), - patch("services.workflow_service.build_dify_run_context"), + patch("services.workflow_service.build_dify_run_context") as mock_build_dify_run_context, patch("services.workflow_service.DifyHumanInputNodeRuntime") as mock_runtime_cls, patch("services.workflow_service.HumanInputNode") as mock_node_cls, ): @@ -2837,4 +2837,17 @@ class TestWorkflowServiceFreeNodeExecution: ) assert node == mock_node_cls.return_value mock_node_cls.assert_called_once() - mock_runtime_cls.assert_called_once_with(mock_graph_init_params.return_value.run_context) + mock_graph_init_context_cls.assert_called_once_with( + workflow_id="wf-1", + graph_config=workflow.graph_dict, + run_context=mock_build_dify_run_context.return_value, + call_depth=0, + ) + mock_runtime_cls.assert_called_once_with(mock_build_dify_run_context.return_value) + mock_node_cls.assert_called_once_with( + id="n-1", + config=node_config, + graph_init_params=mock_graph_init_context_cls.return_value.to_graph_init_params.return_value, + graph_runtime_state=ANY, + runtime=mock_runtime_cls.return_value, + ) diff --git a/api/tests/unit_tests/utils/test_text_processing.py b/api/tests/unit_tests/utils/test_text_processing.py index bf61162a66..5f6ccbcdff 100644 --- a/api/tests/unit_tests/utils/test_text_processing.py +++ b/api/tests/unit_tests/utils/test_text_processing.py @@ -19,7 +19,57 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols ("[Google](https://google.com) is a search engine", "[Google](https://google.com) is a search engine"), ("[Example](http://example.com) some text", "[Example](http://example.com) some text"), # Leading symbols before markdown link are removed, including the opening bracket [ - ("@[Test](https://example.com)", "Test](https://example.com)"), + ("@[Test](https://example.com)", "[Test](https://example.com)"), + ("~~标题~~", "标题~~"), + ('""quoted', "quoted"), + ("''test", "test"), + ("##话题", "话题"), + ("$$价格", "价格"), + ("%%百分比", "百分比"), + ("&&与逻辑", "与逻辑"), + ("((括号))", "括号))"), + ("**强调**", "强调**"), + ("++自增", "自增"), + (",,逗号", "逗号"), + ("..省略", "省略"), + ("//注释", "注释"), + ("::范围", "范围"), + (";;分号", "分号"), + ("<<左移", "左移"), + ("==等于", "等于"), + (">>右移", "右移"), + ("??疑问", "疑问"), + ("@@提及", "提及"), + ("^^上标", "上标"), + ("__下划线", "下划线"), + ("``代码", "代码"), + ("~~删除线", "删除线"), + (" 全角空格开头", "全角空格开头"), + ("、顿号开头", "顿号开头"), + ("。句号开头", "句号开头"), + ("「引号」测试", "引号」测试"), + ("『书名号』", "书名号』"), + ("【保留】测试", "【保留】测试"), + ("〖括号〗测试", "括号〗测试"), + ("〔括号〕测试", "括号〕测试"), + ("~~【保留】~~", "【保留】~~"), + ('"[公告]"', '[公告]"'), + ("[公告] 更新", "[公告] 更新"), + ("【通知】重要", "【通知】重要"), + ("[[嵌套]]", "[[嵌套]]"), + ("【【嵌套】】", "【【嵌套】】"), + ("[【混合】]", "[【混合】]"), + ("normal text", "normal text"), + ("123数字", "123数字"), + ("中文开头", "中文开头"), + ("alpha", "alpha"), + ("~", ""), + ("【", "【"), + ("[", "["), + ("~~~", ""), + ("【【【", "【【【"), + ("\t制表符", "\t制表符"), + ("\n换行", "\n换行"), ], ) def test_remove_leading_symbols(input_text, expected_output): diff --git a/api/uv.lock b/api/uv.lock index b9c2693272..a516b57107 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1585,7 +1585,7 @@ dev = [ { name = "lxml-stubs", specifier = "~=0.5.1" }, { name = "mypy", specifier = "~=1.20.0" }, { name = "pandas-stubs", specifier = "~=3.0.0" }, - { name = "pyrefly", specifier = ">=0.59.1" }, + { name = "pyrefly", specifier = ">=0.60.0" }, { name = "pytest", specifier = "~=9.0.2" }, { name = "pytest-benchmark", specifier = "~=5.2.3" }, { name = "pytest-cov", specifier = "~=7.1.0" }, @@ -4850,19 +4850,19 @@ wheels = [ [[package]] name = "pyrefly" -version = "0.59.1" +version = "0.60.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d5/ce/7882c2af92b2ff6505fcd3430eff8048ece6c6254cc90bdc76ecee12dfab/pyrefly-0.59.1.tar.gz", hash = "sha256:bf1675b0c38d45df2c8f8618cbdfa261a1b92430d9d31eba16e0282b551e210f", size = 5475432, upload-time = "2026-04-01T22:04:04.11Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c6/c7/28d14b64888e2d03815627ebff8d57a9f08389c4bbebfe70ae1ed98a1267/pyrefly-0.60.0.tar.gz", hash = "sha256:2499f5b6ff5342e86dfe1cd94bcce133519bbbc93b7ad5636195fea4f0fa3b81", size = 5500389, upload-time = "2026-04-06T19:57:30.643Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d0/10/04a0e05b08fc855b6fe38c3df549925fc3c2c6e750506870de7335d3e1f7/pyrefly-0.59.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:390db3cd14aa7e0268e847b60cd9ee18b04273eddfa38cf341ed3bb43f3fef2a", size = 12868133, upload-time = "2026-04-01T22:03:39.436Z" }, - { url = "https://files.pythonhosted.org/packages/c7/78/fa7be227c3e3fcacee501c1562278dd026186ffd1b5b5beb51d3941a3aed/pyrefly-0.59.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d246d417b6187c1650d7f855f61c68fbfd6d6155dc846d4e4d273a3e6b5175cb", size = 12379325, upload-time = "2026-04-01T22:03:42.046Z" }, - { url = "https://files.pythonhosted.org/packages/bb/13/6828ce1c98171b5f8388f33c4b0b9ea2ab8c49abe0ef8d793c31e30a05cb/pyrefly-0.59.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:575ac67b04412dc651a7143d27e38a40fbdd3c831c714d5520d0e9d4c8631ab4", size = 35826408, upload-time = "2026-04-01T22:03:45.067Z" }, - { url = "https://files.pythonhosted.org/packages/23/56/79ed8ece9a7ecad0113c394a06a084107db3ad8f1fefe19e7ded43c51245/pyrefly-0.59.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:062e6262ce1064d59dcad81ac0499bb7a3ad501e9bc8a677a50dc630ff0bf862", size = 38532699, upload-time = "2026-04-01T22:03:48.376Z" }, - { url = "https://files.pythonhosted.org/packages/18/7d/ecc025e0f0e3f295b497f523cc19cefaa39e57abede8fc353d29445d174b/pyrefly-0.59.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:43ef4247f9e6f734feb93e1f2b75335b943629956e509f545cc9cdcccd76dd20", size = 36743570, upload-time = "2026-04-01T22:03:51.362Z" }, - { url = "https://files.pythonhosted.org/packages/2f/03/b1ce882ebcb87c673165c00451fbe4df17bf96ccfde18c75880dc87c5f5e/pyrefly-0.59.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59a2d01723b84d042f4fa6ec871ffd52d0a7e83b0ea791c2e0bb0ff750abce56", size = 41236246, upload-time = "2026-04-01T22:03:54.361Z" }, - { url = "https://files.pythonhosted.org/packages/17/af/5e9c7afd510e7dd64a2204be0ed39e804089cbc4338675a28615c7176acb/pyrefly-0.59.1-py3-none-win32.whl", hash = "sha256:4ea70c780848f8376411e787643ae5d2d09da8a829362332b7b26d15ebcbaf56", size = 11884747, upload-time = "2026-04-01T22:03:56.776Z" }, - { url = "https://files.pythonhosted.org/packages/aa/c1/7db1077627453fd1068f0761f059a9512645c00c4c20acfb9f0c24ac02ec/pyrefly-0.59.1-py3-none-win_amd64.whl", hash = "sha256:67e6a08cfd129a0d2788d5e40a627f9860e0fe91a876238d93d5c63ff4af68ae", size = 12720608, upload-time = "2026-04-01T22:03:59.252Z" }, - { url = "https://files.pythonhosted.org/packages/07/16/4bb6e5fce5a9cf0992932d9435d964c33e507aaaf96fdfbb1be493078a4a/pyrefly-0.59.1-py3-none-win_arm64.whl", hash = "sha256:01179cb215cf079e8223a064f61a074f7079aa97ea705cbbc68af3d6713afd15", size = 12223158, upload-time = "2026-04-01T22:04:01.869Z" }, + { url = "https://files.pythonhosted.org/packages/31/99/6c9984a09220e5eb7dd5c869b7a32d25c3d06b5e8854c6eb679db1145c3e/pyrefly-0.60.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:bf1691af0fee69d0c99c3c6e9d26ab6acd3c8afef96416f9ba2e74934833b7b5", size = 12921262, upload-time = "2026-04-06T19:57:00.745Z" }, + { url = "https://files.pythonhosted.org/packages/05/b3/6216aa3c00c88e59a27eb4149851b5affe86eeea6129f4224034a32dddb0/pyrefly-0.60.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:3e71b70c9b95545cf3b479bc55d1381b531de7b2380eb64411088a1e56b634cb", size = 12424413, upload-time = "2026-04-06T19:57:03.417Z" }, + { url = "https://files.pythonhosted.org/packages/9b/87/eb8dd73abd92a93952ac27a605e463c432fb250fb23186574038c7035594/pyrefly-0.60.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:680ee5f8f98230ea145652d7344708f5375786209c5bf03d8b911fdb0d0d4195", size = 35940884, upload-time = "2026-04-06T19:57:06.909Z" }, + { url = "https://files.pythonhosted.org/packages/0d/34/dc6aeb67b840c745fcee6db358295d554abe6ab555a7eaaf44624bd80bf1/pyrefly-0.60.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d0b20dbbe4aff15b959e8d825b7521a144c4122c11e57022e83b36568c54470", size = 38677220, upload-time = "2026-04-06T19:57:11.235Z" }, + { url = "https://files.pythonhosted.org/packages/66/6b/c863fcf7ef592b7d1db91502acf0d1113be8bed7a2a7143fc6f0dd90616f/pyrefly-0.60.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2911563c8e6b2eaefff68885c94727965469a35375a409235a7a4d2b7157dc15", size = 36907431, upload-time = "2026-04-06T19:57:15.074Z" }, + { url = "https://files.pythonhosted.org/packages/8e/a2/25ea095ab2ecca8e62884669b11a79f14299db93071685b73a97efbaf4f3/pyrefly-0.60.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0a631d9d04705e303fe156f2e62551611bc7ef8066c34708ceebcfb3088bd55", size = 41447898, upload-time = "2026-04-06T19:57:19.382Z" }, + { url = "https://files.pythonhosted.org/packages/8e/2c/097bdc6e8d40676b28eb03710a4577bc3c7b803cd24693ac02bf15de3d67/pyrefly-0.60.0-py3-none-win32.whl", hash = "sha256:a08d69298da5626cf502d3debbb6944fd13d2f405ea6625363751f1ff570d366", size = 11913434, upload-time = "2026-04-06T19:57:22.887Z" }, + { url = "https://files.pythonhosted.org/packages/0a/d4/8d27fe310e830c8d11ab73db38b93f9fd2e218744b6efb1204401c9a74d5/pyrefly-0.60.0-py3-none-win_amd64.whl", hash = "sha256:56cf30654e708ae1dd635ffefcba4fa4b349dd7004a6ccc5c41e3a9bb944320c", size = 12745033, upload-time = "2026-04-06T19:57:25.517Z" }, + { url = "https://files.pythonhosted.org/packages/1f/ad/8eea1f8fb8209f91f6dbfe48000c9d05fd0cdb1b5b3157283c9b1dada55d/pyrefly-0.60.0-py3-none-win_arm64.whl", hash = "sha256:b6d27fba970f4777063c0227c54167d83bece1804ea34f69e7118e409ba038d2", size = 12246390, upload-time = "2026-04-06T19:57:28.141Z" }, ] [[package]] diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 681e3ad392..d2105db95c 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -520,8 +520,8 @@ catalogs: specifier: 13.0.0 version: 13.0.0 vinext: - specifier: https://pkg.pr.new/vinext@adbf24d - version: 0.0.5 + specifier: 0.0.41 + version: 0.0.41 vite-plugin-inspect: specifier: 12.0.0-beta.1 version: 12.0.0-beta.1 @@ -1162,7 +1162,7 @@ importers: version: 3.19.3 vinext: specifier: 'catalog:' - version: https://pkg.pr.new/vinext@adbf24d(@mdx-js/rollup@3.1.1(rollup@4.59.0))(@vitejs/plugin-react@6.0.1(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)))(@vitejs/plugin-rsc@0.5.23(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5))(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(next@16.2.3(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(sass@1.98.0))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5)(typescript@6.0.2) + version: 0.0.41(@mdx-js/rollup@3.1.1(rollup@4.59.0))(@vitejs/plugin-react@6.0.1(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)))(@vitejs/plugin-rsc@0.5.23(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5))(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(next@16.2.3(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(sass@1.98.0))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5)(typescript@6.0.2) vite: specifier: npm:@voidzero-dev/vite-plus-core@0.1.16 version: '@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)' @@ -8336,9 +8336,8 @@ packages: vfile@6.0.3: resolution: {integrity: sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==} - vinext@https://pkg.pr.new/vinext@adbf24d: - resolution: {tarball: https://pkg.pr.new/vinext@adbf24d} - version: 0.0.5 + vinext@0.0.41: + resolution: {integrity: sha512-fpQjNp6cIqjYGH2/kbhN2SdIYHEu79RdlII23SWsY1Qp7LM+je8GfTJH1sxw6dASxPhZKZB/jCmTm5d2/D25zw==} engines: {node: '>=22'} hasBin: true peerDependencies: @@ -16586,7 +16585,7 @@ snapshots: '@types/unist': 3.0.3 vfile-message: 4.0.3 - vinext@https://pkg.pr.new/vinext@adbf24d(@mdx-js/rollup@3.1.1(rollup@4.59.0))(@vitejs/plugin-react@6.0.1(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)))(@vitejs/plugin-rsc@0.5.23(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5))(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(next@16.2.3(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(sass@1.98.0))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5)(typescript@6.0.2): + vinext@0.0.41(@mdx-js/rollup@3.1.1(rollup@4.59.0))(@vitejs/plugin-react@6.0.1(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)))(@vitejs/plugin-rsc@0.5.23(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5))(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(next@16.2.3(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(sass@1.98.0))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5)(typescript@6.0.2): dependencies: '@unpic/react': 1.0.2(next@16.2.3(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(sass@1.98.0))(react-dom@19.2.5(react@19.2.5))(react@19.2.5) '@vercel/og': 0.8.6 diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml index f81a452011..91689f4d0c 100644 --- a/pnpm-workspace.yaml +++ b/pnpm-workspace.yaml @@ -221,7 +221,7 @@ catalog: unist-util-visit: 5.1.0 use-context-selector: 2.0.0 uuid: 13.0.0 - vinext: https://pkg.pr.new/vinext@adbf24d + vinext: 0.0.41 vite: npm:@voidzero-dev/vite-plus-core@0.1.16 vite-plugin-inspect: 12.0.0-beta.1 vite-plus: 0.1.16 diff --git a/web/app/components/workflow/nodes/tool/__tests__/output-schema-utils.spec.ts b/web/app/components/workflow/nodes/tool/__tests__/output-schema-utils.spec.ts index 4d095ab189..f5179742b2 100644 --- a/web/app/components/workflow/nodes/tool/__tests__/output-schema-utils.spec.ts +++ b/web/app/components/workflow/nodes/tool/__tests__/output-schema-utils.spec.ts @@ -229,6 +229,23 @@ describe('output-schema-utils', () => { }) }) + describe('Dify compact types (workflow-as-tool output_schema)', () => { + it('should resolve array[string] to arrayString (issue #34428)', () => { + const result = resolveVarType({ type: 'array[string]' }) + expect(result.type).toBe(VarType.arrayString) + }) + + it('should resolve Array[string] case-insensitively', () => { + const result = resolveVarType({ type: 'Array[string]' }) + expect(result.type).toBe(VarType.arrayString) + }) + + it('should resolve array[object] to arrayObject', () => { + const result = resolveVarType({ type: 'array[object]' }) + expect(result.type).toBe(VarType.arrayObject) + }) + }) + describe('unknown types', () => { it('should resolve unknown type to any', () => { const result = resolveVarType({ type: 'unknown_type' }) diff --git a/web/app/components/workflow/nodes/tool/output-schema-utils.ts b/web/app/components/workflow/nodes/tool/output-schema-utils.ts index 141c679da0..630673e3e9 100644 --- a/web/app/components/workflow/nodes/tool/output-schema-utils.ts +++ b/web/app/components/workflow/nodes/tool/output-schema-utils.ts @@ -2,6 +2,30 @@ import type { SchemaTypeDefinition } from '@/service/use-common' import { VarType } from '@/app/components/workflow/types' import { getMatchedSchemaType } from '../_base/components/variable/use-match-schema-type' +/** + * Workflow-as-tool and some internal APIs store Dify VarType strings (e.g. `array[string]`) + * in JSON Schema `type` instead of standard `{ type: 'array', items: { type: 'string' } }`. + * Map those compact strings to VarType so downstream (e.g. Code node var picker) does not + * fall back to `any` and get filtered out. + */ +const resolveDifyCompactTypeString = (typeStr: string): VarType | undefined => { + const trimmed = typeStr.trim() + const m = /^array\[(string|number|integer|boolean|object|file|any)\]$/i.exec(trimmed) + if (!m) + return undefined + const inner = m[1].toLowerCase() + const map: Record = { + string: VarType.arrayString, + number: VarType.arrayNumber, + integer: VarType.arrayNumber, + boolean: VarType.arrayBoolean, + object: VarType.arrayObject, + file: VarType.arrayFile, + any: VarType.arrayAny, + } + return map[inner] +} + /** * Normalizes a JSON Schema type to a simple string type. * Handles complex schemas with oneOf, anyOf, allOf. @@ -54,6 +78,12 @@ export const resolveVarType = ( schemaTypeDefinitions?: SchemaTypeDefinition[], ): { type: VarType, schemaType?: string } => { const schemaType = getMatchedSchemaType(schema, schemaTypeDefinitions) + if (schema && typeof schema.type === 'string') { + const compact = resolveDifyCompactTypeString(schema.type) + if (compact !== undefined) + return { type: compact, schemaType } + } + const normalizedType = normalizeJsonSchemaType(schema) switch (normalizedType) {