Merge remote-tracking branch 'myori/main' into feat/collaboration2

This commit is contained in:
hjlarry 2026-04-10 09:41:47 +08:00
commit 59e752dcd3
38 changed files with 774 additions and 474 deletions

View File

@ -346,89 +346,6 @@ class PublishedRagPipelineRunApi(Resource):
raise InvokeRateLimitHttpError(ex.description)
# class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource):
# @setup_required
# @login_required
# @account_initialization_required
# @get_rag_pipeline
# def post(self, pipeline: Pipeline, node_id: str):
# """
# Run rag pipeline datasource
# """
# # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.has_edit_permission:
# raise Forbidden()
#
# if not isinstance(current_user, Account):
# raise Forbidden()
#
# parser = (reqparse.RequestParser()
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
# .add_argument("datasource_type", type=str, required=True, location="json")
# )
# args = parser.parse_args()
#
# job_id = args.get("job_id")
# if job_id == None:
# raise ValueError("missing job_id")
# datasource_type = args.get("datasource_type")
# if datasource_type == None:
# raise ValueError("missing datasource_type")
#
# rag_pipeline_service = RagPipelineService()
# result = rag_pipeline_service.run_datasource_workflow_node_status(
# pipeline=pipeline,
# node_id=node_id,
# job_id=job_id,
# account=current_user,
# datasource_type=datasource_type,
# is_published=True
# )
#
# return result
# class RagPipelineDraftDatasourceNodeRunStatusApi(Resource):
# @setup_required
# @login_required
# @account_initialization_required
# @get_rag_pipeline
# def post(self, pipeline: Pipeline, node_id: str):
# """
# Run rag pipeline datasource
# """
# # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.has_edit_permission:
# raise Forbidden()
#
# if not isinstance(current_user, Account):
# raise Forbidden()
#
# parser = (reqparse.RequestParser()
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
# .add_argument("datasource_type", type=str, required=True, location="json")
# )
# args = parser.parse_args()
#
# job_id = args.get("job_id")
# if job_id == None:
# raise ValueError("missing job_id")
# datasource_type = args.get("datasource_type")
# if datasource_type == None:
# raise ValueError("missing datasource_type")
#
# rag_pipeline_service = RagPipelineService()
# result = rag_pipeline_service.run_datasource_workflow_node_status(
# pipeline=pipeline,
# node_id=node_id,
# job_id=job_id,
# account=current_user,
# datasource_type=datasource_type,
# is_published=False
# )
#
# return result
#
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])

View File

@ -7,7 +7,8 @@ import logging
from collections.abc import Generator
from flask import Response, jsonify, request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
@ -33,6 +34,11 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
def _jsonify_form_definition(form: Form) -> Response:
payload = form.get_definition().model_dump()
payload["expiration_time"] = int(form.expiration_time.timestamp())
@ -84,10 +90,7 @@ class ConsoleHumanInputFormApi(Resource):
"action": "Approve"
}
"""
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("action", type=str, required=True, location="json")
args = parser.parse_args()
payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
current_user, _ = current_account_with_tenant()
service = HumanInputService(db.engine)
@ -107,8 +110,8 @@ class ConsoleHumanInputFormApi(Resource):
service.submit_form_by_token(
recipient_type=recipient_type,
form_token=form_token,
selected_action_id=args["action"],
form_data=args["inputs"],
selected_action_id=payload.action,
form_data=payload.inputs,
submission_user_id=current_user.id,
)

View File

@ -7,7 +7,8 @@ import logging
from datetime import datetime
from flask import Response, request
from flask_restx import Resource, reqparse
from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select
from werkzeug.exceptions import Forbidden
@ -23,6 +24,12 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
prefix="web_form_submit_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
@ -112,10 +119,7 @@ class HumanInputFormApi(Resource):
"action": "Approve"
}
"""
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("action", type=str, required=True, location="json")
args = parser.parse_args()
payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
ip_address = extract_remote_ip(request)
if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address):
@ -135,8 +139,8 @@ class HumanInputFormApi(Resource):
service.submit_form_by_token(
recipient_type=recipient_type,
form_token=form_token,
selected_action_id=args["action"],
form_data=args["inputs"],
selected_action_id=payload.action,
form_data=payload.inputs,
submission_end_user_id=None,
# submission_end_user_id=_end_user.id,
)

View File

@ -2,7 +2,6 @@ import logging
import time
from typing import cast
from graphon.entities import GraphInitParams
from graphon.enums import WorkflowType
from graphon.graph import Graph
from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent
@ -22,7 +21,7 @@ from core.app.entities.app_invoke_entities import (
)
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id
from core.workflow.node_factory import DifyGraphInitContext, DifyNodeFactory, get_default_root_node_id
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
from core.workflow.workflow_entry import WorkflowEntry
@ -265,22 +264,23 @@ class PipelineRunner(WorkflowBasedAppRunner):
# graph_config["nodes"] = real_run_nodes
# graph_config["edges"] = real_edges
# init graph
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
# Create explicit graph init context for Graph.init.
run_context = build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
)
graph_init_context = DifyGraphInitContext(
workflow_id=workflow.id,
graph_config=graph_config,
run_context=build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
),
run_context=run_context,
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
node_factory = DifyNodeFactory.from_graph_init_context(
graph_init_context=graph_init_context,
graph_runtime_state=graph_runtime_state,
)
if start_node_id is None:

View File

@ -3,7 +3,6 @@ import time
from collections.abc import Mapping, Sequence
from typing import Any, cast
from graphon.entities import GraphInitParams
from graphon.entities.graph_config import NodeConfigDictAdapter
from graphon.entities.pause_reason import HumanInputRequired
from graphon.graph import Graph
@ -67,7 +66,12 @@ from core.app.entities.queue_entities import (
QueueWorkflowSucceededEvent,
)
from core.rag.entities import RetrievalSourceMetadata
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
from core.workflow.node_factory import (
DifyGraphInitContext,
DifyNodeFactory,
get_default_root_node_id,
resolve_workflow_node_class,
)
from core.workflow.system_variables import (
build_bootstrap_variables,
default_system_variables,
@ -127,24 +131,25 @@ class WorkflowBasedAppRunner:
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
# Create explicit graph init context for Graph.init.
run_context = build_dify_run_context(
tenant_id=tenant_id or "",
app_id=self._app_id,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
)
graph_init_context = DifyGraphInitContext(
workflow_id=workflow_id,
graph_config=graph_config,
run_context=build_dify_run_context(
tenant_id=tenant_id or "",
app_id=self._app_id,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
),
run_context=run_context,
call_depth=0,
)
# Use the provided graph_runtime_state for consistent state management
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
node_factory = DifyNodeFactory.from_graph_init_context(
graph_init_context=graph_init_context,
graph_runtime_state=graph_runtime_state,
)
@ -289,22 +294,23 @@ class WorkflowBasedAppRunner:
typed_node_configs = [NodeConfigDictAdapter.validate_python(node) for node in node_configs]
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
# Create explicit graph init context for Graph.init.
run_context = build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
graph_init_context = DifyGraphInitContext(
workflow_id=workflow.id,
graph_config=graph_config,
run_context=build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
),
run_context=run_context,
call_depth=0,
)
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
node_factory = DifyNodeFactory.from_graph_init_context(
graph_init_context=graph_init_context,
graph_runtime_state=graph_runtime_state,
)

View File

@ -146,7 +146,7 @@ def discover_protected_resource_metadata(
return ProtectedResourceMetadata.model_validate(response.json())
elif response.status_code == 404:
continue # Try next URL
except (RequestError, ValidationError):
except (RequestError, ValidationError, json.JSONDecodeError):
continue # Try next URL
return None
@ -166,7 +166,7 @@ def discover_oauth_authorization_server_metadata(
return OAuthMetadata.model_validate(response.json())
elif response.status_code == 404:
continue # Try next URL
except (RequestError, ValidationError):
except (RequestError, ValidationError, json.JSONDecodeError):
continue # Try next URL
return None
@ -276,7 +276,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
else:
return False, ""
return False, ""
except RequestError:
except (RequestError, json.JSONDecodeError, IndexError):
# Not support resource discovery, fall back to well-known OAuth metadata
return False, ""

View File

@ -61,27 +61,28 @@ class TokenBufferMemory:
:param is_user_message: whether this is a user message
:return: PromptMessage
"""
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
app = self.conversation.app
if not app:
raise ValueError("App not found for conversation")
match self.conversation.mode:
case AppMode.AGENT_CHAT | AppMode.COMPLETION | AppMode.CHAT:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW:
app = self.conversation.app
if not app:
raise ValueError("App not found for conversation")
if not message.workflow_run_id:
raise ValueError("Workflow run ID not found")
if not message.workflow_run_id:
raise ValueError("Workflow run ID not found")
workflow_run = self.workflow_run_repo.get_workflow_run_by_id(
tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id
)
if not workflow_run:
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
if not workflow:
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
else:
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
workflow_run = self.workflow_run_repo.get_workflow_run_by_id(
tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id
)
if not workflow_run:
raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
if not workflow:
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
case _:
raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
detail = ImagePromptMessageContent.DETAIL.HIGH
if file_extra_config and app_record:

View File

@ -5,6 +5,7 @@ from typing import Any
from elasticsearch import Elasticsearch
from pydantic import BaseModel, model_validator
from typing_extensions import TypedDict
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@ -19,6 +20,16 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__)
class HuaweiElasticsearchParamsDict(TypedDict, total=False):
hosts: list[str]
verify_certs: bool
ssl_show_warn: bool
request_timeout: int
retry_on_timeout: bool
max_retries: int
basic_auth: tuple[str, str]
def create_ssl_context() -> ssl.SSLContext:
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
@ -38,15 +49,15 @@ class HuaweiCloudVectorConfig(BaseModel):
raise ValueError("config HOSTS is required")
return values
def to_elasticsearch_params(self) -> dict[str, Any]:
params = {
"hosts": self.hosts.split(","),
"verify_certs": False,
"ssl_show_warn": False,
"request_timeout": 30000,
"retry_on_timeout": True,
"max_retries": 10,
}
def to_elasticsearch_params(self) -> HuaweiElasticsearchParamsDict:
params = HuaweiElasticsearchParamsDict(
hosts=self.hosts.split(","),
verify_certs=False,
ssl_show_warn=False,
request_timeout=30000,
retry_on_timeout=True,
max_retries=10,
)
if self.username and self.password:
params["basic_auth"] = (self.username, self.password)
return params

View File

@ -7,6 +7,7 @@ from opensearchpy import OpenSearch, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
from tenacity import retry, stop_after_attempt, wait_exponential
from typing_extensions import TypedDict
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@ -26,6 +27,14 @@ ROUTING_FIELD = "routing_field"
UGC_INDEX_PREFIX = "ugc_index"
class LindormOpenSearchParamsDict(TypedDict, total=False):
hosts: str | None
use_ssl: bool
pool_maxsize: int
timeout: int
http_auth: tuple[str, str]
class LindormVectorStoreConfig(BaseModel):
hosts: str | None
username: str | None = None
@ -44,13 +53,13 @@ class LindormVectorStoreConfig(BaseModel):
raise ValueError("config PASSWORD is required")
return values
def to_opensearch_params(self) -> dict[str, Any]:
params: dict[str, Any] = {
"hosts": self.hosts,
"use_ssl": False,
"pool_maxsize": 128,
"timeout": 30,
}
def to_opensearch_params(self) -> LindormOpenSearchParamsDict:
params = LindormOpenSearchParamsDict(
hosts=self.hosts,
use_ssl=False,
pool_maxsize=128,
timeout=30,
)
if self.username and self.password:
params["http_auth"] = (self.username, self.password)
return params

View File

@ -6,6 +6,7 @@ from uuid import uuid4
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
from typing_extensions import TypedDict
from configs import dify_config
from configs.middleware.vdb.opensearch_config import AuthMethod
@ -21,6 +22,20 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__)
class _OpenSearchHostDict(TypedDict):
host: str
port: int
class OpenSearchParamsDict(TypedDict, total=False):
hosts: list[_OpenSearchHostDict]
use_ssl: bool
verify_certs: bool
connection_class: type
pool_maxsize: int
http_auth: tuple[str | None, str | None] | Urllib3AWSV4SignerAuth
class OpenSearchConfig(BaseModel):
host: str
port: int
@ -57,14 +72,14 @@ class OpenSearchConfig(BaseModel):
service=self.aws_service, # type: ignore[arg-type]
)
def to_opensearch_params(self) -> dict[str, Any]:
params = {
"hosts": [{"host": self.host, "port": self.port}],
"use_ssl": self.secure,
"verify_certs": self.verify_certs,
"connection_class": Urllib3HttpConnection,
"pool_maxsize": 20,
}
def to_opensearch_params(self) -> OpenSearchParamsDict:
params = OpenSearchParamsDict(
hosts=[{"host": self.host, "port": self.port}],
use_ssl=self.secure,
verify_certs=self.verify_certs,
connection_class=Urllib3HttpConnection,
pool_maxsize=20,
)
if self.auth_method == "basic":
logger.info("Using basic authentication for OpenSearch Vector DB")

View File

@ -19,5 +19,18 @@ def remove_leading_symbols(text: str) -> str:
# Match Unicode ranges for punctuation and symbols
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+'
pattern = re.compile(
r"""
^
(?:
[\u2000-\u2025] # General Punctuation: spaces, quotes, dashes
| [\u2027-\u206F] # General Punctuation: ellipsis, underscores, etc.
| [\u2E00-\u2E7F] # Supplemental Punctuation: medieval, ancient marks
| [\u3000-\u300F] # CJK Punctuation: 、。〃「」『》』 (excludes 【】)
| [\u3012-\u303F] # CJK Punctuation: 〖〗〔〕〘〙〚〛〜 etc.
| ["#$%&'()*+,./:;<=>?@^_`~] # ASCII punctuation (excludes []【】)
)+
""",
re.VERBOSE,
)
return re.sub(pattern, "", text)

View File

@ -1,6 +1,7 @@
import importlib
import pkgutil
from collections.abc import Callable, Iterator, Mapping, MutableMapping
from dataclasses import dataclass
from functools import lru_cache
from typing import TYPE_CHECKING, Any, cast, final, override
@ -67,6 +68,31 @@ _START_NODE_TYPES: frozenset[NodeType] = frozenset(
)
@dataclass(frozen=True, slots=True)
class DifyGraphInitContext:
"""Explicit graph-init values owned by the workflow layer.
Dify is gradually removing direct `GraphInitParams` construction from its
production call sites. Keep the translation here until `graphon` exposes an
equivalent explicit API.
"""
workflow_id: str
graph_config: Mapping[str, Any]
run_context: Mapping[str, Any]
call_depth: int
def to_graph_init_params(self) -> "GraphInitParams":
from graphon.entities import GraphInitParams
return GraphInitParams(
workflow_id=self.workflow_id,
graph_config=self.graph_config,
run_context=self.run_context,
call_depth=self.call_depth,
)
def _import_node_package(package_name: str, *, excluded_modules: frozenset[str] = frozenset()) -> None:
package = importlib.import_module(package_name)
for _, module_name, _ in pkgutil.walk_packages(package.__path__, package.__name__ + "."):
@ -237,6 +263,19 @@ class DifyNodeFactory(NodeFactory):
Default implementation of NodeFactory that resolves node classes from the live registry.
"""
@classmethod
def from_graph_init_context(
cls,
*,
graph_init_context: DifyGraphInitContext,
graph_runtime_state: "GraphRuntimeState",
) -> "DifyNodeFactory":
"""Bridge Dify's explicit init context into the current `graphon` API."""
return cls(
graph_init_params=graph_init_context.to_graph_init_params(),
graph_runtime_state=graph_runtime_state,
)
def __init__(
self,
graph_init_params: "GraphInitParams",

View File

@ -29,7 +29,7 @@ class TriggerWebhookNode(Node[WebhookData]):
def post_init(self) -> None:
from core.workflow.node_runtime import DifyFileReferenceFactory
self._file_reference_factory = DifyFileReferenceFactory(self.graph_init_params.run_context)
self._file_reference_factory = DifyFileReferenceFactory(self.run_context)
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:

View File

@ -24,7 +24,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_di
from core.app.file_access import DatabaseFileAccessController
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.layers.observability import ObservabilityLayer
from core.workflow.node_factory import DifyNodeFactory, is_start_node_type, resolve_workflow_node_class
from core.workflow.node_factory import (
DifyGraphInitContext,
DifyNodeFactory,
is_start_node_type,
resolve_workflow_node_class,
)
from core.workflow.system_variables import (
default_system_variables,
get_node_creation_preload_selectors,
@ -251,17 +256,18 @@ class WorkflowEntry:
node_version = str(node_config_data.version)
node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version)
# init graph init params and runtime state
graph_init_params = GraphInitParams(
# init graph context and runtime state
run_context = build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
graph_init_context = DifyGraphInitContext(
workflow_id=workflow.id,
graph_config=workflow.graph_dict,
run_context=build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
),
run_context=run_context,
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(
@ -313,8 +319,8 @@ class WorkflowEntry:
)
# init workflow run state
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
node_factory = DifyNodeFactory.from_graph_init_context(
graph_init_context=graph_init_context,
graph_runtime_state=graph_runtime_state,
)
node = node_factory.create_node(node_config)
@ -409,17 +415,18 @@ class WorkflowEntry:
variable_pool = VariablePool()
add_variables_to_pool(variable_pool, default_system_variables())
# init graph init params and runtime state
graph_init_params = GraphInitParams(
# init graph context and runtime state
run_context = build_dify_run_context(
tenant_id=tenant_id,
app_id="",
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
graph_init_context = DifyGraphInitContext(
workflow_id="",
graph_config=graph_dict,
run_context=build_dify_run_context(
tenant_id=tenant_id,
app_id="",
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
),
run_context=run_context,
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(
@ -430,8 +437,8 @@ class WorkflowEntry:
# init workflow run state
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
node_factory = DifyNodeFactory.from_graph_init_context(
graph_init_context=graph_init_context,
graph_runtime_state=graph_runtime_state,
)
node = node_factory.create_node(node_config)

View File

@ -5,12 +5,30 @@ from typing import Any
import pytz # type: ignore[import-untyped]
from celery import Celery, Task
from celery.schedules import crontab
from typing_extensions import TypedDict
from configs import dify_config
from dify_app import DifyApp
def get_celery_ssl_options() -> dict[str, Any] | None:
class _CelerySentinelKwargsDict(TypedDict):
socket_timeout: float | None
password: str | None
class CelerySentinelTransportDict(TypedDict):
master_name: str | None
sentinel_kwargs: _CelerySentinelKwargsDict
class CelerySSLOptionsDict(TypedDict):
ssl_cert_reqs: int
ssl_ca_certs: str | None
ssl_certfile: str | None
ssl_keyfile: str | None
def get_celery_ssl_options() -> CelerySSLOptionsDict | None:
"""Get SSL configuration for Celery broker/backend connections."""
# Only apply SSL if we're using Redis as broker/backend
if not dify_config.BROKER_USE_SSL:
@ -33,26 +51,24 @@ def get_celery_ssl_options() -> dict[str, Any] | None:
ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE)
ssl_options = {
"ssl_cert_reqs": ssl_cert_reqs,
"ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS,
"ssl_certfile": dify_config.REDIS_SSL_CERTFILE,
"ssl_keyfile": dify_config.REDIS_SSL_KEYFILE,
}
return ssl_options
return CelerySSLOptionsDict(
ssl_cert_reqs=ssl_cert_reqs,
ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS,
ssl_certfile=dify_config.REDIS_SSL_CERTFILE,
ssl_keyfile=dify_config.REDIS_SSL_KEYFILE,
)
def get_celery_broker_transport_options() -> dict[str, Any]:
def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]:
"""Get broker transport options (e.g. Redis Sentinel) for Celery connections."""
if dify_config.CELERY_USE_SENTINEL:
return {
"master_name": dify_config.CELERY_SENTINEL_MASTER_NAME,
"sentinel_kwargs": {
"socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
"password": dify_config.CELERY_SENTINEL_PASSWORD,
},
}
return CelerySentinelTransportDict(
master_name=dify_config.CELERY_SENTINEL_MASTER_NAME,
sentinel_kwargs=_CelerySentinelKwargsDict(
socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
password=dify_config.CELERY_SENTINEL_PASSWORD,
),
)
return {}

View File

@ -674,28 +674,24 @@ class AppModelConfig(TypeBase):
def suggested_questions_list(self) -> list[str]:
return json.loads(self.suggested_questions) if self.suggested_questions else []
def _get_enabled_config(self, value: str | None, *, default_enabled: bool = False) -> EnabledConfig:
return cast(EnabledConfig, json.loads(value) if value else {"enabled": default_enabled})
@property
def suggested_questions_after_answer_dict(self) -> EnabledConfig:
return cast(
EnabledConfig,
json.loads(self.suggested_questions_after_answer)
if self.suggested_questions_after_answer
else {"enabled": False},
)
return self._get_enabled_config(self.suggested_questions_after_answer)
@property
def speech_to_text_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False})
return self._get_enabled_config(self.speech_to_text)
@property
def text_to_speech_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False})
return self._get_enabled_config(self.text_to_speech)
@property
def retriever_resource_dict(self) -> EnabledConfig:
return cast(
EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
)
return self._get_enabled_config(self.retriever_resource, default_enabled=True)
@property
def annotation_reply_dict(self) -> AnnotationReplyConfig:
@ -722,7 +718,7 @@ class AppModelConfig(TypeBase):
@property
def more_like_this_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False})
return self._get_enabled_config(self.more_like_this)
@property
def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig:
@ -902,7 +898,7 @@ class InstalledApp(TypeBase):
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
class TrialApp(Base):
class TrialApp(TypeBase):
__tablename__ = "trial_apps"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="trial_app_pkey"),
@ -911,18 +907,26 @@ class TrialApp(Base):
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
)
id = mapped_column(StringUUID, default=gen_uuidv4_string)
app_id = mapped_column(StringUUID, nullable=False)
tenant_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
trial_limit = mapped_column(sa.Integer, nullable=False, default=3)
id: Mapped[str] = mapped_column(
StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False
)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
insert_default=func.current_timestamp(),
server_default=func.current_timestamp(),
init=False,
)
trial_limit: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=3)
@property
def app(self) -> App | None:
return db.session.scalar(select(App).where(App.id == self.app_id))
class AccountTrialAppRecord(Base):
class AccountTrialAppRecord(TypeBase):
__tablename__ = "account_trial_app_records"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"),
@ -930,11 +934,19 @@ class AccountTrialAppRecord(Base):
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
)
id = mapped_column(StringUUID, default=gen_uuidv4_string)
account_id = mapped_column(StringUUID, nullable=False)
app_id = mapped_column(StringUUID, nullable=False)
count = mapped_column(sa.Integer, nullable=False, default=0)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
id: Mapped[str] = mapped_column(
StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False
)
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
insert_default=func.current_timestamp(),
server_default=func.current_timestamp(),
init=False,
)
@property
def app(self) -> App | None:

View File

@ -66,12 +66,15 @@ def build_file_from_stored_mapping(
record_id = resolve_file_record_id(mapping)
transfer_method = FileTransferMethod.value_of(mapping["transfer_method"])
if transfer_method == FileTransferMethod.TOOL_FILE and record_id:
mapping["tool_file_id"] = record_id
elif transfer_method in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL] and record_id:
mapping["upload_file_id"] = record_id
elif transfer_method == FileTransferMethod.DATASOURCE_FILE and record_id:
mapping["datasource_file_id"] = record_id
match transfer_method:
case FileTransferMethod.TOOL_FILE if record_id:
mapping["tool_file_id"] = record_id
case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL if record_id:
mapping["upload_file_id"] = record_id
case FileTransferMethod.DATASOURCE_FILE if record_id:
mapping["datasource_file_id"] = record_id
case _:
pass
if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None:
remote_url = mapping.get("remote_url")

View File

@ -173,7 +173,7 @@ dev = [
"sseclient-py>=1.8.0",
"pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0",
"pyrefly>=0.59.1",
"pyrefly>=0.60.0",
]
############################################################

View File

@ -467,61 +467,67 @@ class AppDslService:
)
# Initialize app based on mode
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow_data = data.get("workflow")
if not workflow_data or not isinstance(workflow_data, dict):
raise ValueError("Missing workflow data for workflow/advanced chat app")
match app_mode:
case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW:
workflow_data = data.get("workflow")
if not workflow_data or not isinstance(workflow_data, dict):
raise ValueError("Missing workflow data for workflow/advanced chat app")
environment_variables_list = workflow_data.get("environment_variables", [])
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = workflow_data.get("conversation_variables", [])
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
environment_variables_list = workflow_data.get("environment_variables", [])
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = workflow_data.get("conversation_variables", [])
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj)
for obj in conversation_variables_list
]
workflow_service = WorkflowService()
current_draft_workflow = workflow_service.get_draft_workflow(app_model=app)
if current_draft_workflow:
unique_hash = current_draft_workflow.unique_hash
else:
unique_hash = None
graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
decrypted_id
for dataset_id in dataset_ids
if (decrypted_id := self.decrypt_dataset_id(encrypted_data=dataset_id, tenant_id=app.tenant_id))
]
workflow_service.sync_draft_workflow(
app_model=app,
graph=workflow_data.get("graph", {}),
features=workflow_data.get("features", {}),
unique_hash=unique_hash,
account=account,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}:
# Initialize model config
model_config = data.get("model_config")
if not model_config or not isinstance(model_config, dict):
raise ValueError("Missing model_config for chat/agent-chat/completion app")
# Initialize or update model config
if not app.app_model_config:
app_model_config = AppModelConfig(
app_id=app.id, created_by=account.id, updated_by=account.id
).from_model_config_dict(cast(AppModelConfigDict, model_config))
app_model_config.id = str(uuid4())
app.app_model_config_id = app_model_config.id
workflow_service = WorkflowService()
current_draft_workflow = workflow_service.get_draft_workflow(app_model=app)
if current_draft_workflow:
unique_hash = current_draft_workflow.unique_hash
else:
unique_hash = None
graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [
decrypted_id
for dataset_id in dataset_ids
if (
decrypted_id := self.decrypt_dataset_id(
encrypted_data=dataset_id, tenant_id=app.tenant_id
)
)
]
workflow_service.sync_draft_workflow(
app_model=app,
graph=workflow_data.get("graph", {}),
features=workflow_data.get("features", {}),
unique_hash=unique_hash,
account=account,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
case AppMode.CHAT | AppMode.AGENT_CHAT | AppMode.COMPLETION:
# Initialize model config
model_config = data.get("model_config")
if not model_config or not isinstance(model_config, dict):
raise ValueError("Missing model_config for chat/agent-chat/completion app")
# Initialize or update model config
if not app.app_model_config:
app_model_config = AppModelConfig(
app_id=app.id, created_by=account.id, updated_by=account.id
).from_model_config_dict(cast(AppModelConfigDict, model_config))
app_model_config.id = str(uuid4())
app.app_model_config_id = app_model_config.id
self._session.add(app_model_config)
app_model_config_was_updated.send(app, app_model_config=app_model_config)
else:
raise ValueError("Invalid app mode")
self._session.add(app_model_config)
app_model_config_was_updated.send(app, app_model_config=app_model_config)
case _:
raise ValueError("Invalid app mode")
return app
@classmethod

View File

@ -132,8 +132,8 @@ class FileService:
return file_size <= file_size_limit
def get_file_base64(self, file_id: str) -> str:
upload_file = (
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first()
upload_file = self._session_maker(expire_on_commit=False).scalar(
select(UploadFile).where(UploadFile.id == file_id).limit(1)
)
if not upload_file:
raise NotFound("File not found")
@ -178,7 +178,7 @@ class FileService:
Return a short text preview extracted from a document file.
"""
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if not upload_file:
raise NotFound("File not found")
@ -200,7 +200,7 @@ class FileService:
if not result:
raise NotFound("File not found or signature is invalid")
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if not upload_file:
raise NotFound("File not found or signature is invalid")
@ -220,7 +220,7 @@ class FileService:
raise NotFound("File not found or signature is invalid")
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if not upload_file:
raise NotFound("File not found or signature is invalid")
@ -231,7 +231,7 @@ class FileService:
def get_public_image_preview(self, file_id: str):
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if not upload_file:
raise NotFound("File not found or signature is invalid")
@ -247,7 +247,7 @@ class FileService:
def get_file_content(self, file_id: str) -> str:
with self._session_maker(expire_on_commit=False) as session:
upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first()
upload_file: UploadFile | None = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if not upload_file:
raise NotFound("File not found")

View File

@ -1,14 +1,13 @@
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
from core.db.session_factory import session_factory
from models.account import TenantPluginAutoUpgradeStrategy
class PluginAutoUpgradeService:
@staticmethod
def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None:
with sessionmaker(bind=db.engine).begin() as session:
with session_factory.create_session() as session:
return session.scalar(
select(TenantPluginAutoUpgradeStrategy)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
@ -24,7 +23,7 @@ class PluginAutoUpgradeService:
exclude_plugins: list[str],
include_plugins: list[str],
) -> bool:
with sessionmaker(bind=db.engine).begin() as session:
with session_factory.create_session() as session:
exist_strategy = session.scalar(
select(TenantPluginAutoUpgradeStrategy)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
@ -51,7 +50,7 @@ class PluginAutoUpgradeService:
@staticmethod
def exclude_plugin(tenant_id: str, plugin_id: str) -> bool:
with sessionmaker(bind=db.engine).begin() as session:
with session_factory.create_session() as session:
exist_strategy = session.scalar(
select(TenantPluginAutoUpgradeStrategy)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)

View File

@ -1,3 +1,5 @@
from typing import Any, TypedDict
from sqlalchemy import select
from constants.languages import languages
@ -8,16 +10,43 @@ from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase
from services.recommend_app.recommend_app_type import RecommendAppType
class RecommendedAppItemDict(TypedDict):
id: str
app: App | None
app_id: str
description: Any
copyright: Any
privacy_policy: Any
custom_disclaimer: str
category: str
position: int
is_listed: bool
class RecommendedAppsResultDict(TypedDict):
recommended_apps: list[RecommendedAppItemDict]
categories: list[str]
class RecommendedAppDetailDict(TypedDict):
id: str
name: str
icon: Any
icon_background: str | None
mode: str
export_data: str
class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
"""
Retrieval recommended app from database
"""
def get_recommended_apps_and_categories(self, language: str):
def get_recommended_apps_and_categories(self, language: str) -> RecommendedAppsResultDict:
result = self.fetch_recommended_apps_from_db(language)
return result
def get_recommend_app_detail(self, app_id: str):
def get_recommend_app_detail(self, app_id: str) -> RecommendedAppDetailDict | None:
result = self.fetch_recommended_app_detail_from_db(app_id)
return result
@ -25,7 +54,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
return RecommendAppType.DATABASE
@classmethod
def fetch_recommended_apps_from_db(cls, language: str):
def fetch_recommended_apps_from_db(cls, language: str) -> RecommendedAppsResultDict:
"""
Fetch recommended apps from db.
:param language: language
@ -41,7 +70,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
).all()
categories = set()
recommended_apps_result = []
recommended_apps_result: list[RecommendedAppItemDict] = []
for recommended_app in recommended_apps:
app = recommended_app.app
if not app or not app.is_public:
@ -51,7 +80,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
if not site:
continue
recommended_app_result = {
recommended_app_result: RecommendedAppItemDict = {
"id": recommended_app.id,
"app": recommended_app.app,
"app_id": recommended_app.app_id,
@ -67,10 +96,10 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
categories.add(recommended_app.category)
return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)}
return RecommendedAppsResultDict(recommended_apps=recommended_apps_result, categories=sorted(categories))
@classmethod
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> dict | None:
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> RecommendedAppDetailDict | None:
"""
Fetch recommended app detail from db.
:param app_id: App ID
@ -89,11 +118,11 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
if not app_model or not app_model.is_public:
return None
return {
"id": app_model.id,
"name": app_model.name,
"icon": app_model.icon,
"icon_background": app_model.icon_background,
"mode": app_model.mode,
"export_data": AppDslService.export_dsl(app_model=app_model),
}
return RecommendedAppDetailDict(
id=app_model.id,
name=app_model.name,
icon=app_model.icon,
icon_background=app_model.icon_background,
mode=app_model.mode,
export_data=AppDslService.export_dsl(app_model=app_model),
)

View File

@ -104,32 +104,32 @@ class WebhookService:
"""
with Session(db.engine) as session:
# Get webhook trigger
webhook_trigger = (
session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.webhook_id == webhook_id).first()
webhook_trigger = session.scalar(
select(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.webhook_id == webhook_id).limit(1)
)
if not webhook_trigger:
raise ValueError(f"Webhook not found: {webhook_id}")
if is_debug:
workflow = (
session.query(Workflow)
.filter(
workflow = session.scalar(
select(Workflow)
.where(
Workflow.app_id == webhook_trigger.app_id,
Workflow.version == Workflow.VERSION_DRAFT,
)
.order_by(Workflow.created_at.desc())
.first()
.limit(1)
)
else:
# Check if the corresponding AppTrigger exists
app_trigger = (
session.query(AppTrigger)
.filter(
app_trigger = session.scalar(
select(AppTrigger)
.where(
AppTrigger.app_id == webhook_trigger.app_id,
AppTrigger.node_id == webhook_trigger.node_id,
AppTrigger.trigger_type == AppTriggerType.TRIGGER_WEBHOOK,
)
.first()
.limit(1)
)
if not app_trigger:
@ -146,14 +146,14 @@ class WebhookService:
raise ValueError(f"Webhook trigger is disabled for webhook {webhook_id}")
# Get workflow
workflow = (
session.query(Workflow)
.filter(
workflow = session.scalar(
select(Workflow)
.where(
Workflow.app_id == webhook_trigger.app_id,
Workflow.version != Workflow.VERSION_DRAFT,
)
.order_by(Workflow.created_at.desc())
.first()
.limit(1)
)
if not workflow:
raise ValueError(f"Workflow not found for app {webhook_trigger.app_id}")

View File

@ -5,7 +5,7 @@ import uuid
from collections.abc import Callable, Generator, Mapping, Sequence
from typing import Any, cast
from graphon.entities import GraphInitParams, WorkflowNodeExecution
from graphon.entities import WorkflowNodeExecution
from graphon.entities.graph_config import NodeConfigDict
from graphon.entities.pause_reason import HumanInputRequired
from graphon.enums import (
@ -48,7 +48,12 @@ from core.workflow.human_input_compat import (
normalize_human_input_node_data_for_graph,
parse_human_input_delivery_methods,
)
from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type
from core.workflow.node_factory import (
LATEST_VERSION,
DifyGraphInitContext,
get_node_type_classes_mapping,
is_start_node_type,
)
from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
@ -1204,18 +1209,20 @@ class WorkflowService:
node_config: NodeConfigDict,
variable_pool: VariablePool,
) -> HumanInputNode:
graph_init_params = GraphInitParams(
run_context = build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
user_id=account.id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
graph_init_context = DifyGraphInitContext(
workflow_id=workflow.id,
graph_config=workflow.graph_dict,
run_context=build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
user_id=account.id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
),
run_context=run_context,
call_depth=0,
)
graph_init_params = graph_init_context.to_graph_init_params()
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter(),
@ -1225,7 +1232,7 @@ class WorkflowService:
config=node_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context),
runtime=DifyHumanInputNodeRuntime(run_context),
)
return node

View File

@ -3,6 +3,7 @@ import time
import click
from celery import shared_task # type: ignore
from sqlalchemy import select, update
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
@ -26,43 +27,42 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "upgrade":
dataset_documents = (
session.query(DatasetDocument)
.where(
dataset_documents = session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
).all()
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id.in_(dataset_documents_ids))
.values(indexing_status="indexing")
)
session.commit()
for dataset_document in dataset_documents:
try:
# add from vector index
segments = (
session.query(DocumentSegment)
segments = session.scalars(
select(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True,
)
.order_by(DocumentSegment.position.asc())
.all()
)
).all()
if segments:
documents = []
for segment in segments:
@ -81,32 +81,36 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
# clean keywords
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
index_processor.load(dataset, documents, with_keywords=False)
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="completed")
)
session.commit()
except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="error", error=str(e))
)
session.commit()
elif action == "update":
dataset_documents = (
session.query(DatasetDocument)
.where(
dataset_documents = session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
).all()
# add new index
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id.in_(dataset_documents_ids))
.values(indexing_status="indexing")
)
session.commit()
@ -116,15 +120,14 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
for dataset_document in dataset_documents:
# update from vector index
try:
segments = (
session.query(DocumentSegment)
segments = session.scalars(
select(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True,
)
.order_by(DocumentSegment.position.asc())
.all()
)
).all()
if segments:
documents = []
multimodal_documents = []
@ -173,13 +176,17 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
index_processor.load(
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
)
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="completed")
)
session.commit()
except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
session.execute(
update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="error", error=str(e))
)
session.commit()
else:

View File

@ -862,6 +862,15 @@ class TestAuthOrchestration:
result = discover_protected_resource_metadata(None, "https://api.example.com")
assert result is None
# JSONDecodeError (non-JSON 200 response)
mock_get.side_effect = None
bad_json_response = Mock()
bad_json_response.status_code = 200
bad_json_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
mock_get.return_value = bad_json_response
result = discover_protected_resource_metadata(None, "https://api.example.com")
assert result is None
@patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_authorization_server_metadata(self, mock_get):
# Success
@ -892,6 +901,14 @@ class TestAuthOrchestration:
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com")
assert result is None
# JSONDecodeError (non-JSON 200 response)
bad_json_response = Mock()
bad_json_response.status_code = 200
bad_json_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
mock_get.return_value = bad_json_response
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com")
assert result is None
def test_get_effective_scope(self):
prm = ProtectedResourceMetadata(
resource="https://api.example.com",
@ -997,6 +1014,24 @@ class TestAuthOrchestration:
supported, url = check_support_resource_discovery("https://api")
assert supported is False
# Case 6: JSONDecodeError (non-JSON 200 response)
mock_get.side_effect = None
bad_json_res = Mock()
bad_json_res.status_code = 200
bad_json_res.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
mock_get.return_value = bad_json_res
supported, url = check_support_resource_discovery("https://api")
assert supported is False
assert url == ""
# Case 7: Empty authorization_servers array (IndexError)
empty_res = Mock()
empty_res.status_code = 200
empty_res.json.return_value = {"authorization_servers": []}
mock_get.return_value = empty_res
supported, url = check_support_resource_discovery("https://api")
assert supported is False
def test_discover_oauth_metadata(self):
with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm:
with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm:

View File

@ -110,6 +110,34 @@ class TestFetchMemory:
)
class TestDifyGraphInitContext:
def test_to_graph_init_params_preserves_explicit_values(self):
run_context = {
DIFY_RUN_CONTEXT_KEY: DifyRunContext(
tenant_id="tenant-id",
app_id="app-id",
user_id="user-id",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
),
"extra": "value",
}
graph_config = {"nodes": [], "edges": []}
graph_init_context = node_factory.DifyGraphInitContext(
workflow_id="workflow-id",
graph_config=graph_config,
run_context=run_context,
call_depth=2,
)
result = graph_init_context.to_graph_init_params()
assert result.workflow_id == "workflow-id"
assert result.graph_config == graph_config
assert result.run_context == run_context
assert result.call_depth == 2
class TestDefaultWorkflowCodeExecutor:
def test_execute_delegates_to_code_executor(self, monkeypatch):
executor = node_factory.DefaultWorkflowCodeExecutor()
@ -172,6 +200,23 @@ class TestCodeExecutorJinja2TemplateRenderer:
class TestDifyNodeFactoryInit:
def test_from_graph_init_context_translates_before_init(self):
graph_init_context = MagicMock()
graph_init_context.to_graph_init_params.return_value = sentinel.graph_init_params
with patch.object(node_factory.DifyNodeFactory, "__init__", return_value=None) as init:
factory = node_factory.DifyNodeFactory.from_graph_init_context(
graph_init_context=graph_init_context,
graph_runtime_state=sentinel.graph_runtime_state,
)
assert isinstance(factory, node_factory.DifyNodeFactory)
graph_init_context.to_graph_init_params.assert_called_once_with()
init.assert_called_once_with(
graph_init_params=sentinel.graph_init_params,
graph_runtime_state=sentinel.graph_runtime_state,
)
def test_init_builds_default_dependencies(self):
graph_init_params = SimpleNamespace(run_context={"context": "value"})
graph_runtime_state = sentinel.graph_runtime_state

View File

@ -349,7 +349,7 @@ class TestWorkflowEntrySingleStepRun:
]
with (
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
patch.object(
workflow_entry,
"GraphRuntimeState",
@ -358,7 +358,7 @@ class TestWorkflowEntrySingleStepRun:
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeLLMNode),
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory,
patch.object(workflow_entry, "load_into_variable_pool"),
patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"),
patch.object(
@ -412,12 +412,12 @@ class TestWorkflowEntrySingleStepRun:
raise NotImplementedError
with (
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeNode),
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory,
patch.object(workflow_entry, "add_node_inputs_to_pool") as add_node_inputs_to_pool,
patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool,
patch.object(
@ -481,12 +481,12 @@ class TestWorkflowEntrySingleStepRun:
return {"question": ["node", "question"]}
with (
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeDatasourceNode),
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory,
patch.object(workflow_entry, "add_node_inputs_to_pool") as add_node_inputs_to_pool,
patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool,
patch.object(
@ -541,12 +541,12 @@ class TestWorkflowEntrySingleStepRun:
return "1"
with (
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeNode),
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory,
patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory,
patch.object(workflow_entry, "add_node_inputs_to_pool"),
patch.object(workflow_entry, "load_into_variable_pool"),
patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"),
@ -651,14 +651,18 @@ class TestWorkflowEntryHelpers:
patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool) as variable_pool_cls,
patch.object(workflow_entry, "add_variables_to_pool") as add_variables_to_pool,
patch.object(
workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params
) as graph_init_params,
workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context
) as graph_init_context_cls,
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
patch.object(
workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}
) as build_dify_run_context,
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory) as dify_node_factory_cls,
patch.object(
workflow_entry.DifyNodeFactory,
"from_graph_init_context",
return_value=dify_node_factory,
) as dify_node_factory_cls,
patch.object(
workflow_entry.WorkflowEntry,
"mapping_user_inputs_to_variable_pool",
@ -688,7 +692,7 @@ class TestWorkflowEntryHelpers:
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
graph_init_params.assert_called_once_with(
graph_init_context_cls.assert_called_once_with(
workflow_id="",
graph_config=workflow_entry.WorkflowEntry._create_single_node_graph(
"node-id", {"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR, "title": "Node"}
@ -697,7 +701,7 @@ class TestWorkflowEntryHelpers:
call_depth=0,
)
dify_node_factory_cls.assert_called_once_with(
graph_init_params=sentinel.graph_init_params,
graph_init_context=sentinel.graph_init_context,
graph_runtime_state=sentinel.graph_runtime_state,
)
mapping_user_inputs_to_variable_pool.assert_called_once_with(
@ -734,11 +738,15 @@ class TestWorkflowEntryHelpers:
patch.object(workflow_entry, "default_system_variables", return_value=sentinel.system_variables),
patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool),
patch.object(workflow_entry, "add_variables_to_pool"),
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params),
patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory),
patch.object(
workflow_entry.DifyNodeFactory,
"from_graph_init_context",
return_value=dify_node_factory,
),
patch.object(
workflow_entry.WorkflowEntry,
"mapping_user_inputs_to_variable_pool",

View File

@ -6,23 +6,23 @@ MODULE = "services.plugin.plugin_auto_upgrade_service"
def _patched_session():
"""Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager."""
"""Patch session_factory.create_session() to return a mock session as context manager."""
session = MagicMock()
mock_sessionmaker = MagicMock()
mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session)
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False)
patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker)
db_patcher = patch(f"{MODULE}.db")
return patcher, db_patcher, session
session.__enter__ = MagicMock(return_value=session)
session.__exit__ = MagicMock(return_value=False)
mock_factory = MagicMock()
mock_factory.create_session.return_value = session
patcher = patch(f"{MODULE}.session_factory", mock_factory)
return patcher, session
class TestGetStrategy:
def test_returns_strategy_when_found(self):
p1, p2, session = _patched_session()
p1, session = _patched_session()
strategy = MagicMock()
session.scalar.return_value = strategy
with p1, p2:
with p1:
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.get_strategy("t1")
@ -30,10 +30,10 @@ class TestGetStrategy:
assert result is strategy
def test_returns_none_when_not_found(self):
p1, p2, session = _patched_session()
p1, session = _patched_session()
session.scalar.return_value = None
with p1, p2:
with p1:
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.get_strategy("t1")
@ -43,10 +43,10 @@ class TestGetStrategy:
class TestChangeStrategy:
def test_creates_new_strategy(self):
p1, p2, session = _patched_session()
p1, session = _patched_session()
session.scalar.return_value = None
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.return_value = MagicMock()
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
@ -63,11 +63,11 @@ class TestChangeStrategy:
session.add.assert_called_once()
def test_updates_existing_strategy(self):
p1, p2, session = _patched_session()
p1, session = _patched_session()
existing = MagicMock()
session.scalar.return_value = existing
with p1, p2:
with p1:
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.change_strategy(
@ -89,12 +89,11 @@ class TestChangeStrategy:
class TestExcludePlugin:
def test_creates_default_strategy_when_none_exists(self):
p1, p2, session = _patched_session()
p1, session = _patched_session()
session.scalar.return_value = None
with (
p1,
p2,
patch(f"{MODULE}.select"),
patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls,
patch(f"{MODULE}.PluginAutoUpgradeService.change_strategy") as cs,
@ -110,13 +109,13 @@ class TestExcludePlugin:
cs.assert_called_once()
def test_appends_to_exclude_list_in_exclude_mode(self):
p1, p2, session = _patched_session()
p1, session = _patched_session()
existing = MagicMock()
existing.upgrade_mode = "exclude"
existing.exclude_plugins = ["p-existing"]
session.scalar.return_value = existing
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all"
@ -128,13 +127,13 @@ class TestExcludePlugin:
assert existing.exclude_plugins == ["p-existing", "p-new"]
def test_removes_from_include_list_in_partial_mode(self):
p1, p2, session = _patched_session()
p1, session = _patched_session()
existing = MagicMock()
existing.upgrade_mode = "partial"
existing.include_plugins = ["p1", "p2"]
session.scalar.return_value = existing
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all"
@ -146,12 +145,12 @@ class TestExcludePlugin:
assert existing.include_plugins == ["p2"]
def test_switches_to_exclude_mode_from_all(self):
p1, p2, session = _patched_session()
p1, session = _patched_session()
existing = MagicMock()
existing.upgrade_mode = "all"
session.scalar.return_value = existing
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all"
@ -164,13 +163,13 @@ class TestExcludePlugin:
assert existing.exclude_plugins == ["p1"]
def test_no_duplicate_in_exclude_list(self):
p1, p2, session = _patched_session()
p1, session = _patched_session()
existing = MagicMock()
existing.upgrade_mode = "exclude"
existing.exclude_plugins = ["p1"]
session.scalar.return_value = existing
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all"

View File

@ -165,7 +165,7 @@ class TestFileService:
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
upload_file.key = "test_key"
mock_db_session.query().where().first.return_value = upload_file
mock_db_session.scalar.return_value = upload_file
with patch("services.file_service.storage") as mock_storage:
mock_storage.load_once.return_value = b"test content"
@ -178,7 +178,7 @@ class TestFileService:
mock_storage.load_once.assert_called_once_with("test_key")
def test_get_file_base64_not_found(self, file_service, mock_db_session):
mock_db_session.query().where().first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(NotFound, match="File not found"):
file_service.get_file_base64("non_existent")
@ -215,7 +215,7 @@ class TestFileService:
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
upload_file.extension = "pdf"
mock_db_session.query().where().first.return_value = upload_file
mock_db_session.scalar.return_value = upload_file
with patch("services.file_service.ExtractProcessor.load_from_upload_file") as mock_extract:
mock_extract.return_value = "Extracted text content"
@ -227,7 +227,7 @@ class TestFileService:
assert result == "Extracted text content"
def test_get_file_preview_not_found(self, file_service, mock_db_session):
mock_db_session.query().where().first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(NotFound, match="File not found"):
file_service.get_file_preview("non_existent")
@ -235,7 +235,7 @@ class TestFileService:
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
upload_file.extension = "exe"
mock_db_session.query().where().first.return_value = upload_file
mock_db_session.scalar.return_value = upload_file
with pytest.raises(UnsupportedFileTypeError):
file_service.get_file_preview("file_id")
@ -246,7 +246,7 @@ class TestFileService:
upload_file.extension = "jpg"
upload_file.mime_type = "image/jpeg"
upload_file.key = "key"
mock_db_session.query().where().first.return_value = upload_file
mock_db_session.scalar.return_value = upload_file
with (
patch("services.file_service.file_helpers.verify_image_signature") as mock_verify,
@ -269,7 +269,7 @@ class TestFileService:
file_service.get_image_preview("file_id", "ts", "nonce", "sign")
def test_get_image_preview_not_found(self, file_service, mock_db_session):
mock_db_session.query().where().first.return_value = None
mock_db_session.scalar.return_value = None
with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify:
mock_verify.return_value = True
with pytest.raises(NotFound, match="File not found or signature is invalid"):
@ -279,7 +279,7 @@ class TestFileService:
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
upload_file.extension = "txt"
mock_db_session.query().where().first.return_value = upload_file
mock_db_session.scalar.return_value = upload_file
with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify:
mock_verify.return_value = True
with pytest.raises(UnsupportedFileTypeError):
@ -289,7 +289,7 @@ class TestFileService:
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
upload_file.key = "key"
mock_db_session.query().where().first.return_value = upload_file
mock_db_session.scalar.return_value = upload_file
with (
patch("services.file_service.file_helpers.verify_file_signature") as mock_verify,
@ -309,7 +309,7 @@ class TestFileService:
file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign")
def test_get_file_generator_by_file_id_not_found(self, file_service, mock_db_session):
mock_db_session.query().where().first.return_value = None
mock_db_session.scalar.return_value = None
with patch("services.file_service.file_helpers.verify_file_signature") as mock_verify:
mock_verify.return_value = True
with pytest.raises(NotFound, match="File not found or signature is invalid"):
@ -321,7 +321,7 @@ class TestFileService:
upload_file.extension = "png"
upload_file.mime_type = "image/png"
upload_file.key = "key"
mock_db_session.query().where().first.return_value = upload_file
mock_db_session.scalar.return_value = upload_file
with patch("services.file_service.storage") as mock_storage:
mock_storage.load.return_value = b"image content"
@ -330,7 +330,7 @@ class TestFileService:
assert mime == "image/png"
def test_get_public_image_preview_not_found(self, file_service, mock_db_session):
mock_db_session.query().where().first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(NotFound, match="File not found or signature is invalid"):
file_service.get_public_image_preview("file_id")
@ -338,7 +338,7 @@ class TestFileService:
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
upload_file.extension = "txt"
mock_db_session.query().where().first.return_value = upload_file
mock_db_session.scalar.return_value = upload_file
with pytest.raises(UnsupportedFileTypeError):
file_service.get_public_image_preview("file_id")
@ -346,7 +346,7 @@ class TestFileService:
upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id"
upload_file.key = "key"
mock_db_session.query().where().first.return_value = upload_file
mock_db_session.scalar.return_value = upload_file
with patch("services.file_service.storage") as mock_storage:
mock_storage.load.return_value = b"hello world"
@ -354,7 +354,7 @@ class TestFileService:
assert result == "hello world"
def test_get_file_content_not_found(self, file_service, mock_db_session):
mock_db_session.query().where().first.return_value = None
mock_db_session.scalar.return_value = None
with pytest.raises(NotFound, match="File not found"):
file_service.get_file_content("file_id")

View File

@ -657,7 +657,7 @@ def _app(**kwargs: Any) -> App:
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange
fake_session = MagicMock()
fake_session.query.return_value = _FakeQuery(None)
fake_session.scalar.return_value = None
_patch_session(monkeypatch, fake_session)
# Act / Assert
@ -671,7 +671,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_foun
# Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(None)]
fake_session.scalar.side_effect = [webhook_trigger, None]
_patch_session(monkeypatch, fake_session)
# Act / Assert
@ -686,7 +686,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_lim
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)]
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
_patch_session(monkeypatch, fake_session)
# Act / Assert
@ -701,7 +701,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)]
fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
_patch_session(monkeypatch, fake_session)
# Act / Assert
@ -714,7 +714,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(m
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(None)]
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None]
_patch_session(monkeypatch, fake_session)
# Act / Assert
@ -732,7 +732,7 @@ def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mod
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(workflow)]
fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow]
_patch_session(monkeypatch, fake_session)
# Act
@ -751,7 +751,7 @@ def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(mo
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(workflow)]
fake_session.scalar.side_effect = [webhook_trigger, workflow]
_patch_session(monkeypatch, fake_session)
# Act

View File

@ -2826,9 +2826,9 @@ class TestWorkflowServiceFreeNodeExecution:
variable_pool = MagicMock()
with (
patch("services.workflow_service.GraphInitParams") as mock_graph_init_params,
patch("services.workflow_service.DifyGraphInitContext") as mock_graph_init_context_cls,
patch("services.workflow_service.GraphRuntimeState"),
patch("services.workflow_service.build_dify_run_context"),
patch("services.workflow_service.build_dify_run_context") as mock_build_dify_run_context,
patch("services.workflow_service.DifyHumanInputNodeRuntime") as mock_runtime_cls,
patch("services.workflow_service.HumanInputNode") as mock_node_cls,
):
@ -2837,4 +2837,17 @@ class TestWorkflowServiceFreeNodeExecution:
)
assert node == mock_node_cls.return_value
mock_node_cls.assert_called_once()
mock_runtime_cls.assert_called_once_with(mock_graph_init_params.return_value.run_context)
mock_graph_init_context_cls.assert_called_once_with(
workflow_id="wf-1",
graph_config=workflow.graph_dict,
run_context=mock_build_dify_run_context.return_value,
call_depth=0,
)
mock_runtime_cls.assert_called_once_with(mock_build_dify_run_context.return_value)
mock_node_cls.assert_called_once_with(
id="n-1",
config=node_config,
graph_init_params=mock_graph_init_context_cls.return_value.to_graph_init_params.return_value,
graph_runtime_state=ANY,
runtime=mock_runtime_cls.return_value,
)

View File

@ -19,7 +19,57 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols
("[Google](https://google.com) is a search engine", "[Google](https://google.com) is a search engine"),
("[Example](http://example.com) some text", "[Example](http://example.com) some text"),
# Leading symbols before markdown link are removed, including the opening bracket [
("@[Test](https://example.com)", "Test](https://example.com)"),
("@[Test](https://example.com)", "[Test](https://example.com)"),
("~~标题~~", "标题~~"),
('""quoted', "quoted"),
("''test", "test"),
("##话题", "话题"),
("$$价格", "价格"),
("%%百分比", "百分比"),
("&&与逻辑", "与逻辑"),
("((括号))", "括号))"),
("**强调**", "强调**"),
("++自增", "自增"),
(",,逗号", "逗号"),
("..省略", "省略"),
("//注释", "注释"),
("::范围", "范围"),
(";;分号", "分号"),
("<<左移", "左移"),
("==等于", "等于"),
(">>右移", "右移"),
("??疑问", "疑问"),
("@@提及", "提及"),
("^^上标", "上标"),
("__下划线", "下划线"),
("``代码", "代码"),
("~~删除线", "删除线"),
(" 全角空格开头", "全角空格开头"),
("、顿号开头", "顿号开头"),
("。句号开头", "句号开头"),
("「引号」测试", "引号」测试"),
("『书名号』", "书名号』"),
("【保留】测试", "【保留】测试"),
("〖括号〗测试", "括号〗测试"),
("〔括号〕测试", "括号〕测试"),
("~~【保留】~~", "【保留】~~"),
('"[公告]"', '[公告]"'),
("[公告] 更新", "[公告] 更新"),
("【通知】重要", "【通知】重要"),
("[[嵌套]]", "[[嵌套]]"),
("【【嵌套】】", "【【嵌套】】"),
("[【混合】]", "[【混合】]"),
("normal text", "normal text"),
("123数字", "123数字"),
("中文开头", "中文开头"),
("alpha", "alpha"),
("~", ""),
("", ""),
("[", "["),
("~~~", ""),
("【【【", "【【【"),
("\t制表符", "\t制表符"),
("\n换行", "\n换行"),
],
)
def test_remove_leading_symbols(input_text, expected_output):

24
api/uv.lock generated
View File

@ -1585,7 +1585,7 @@ dev = [
{ name = "lxml-stubs", specifier = "~=0.5.1" },
{ name = "mypy", specifier = "~=1.20.0" },
{ name = "pandas-stubs", specifier = "~=3.0.0" },
{ name = "pyrefly", specifier = ">=0.59.1" },
{ name = "pyrefly", specifier = ">=0.60.0" },
{ name = "pytest", specifier = "~=9.0.2" },
{ name = "pytest-benchmark", specifier = "~=5.2.3" },
{ name = "pytest-cov", specifier = "~=7.1.0" },
@ -4850,19 +4850,19 @@ wheels = [
[[package]]
name = "pyrefly"
version = "0.59.1"
version = "0.60.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d5/ce/7882c2af92b2ff6505fcd3430eff8048ece6c6254cc90bdc76ecee12dfab/pyrefly-0.59.1.tar.gz", hash = "sha256:bf1675b0c38d45df2c8f8618cbdfa261a1b92430d9d31eba16e0282b551e210f", size = 5475432, upload-time = "2026-04-01T22:04:04.11Z" }
sdist = { url = "https://files.pythonhosted.org/packages/c6/c7/28d14b64888e2d03815627ebff8d57a9f08389c4bbebfe70ae1ed98a1267/pyrefly-0.60.0.tar.gz", hash = "sha256:2499f5b6ff5342e86dfe1cd94bcce133519bbbc93b7ad5636195fea4f0fa3b81", size = 5500389, upload-time = "2026-04-06T19:57:30.643Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d0/10/04a0e05b08fc855b6fe38c3df549925fc3c2c6e750506870de7335d3e1f7/pyrefly-0.59.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:390db3cd14aa7e0268e847b60cd9ee18b04273eddfa38cf341ed3bb43f3fef2a", size = 12868133, upload-time = "2026-04-01T22:03:39.436Z" },
{ url = "https://files.pythonhosted.org/packages/c7/78/fa7be227c3e3fcacee501c1562278dd026186ffd1b5b5beb51d3941a3aed/pyrefly-0.59.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d246d417b6187c1650d7f855f61c68fbfd6d6155dc846d4e4d273a3e6b5175cb", size = 12379325, upload-time = "2026-04-01T22:03:42.046Z" },
{ url = "https://files.pythonhosted.org/packages/bb/13/6828ce1c98171b5f8388f33c4b0b9ea2ab8c49abe0ef8d793c31e30a05cb/pyrefly-0.59.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:575ac67b04412dc651a7143d27e38a40fbdd3c831c714d5520d0e9d4c8631ab4", size = 35826408, upload-time = "2026-04-01T22:03:45.067Z" },
{ url = "https://files.pythonhosted.org/packages/23/56/79ed8ece9a7ecad0113c394a06a084107db3ad8f1fefe19e7ded43c51245/pyrefly-0.59.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:062e6262ce1064d59dcad81ac0499bb7a3ad501e9bc8a677a50dc630ff0bf862", size = 38532699, upload-time = "2026-04-01T22:03:48.376Z" },
{ url = "https://files.pythonhosted.org/packages/18/7d/ecc025e0f0e3f295b497f523cc19cefaa39e57abede8fc353d29445d174b/pyrefly-0.59.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:43ef4247f9e6f734feb93e1f2b75335b943629956e509f545cc9cdcccd76dd20", size = 36743570, upload-time = "2026-04-01T22:03:51.362Z" },
{ url = "https://files.pythonhosted.org/packages/2f/03/b1ce882ebcb87c673165c00451fbe4df17bf96ccfde18c75880dc87c5f5e/pyrefly-0.59.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59a2d01723b84d042f4fa6ec871ffd52d0a7e83b0ea791c2e0bb0ff750abce56", size = 41236246, upload-time = "2026-04-01T22:03:54.361Z" },
{ url = "https://files.pythonhosted.org/packages/17/af/5e9c7afd510e7dd64a2204be0ed39e804089cbc4338675a28615c7176acb/pyrefly-0.59.1-py3-none-win32.whl", hash = "sha256:4ea70c780848f8376411e787643ae5d2d09da8a829362332b7b26d15ebcbaf56", size = 11884747, upload-time = "2026-04-01T22:03:56.776Z" },
{ url = "https://files.pythonhosted.org/packages/aa/c1/7db1077627453fd1068f0761f059a9512645c00c4c20acfb9f0c24ac02ec/pyrefly-0.59.1-py3-none-win_amd64.whl", hash = "sha256:67e6a08cfd129a0d2788d5e40a627f9860e0fe91a876238d93d5c63ff4af68ae", size = 12720608, upload-time = "2026-04-01T22:03:59.252Z" },
{ url = "https://files.pythonhosted.org/packages/07/16/4bb6e5fce5a9cf0992932d9435d964c33e507aaaf96fdfbb1be493078a4a/pyrefly-0.59.1-py3-none-win_arm64.whl", hash = "sha256:01179cb215cf079e8223a064f61a074f7079aa97ea705cbbc68af3d6713afd15", size = 12223158, upload-time = "2026-04-01T22:04:01.869Z" },
{ url = "https://files.pythonhosted.org/packages/31/99/6c9984a09220e5eb7dd5c869b7a32d25c3d06b5e8854c6eb679db1145c3e/pyrefly-0.60.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:bf1691af0fee69d0c99c3c6e9d26ab6acd3c8afef96416f9ba2e74934833b7b5", size = 12921262, upload-time = "2026-04-06T19:57:00.745Z" },
{ url = "https://files.pythonhosted.org/packages/05/b3/6216aa3c00c88e59a27eb4149851b5affe86eeea6129f4224034a32dddb0/pyrefly-0.60.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:3e71b70c9b95545cf3b479bc55d1381b531de7b2380eb64411088a1e56b634cb", size = 12424413, upload-time = "2026-04-06T19:57:03.417Z" },
{ url = "https://files.pythonhosted.org/packages/9b/87/eb8dd73abd92a93952ac27a605e463c432fb250fb23186574038c7035594/pyrefly-0.60.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:680ee5f8f98230ea145652d7344708f5375786209c5bf03d8b911fdb0d0d4195", size = 35940884, upload-time = "2026-04-06T19:57:06.909Z" },
{ url = "https://files.pythonhosted.org/packages/0d/34/dc6aeb67b840c745fcee6db358295d554abe6ab555a7eaaf44624bd80bf1/pyrefly-0.60.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d0b20dbbe4aff15b959e8d825b7521a144c4122c11e57022e83b36568c54470", size = 38677220, upload-time = "2026-04-06T19:57:11.235Z" },
{ url = "https://files.pythonhosted.org/packages/66/6b/c863fcf7ef592b7d1db91502acf0d1113be8bed7a2a7143fc6f0dd90616f/pyrefly-0.60.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2911563c8e6b2eaefff68885c94727965469a35375a409235a7a4d2b7157dc15", size = 36907431, upload-time = "2026-04-06T19:57:15.074Z" },
{ url = "https://files.pythonhosted.org/packages/8e/a2/25ea095ab2ecca8e62884669b11a79f14299db93071685b73a97efbaf4f3/pyrefly-0.60.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0a631d9d04705e303fe156f2e62551611bc7ef8066c34708ceebcfb3088bd55", size = 41447898, upload-time = "2026-04-06T19:57:19.382Z" },
{ url = "https://files.pythonhosted.org/packages/8e/2c/097bdc6e8d40676b28eb03710a4577bc3c7b803cd24693ac02bf15de3d67/pyrefly-0.60.0-py3-none-win32.whl", hash = "sha256:a08d69298da5626cf502d3debbb6944fd13d2f405ea6625363751f1ff570d366", size = 11913434, upload-time = "2026-04-06T19:57:22.887Z" },
{ url = "https://files.pythonhosted.org/packages/0a/d4/8d27fe310e830c8d11ab73db38b93f9fd2e218744b6efb1204401c9a74d5/pyrefly-0.60.0-py3-none-win_amd64.whl", hash = "sha256:56cf30654e708ae1dd635ffefcba4fa4b349dd7004a6ccc5c41e3a9bb944320c", size = 12745033, upload-time = "2026-04-06T19:57:25.517Z" },
{ url = "https://files.pythonhosted.org/packages/1f/ad/8eea1f8fb8209f91f6dbfe48000c9d05fd0cdb1b5b3157283c9b1dada55d/pyrefly-0.60.0-py3-none-win_arm64.whl", hash = "sha256:b6d27fba970f4777063c0227c54167d83bece1804ea34f69e7118e409ba038d2", size = 12246390, upload-time = "2026-04-06T19:57:28.141Z" },
]
[[package]]

13
pnpm-lock.yaml generated
View File

@ -520,8 +520,8 @@ catalogs:
specifier: 13.0.0
version: 13.0.0
vinext:
specifier: https://pkg.pr.new/vinext@adbf24d
version: 0.0.5
specifier: 0.0.41
version: 0.0.41
vite-plugin-inspect:
specifier: 12.0.0-beta.1
version: 12.0.0-beta.1
@ -1162,7 +1162,7 @@ importers:
version: 3.19.3
vinext:
specifier: 'catalog:'
version: https://pkg.pr.new/vinext@adbf24d(@mdx-js/rollup@3.1.1(rollup@4.59.0))(@vitejs/plugin-react@6.0.1(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)))(@vitejs/plugin-rsc@0.5.23(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5))(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(next@16.2.3(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(sass@1.98.0))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5)(typescript@6.0.2)
version: 0.0.41(@mdx-js/rollup@3.1.1(rollup@4.59.0))(@vitejs/plugin-react@6.0.1(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)))(@vitejs/plugin-rsc@0.5.23(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5))(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(next@16.2.3(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(sass@1.98.0))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5)(typescript@6.0.2)
vite:
specifier: npm:@voidzero-dev/vite-plus-core@0.1.16
version: '@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)'
@ -8336,9 +8336,8 @@ packages:
vfile@6.0.3:
resolution: {integrity: sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==}
vinext@https://pkg.pr.new/vinext@adbf24d:
resolution: {tarball: https://pkg.pr.new/vinext@adbf24d}
version: 0.0.5
vinext@0.0.41:
resolution: {integrity: sha512-fpQjNp6cIqjYGH2/kbhN2SdIYHEu79RdlII23SWsY1Qp7LM+je8GfTJH1sxw6dASxPhZKZB/jCmTm5d2/D25zw==}
engines: {node: '>=22'}
hasBin: true
peerDependencies:
@ -16586,7 +16585,7 @@ snapshots:
'@types/unist': 3.0.3
vfile-message: 4.0.3
vinext@https://pkg.pr.new/vinext@adbf24d(@mdx-js/rollup@3.1.1(rollup@4.59.0))(@vitejs/plugin-react@6.0.1(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)))(@vitejs/plugin-rsc@0.5.23(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5))(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(next@16.2.3(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(sass@1.98.0))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5)(typescript@6.0.2):
vinext@0.0.41(@mdx-js/rollup@3.1.1(rollup@4.59.0))(@vitejs/plugin-react@6.0.1(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)))(@vitejs/plugin-rsc@0.5.23(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5))(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(next@16.2.3(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(sass@1.98.0))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5)(typescript@6.0.2):
dependencies:
'@unpic/react': 1.0.2(next@16.2.3(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(sass@1.98.0))(react-dom@19.2.5(react@19.2.5))(react@19.2.5)
'@vercel/og': 0.8.6

View File

@ -221,7 +221,7 @@ catalog:
unist-util-visit: 5.1.0
use-context-selector: 2.0.0
uuid: 13.0.0
vinext: https://pkg.pr.new/vinext@adbf24d
vinext: 0.0.41
vite: npm:@voidzero-dev/vite-plus-core@0.1.16
vite-plugin-inspect: 12.0.0-beta.1
vite-plus: 0.1.16

View File

@ -229,6 +229,23 @@ describe('output-schema-utils', () => {
})
})
describe('Dify compact types (workflow-as-tool output_schema)', () => {
it('should resolve array[string] to arrayString (issue #34428)', () => {
const result = resolveVarType({ type: 'array[string]' })
expect(result.type).toBe(VarType.arrayString)
})
it('should resolve Array[string] case-insensitively', () => {
const result = resolveVarType({ type: 'Array[string]' })
expect(result.type).toBe(VarType.arrayString)
})
it('should resolve array[object] to arrayObject', () => {
const result = resolveVarType({ type: 'array[object]' })
expect(result.type).toBe(VarType.arrayObject)
})
})
describe('unknown types', () => {
it('should resolve unknown type to any', () => {
const result = resolveVarType({ type: 'unknown_type' })

View File

@ -2,6 +2,30 @@ import type { SchemaTypeDefinition } from '@/service/use-common'
import { VarType } from '@/app/components/workflow/types'
import { getMatchedSchemaType } from '../_base/components/variable/use-match-schema-type'
/**
* Workflow-as-tool and some internal APIs store Dify VarType strings (e.g. `array[string]`)
* in JSON Schema `type` instead of standard `{ type: 'array', items: { type: 'string' } }`.
* Map those compact strings to VarType so downstream (e.g. Code node var picker) does not
* fall back to `any` and get filtered out.
*/
const resolveDifyCompactTypeString = (typeStr: string): VarType | undefined => {
const trimmed = typeStr.trim()
const m = /^array\[(string|number|integer|boolean|object|file|any)\]$/i.exec(trimmed)
if (!m)
return undefined
const inner = m[1].toLowerCase()
const map: Record<string, VarType> = {
string: VarType.arrayString,
number: VarType.arrayNumber,
integer: VarType.arrayNumber,
boolean: VarType.arrayBoolean,
object: VarType.arrayObject,
file: VarType.arrayFile,
any: VarType.arrayAny,
}
return map[inner]
}
/**
* Normalizes a JSON Schema type to a simple string type.
* Handles complex schemas with oneOf, anyOf, allOf.
@ -54,6 +78,12 @@ export const resolveVarType = (
schemaTypeDefinitions?: SchemaTypeDefinition[],
): { type: VarType, schemaType?: string } => {
const schemaType = getMatchedSchemaType(schema, schemaTypeDefinitions)
if (schema && typeof schema.type === 'string') {
const compact = resolveDifyCompactTypeString(schema.type)
if (compact !== undefined)
return { type: compact, schemaType }
}
const normalizedType = normalizeJsonSchemaType(schema)
switch (normalizedType) {