Merge remote-tracking branch 'myori/main' into feat/collaboration2

This commit is contained in:
hjlarry 2026-04-10 09:41:47 +08:00
commit 59e752dcd3
38 changed files with 774 additions and 474 deletions

View File

@ -346,89 +346,6 @@ class PublishedRagPipelineRunApi(Resource):
raise InvokeRateLimitHttpError(ex.description) raise InvokeRateLimitHttpError(ex.description)
# class RagPipelinePublishedDatasourceNodeRunStatusApi(Resource):
# @setup_required
# @login_required
# @account_initialization_required
# @get_rag_pipeline
# def post(self, pipeline: Pipeline, node_id: str):
# """
# Run rag pipeline datasource
# """
# # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.has_edit_permission:
# raise Forbidden()
#
# if not isinstance(current_user, Account):
# raise Forbidden()
#
# parser = (reqparse.RequestParser()
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
# .add_argument("datasource_type", type=str, required=True, location="json")
# )
# args = parser.parse_args()
#
# job_id = args.get("job_id")
# if job_id == None:
# raise ValueError("missing job_id")
# datasource_type = args.get("datasource_type")
# if datasource_type == None:
# raise ValueError("missing datasource_type")
#
# rag_pipeline_service = RagPipelineService()
# result = rag_pipeline_service.run_datasource_workflow_node_status(
# pipeline=pipeline,
# node_id=node_id,
# job_id=job_id,
# account=current_user,
# datasource_type=datasource_type,
# is_published=True
# )
#
# return result
# class RagPipelineDraftDatasourceNodeRunStatusApi(Resource):
# @setup_required
# @login_required
# @account_initialization_required
# @get_rag_pipeline
# def post(self, pipeline: Pipeline, node_id: str):
# """
# Run rag pipeline datasource
# """
# # The role of the current user in the ta table must be admin, owner, or editor
# if not current_user.has_edit_permission:
# raise Forbidden()
#
# if not isinstance(current_user, Account):
# raise Forbidden()
#
# parser = (reqparse.RequestParser()
# .add_argument("job_id", type=str, required=True, nullable=False, location="json")
# .add_argument("datasource_type", type=str, required=True, location="json")
# )
# args = parser.parse_args()
#
# job_id = args.get("job_id")
# if job_id == None:
# raise ValueError("missing job_id")
# datasource_type = args.get("datasource_type")
# if datasource_type == None:
# raise ValueError("missing datasource_type")
#
# rag_pipeline_service = RagPipelineService()
# result = rag_pipeline_service.run_datasource_workflow_node_status(
# pipeline=pipeline,
# node_id=node_id,
# job_id=job_id,
# account=current_user,
# datasource_type=datasource_type,
# is_published=False
# )
#
# return result
#
@console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run") @console_ns.route("/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/run")
class RagPipelinePublishedDatasourceNodeRunApi(Resource): class RagPipelinePublishedDatasourceNodeRunApi(Resource):
@console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__]) @console_ns.expect(console_ns.models[DatasourceNodeRunPayload.__name__])

View File

@ -7,7 +7,8 @@ import logging
from collections.abc import Generator from collections.abc import Generator
from flask import Response, jsonify, request from flask import Response, jsonify, request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
@ -33,6 +34,11 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
def _jsonify_form_definition(form: Form) -> Response: def _jsonify_form_definition(form: Form) -> Response:
payload = form.get_definition().model_dump() payload = form.get_definition().model_dump()
payload["expiration_time"] = int(form.expiration_time.timestamp()) payload["expiration_time"] = int(form.expiration_time.timestamp())
@ -84,10 +90,7 @@ class ConsoleHumanInputFormApi(Resource):
"action": "Approve" "action": "Approve"
} }
""" """
parser = reqparse.RequestParser() payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("action", type=str, required=True, location="json")
args = parser.parse_args()
current_user, _ = current_account_with_tenant() current_user, _ = current_account_with_tenant()
service = HumanInputService(db.engine) service = HumanInputService(db.engine)
@ -107,8 +110,8 @@ class ConsoleHumanInputFormApi(Resource):
service.submit_form_by_token( service.submit_form_by_token(
recipient_type=recipient_type, recipient_type=recipient_type,
form_token=form_token, form_token=form_token,
selected_action_id=args["action"], selected_action_id=payload.action,
form_data=args["inputs"], form_data=payload.inputs,
submission_user_id=current_user.id, submission_user_id=current_user.id,
) )

View File

@ -7,7 +7,8 @@ import logging
from datetime import datetime from datetime import datetime
from flask import Response, request from flask import Response, request
from flask_restx import Resource, reqparse from flask_restx import Resource
from pydantic import BaseModel
from sqlalchemy import select from sqlalchemy import select
from werkzeug.exceptions import Forbidden from werkzeug.exceptions import Forbidden
@ -23,6 +24,12 @@ from services.human_input_service import Form, FormNotFoundError, HumanInputServ
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class HumanInputFormSubmitPayload(BaseModel):
inputs: dict
action: str
_FORM_SUBMIT_RATE_LIMITER = RateLimiter( _FORM_SUBMIT_RATE_LIMITER = RateLimiter(
prefix="web_form_submit_rate_limit", prefix="web_form_submit_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS, max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
@ -112,10 +119,7 @@ class HumanInputFormApi(Resource):
"action": "Approve" "action": "Approve"
} }
""" """
parser = reqparse.RequestParser() payload = HumanInputFormSubmitPayload.model_validate(request.get_json())
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("action", type=str, required=True, location="json")
args = parser.parse_args()
ip_address = extract_remote_ip(request) ip_address = extract_remote_ip(request)
if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address): if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address):
@ -135,8 +139,8 @@ class HumanInputFormApi(Resource):
service.submit_form_by_token( service.submit_form_by_token(
recipient_type=recipient_type, recipient_type=recipient_type,
form_token=form_token, form_token=form_token,
selected_action_id=args["action"], selected_action_id=payload.action,
form_data=args["inputs"], form_data=payload.inputs,
submission_end_user_id=None, submission_end_user_id=None,
# submission_end_user_id=_end_user.id, # submission_end_user_id=_end_user.id,
) )

View File

@ -2,7 +2,6 @@ import logging
import time import time
from typing import cast from typing import cast
from graphon.entities import GraphInitParams
from graphon.enums import WorkflowType from graphon.enums import WorkflowType
from graphon.graph import Graph from graphon.graph import Graph
from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent
@ -22,7 +21,7 @@ from core.app.entities.app_invoke_entities import (
) )
from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id from core.workflow.node_factory import DifyGraphInitContext, DifyNodeFactory, get_default_root_node_id
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables from core.workflow.system_variables import build_bootstrap_variables, build_system_variables
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
@ -265,22 +264,23 @@ class PipelineRunner(WorkflowBasedAppRunner):
# graph_config["nodes"] = real_run_nodes # graph_config["nodes"] = real_run_nodes
# graph_config["edges"] = real_edges # graph_config["edges"] = real_edges
# init graph # init graph
# Create required parameters for Graph.init # Create explicit graph init context for Graph.init.
graph_init_params = GraphInitParams( run_context = build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
)
graph_init_context = DifyGraphInitContext(
workflow_id=workflow.id, workflow_id=workflow.id,
graph_config=graph_config, graph_config=graph_config,
run_context=build_dify_run_context( run_context=run_context,
tenant_id=workflow.tenant_id,
app_id=self._app_id,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
),
call_depth=0, call_depth=0,
) )
node_factory = DifyNodeFactory( node_factory = DifyNodeFactory.from_graph_init_context(
graph_init_params=graph_init_params, graph_init_context=graph_init_context,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
) )
if start_node_id is None: if start_node_id is None:

View File

@ -3,7 +3,6 @@ import time
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, cast from typing import Any, cast
from graphon.entities import GraphInitParams
from graphon.entities.graph_config import NodeConfigDictAdapter from graphon.entities.graph_config import NodeConfigDictAdapter
from graphon.entities.pause_reason import HumanInputRequired from graphon.entities.pause_reason import HumanInputRequired
from graphon.graph import Graph from graphon.graph import Graph
@ -67,7 +66,12 @@ from core.app.entities.queue_entities import (
QueueWorkflowSucceededEvent, QueueWorkflowSucceededEvent,
) )
from core.rag.entities import RetrievalSourceMetadata from core.rag.entities import RetrievalSourceMetadata
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class from core.workflow.node_factory import (
DifyGraphInitContext,
DifyNodeFactory,
get_default_root_node_id,
resolve_workflow_node_class,
)
from core.workflow.system_variables import ( from core.workflow.system_variables import (
build_bootstrap_variables, build_bootstrap_variables,
default_system_variables, default_system_variables,
@ -127,24 +131,25 @@ class WorkflowBasedAppRunner:
if not isinstance(graph_config.get("edges"), list): if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list") raise ValueError("edges in workflow graph must be a list")
# Create required parameters for Graph.init # Create explicit graph init context for Graph.init.
graph_init_params = GraphInitParams( run_context = build_dify_run_context(
tenant_id=tenant_id or "",
app_id=self._app_id,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
)
graph_init_context = DifyGraphInitContext(
workflow_id=workflow_id, workflow_id=workflow_id,
graph_config=graph_config, graph_config=graph_config,
run_context=build_dify_run_context( run_context=run_context,
tenant_id=tenant_id or "",
app_id=self._app_id,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
),
call_depth=0, call_depth=0,
) )
# Use the provided graph_runtime_state for consistent state management # Use the provided graph_runtime_state for consistent state management
node_factory = DifyNodeFactory( node_factory = DifyNodeFactory.from_graph_init_context(
graph_init_params=graph_init_params, graph_init_context=graph_init_context,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
) )
@ -289,22 +294,23 @@ class WorkflowBasedAppRunner:
typed_node_configs = [NodeConfigDictAdapter.validate_python(node) for node in node_configs] typed_node_configs = [NodeConfigDictAdapter.validate_python(node) for node in node_configs]
# Create required parameters for Graph.init # Create explicit graph init context for Graph.init.
graph_init_params = GraphInitParams( run_context = build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=self._app_id,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
graph_init_context = DifyGraphInitContext(
workflow_id=workflow.id, workflow_id=workflow.id,
graph_config=graph_config, graph_config=graph_config,
run_context=build_dify_run_context( run_context=run_context,
tenant_id=workflow.tenant_id,
app_id=self._app_id,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
),
call_depth=0, call_depth=0,
) )
node_factory = DifyNodeFactory( node_factory = DifyNodeFactory.from_graph_init_context(
graph_init_params=graph_init_params, graph_init_context=graph_init_context,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
) )

View File

@ -146,7 +146,7 @@ def discover_protected_resource_metadata(
return ProtectedResourceMetadata.model_validate(response.json()) return ProtectedResourceMetadata.model_validate(response.json())
elif response.status_code == 404: elif response.status_code == 404:
continue # Try next URL continue # Try next URL
except (RequestError, ValidationError): except (RequestError, ValidationError, json.JSONDecodeError):
continue # Try next URL continue # Try next URL
return None return None
@ -166,7 +166,7 @@ def discover_oauth_authorization_server_metadata(
return OAuthMetadata.model_validate(response.json()) return OAuthMetadata.model_validate(response.json())
elif response.status_code == 404: elif response.status_code == 404:
continue # Try next URL continue # Try next URL
except (RequestError, ValidationError): except (RequestError, ValidationError, json.JSONDecodeError):
continue # Try next URL continue # Try next URL
return None return None
@ -276,7 +276,7 @@ def check_support_resource_discovery(server_url: str) -> tuple[bool, str]:
else: else:
return False, "" return False, ""
return False, "" return False, ""
except RequestError: except (RequestError, json.JSONDecodeError, IndexError):
# Not support resource discovery, fall back to well-known OAuth metadata # Not support resource discovery, fall back to well-known OAuth metadata
return False, "" return False, ""

View File

@ -61,27 +61,28 @@ class TokenBufferMemory:
:param is_user_message: whether this is a user message :param is_user_message: whether this is a user message
:return: PromptMessage :return: PromptMessage
""" """
if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}: match self.conversation.mode:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) case AppMode.AGENT_CHAT | AppMode.COMPLETION | AppMode.CHAT:
elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
app = self.conversation.app case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW:
if not app: app = self.conversation.app
raise ValueError("App not found for conversation") if not app:
raise ValueError("App not found for conversation")
if not message.workflow_run_id: if not message.workflow_run_id:
raise ValueError("Workflow run ID not found") raise ValueError("Workflow run ID not found")
workflow_run = self.workflow_run_repo.get_workflow_run_by_id( workflow_run = self.workflow_run_repo.get_workflow_run_by_id(
tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id tenant_id=app.tenant_id, app_id=app.id, run_id=message.workflow_run_id
) )
if not workflow_run: if not workflow_run:
raise ValueError(f"Workflow run not found: {message.workflow_run_id}") raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id)) workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
if not workflow: if not workflow:
raise ValueError(f"Workflow not found: {workflow_run.workflow_id}") raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
else: case _:
raise AssertionError(f"Invalid app mode: {self.conversation.mode}") raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
detail = ImagePromptMessageContent.DETAIL.HIGH detail = ImagePromptMessageContent.DETAIL.HIGH
if file_extra_config and app_record: if file_extra_config and app_record:

View File

@ -5,6 +5,7 @@ from typing import Any
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from typing_extensions import TypedDict
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
@ -19,6 +20,16 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class HuaweiElasticsearchParamsDict(TypedDict, total=False):
hosts: list[str]
verify_certs: bool
ssl_show_warn: bool
request_timeout: int
retry_on_timeout: bool
max_retries: int
basic_auth: tuple[str, str]
def create_ssl_context() -> ssl.SSLContext: def create_ssl_context() -> ssl.SSLContext:
ssl_context = ssl.create_default_context() ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False ssl_context.check_hostname = False
@ -38,15 +49,15 @@ class HuaweiCloudVectorConfig(BaseModel):
raise ValueError("config HOSTS is required") raise ValueError("config HOSTS is required")
return values return values
def to_elasticsearch_params(self) -> dict[str, Any]: def to_elasticsearch_params(self) -> HuaweiElasticsearchParamsDict:
params = { params = HuaweiElasticsearchParamsDict(
"hosts": self.hosts.split(","), hosts=self.hosts.split(","),
"verify_certs": False, verify_certs=False,
"ssl_show_warn": False, ssl_show_warn=False,
"request_timeout": 30000, request_timeout=30000,
"retry_on_timeout": True, retry_on_timeout=True,
"max_retries": 10, max_retries=10,
} )
if self.username and self.password: if self.username and self.password:
params["basic_auth"] = (self.username, self.password) params["basic_auth"] = (self.username, self.password)
return params return params

View File

@ -7,6 +7,7 @@ from opensearchpy import OpenSearch, helpers
from opensearchpy.helpers import BulkIndexError from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from tenacity import retry, stop_after_attempt, wait_exponential from tenacity import retry, stop_after_attempt, wait_exponential
from typing_extensions import TypedDict
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.field import Field from core.rag.datasource.vdb.field import Field
@ -26,6 +27,14 @@ ROUTING_FIELD = "routing_field"
UGC_INDEX_PREFIX = "ugc_index" UGC_INDEX_PREFIX = "ugc_index"
class LindormOpenSearchParamsDict(TypedDict, total=False):
hosts: str | None
use_ssl: bool
pool_maxsize: int
timeout: int
http_auth: tuple[str, str]
class LindormVectorStoreConfig(BaseModel): class LindormVectorStoreConfig(BaseModel):
hosts: str | None hosts: str | None
username: str | None = None username: str | None = None
@ -44,13 +53,13 @@ class LindormVectorStoreConfig(BaseModel):
raise ValueError("config PASSWORD is required") raise ValueError("config PASSWORD is required")
return values return values
def to_opensearch_params(self) -> dict[str, Any]: def to_opensearch_params(self) -> LindormOpenSearchParamsDict:
params: dict[str, Any] = { params = LindormOpenSearchParamsDict(
"hosts": self.hosts, hosts=self.hosts,
"use_ssl": False, use_ssl=False,
"pool_maxsize": 128, pool_maxsize=128,
"timeout": 30, timeout=30,
} )
if self.username and self.password: if self.username and self.password:
params["http_auth"] = (self.username, self.password) params["http_auth"] = (self.username, self.password)
return params return params

View File

@ -6,6 +6,7 @@ from uuid import uuid4
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from opensearchpy.helpers import BulkIndexError from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator from pydantic import BaseModel, model_validator
from typing_extensions import TypedDict
from configs import dify_config from configs import dify_config
from configs.middleware.vdb.opensearch_config import AuthMethod from configs.middleware.vdb.opensearch_config import AuthMethod
@ -21,6 +22,20 @@ from models.dataset import Dataset
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class _OpenSearchHostDict(TypedDict):
host: str
port: int
class OpenSearchParamsDict(TypedDict, total=False):
hosts: list[_OpenSearchHostDict]
use_ssl: bool
verify_certs: bool
connection_class: type
pool_maxsize: int
http_auth: tuple[str | None, str | None] | Urllib3AWSV4SignerAuth
class OpenSearchConfig(BaseModel): class OpenSearchConfig(BaseModel):
host: str host: str
port: int port: int
@ -57,14 +72,14 @@ class OpenSearchConfig(BaseModel):
service=self.aws_service, # type: ignore[arg-type] service=self.aws_service, # type: ignore[arg-type]
) )
def to_opensearch_params(self) -> dict[str, Any]: def to_opensearch_params(self) -> OpenSearchParamsDict:
params = { params = OpenSearchParamsDict(
"hosts": [{"host": self.host, "port": self.port}], hosts=[{"host": self.host, "port": self.port}],
"use_ssl": self.secure, use_ssl=self.secure,
"verify_certs": self.verify_certs, verify_certs=self.verify_certs,
"connection_class": Urllib3HttpConnection, connection_class=Urllib3HttpConnection,
"pool_maxsize": 20, pool_maxsize=20,
} )
if self.auth_method == "basic": if self.auth_method == "basic":
logger.info("Using basic authentication for OpenSearch Vector DB") logger.info("Using basic authentication for OpenSearch Vector DB")

View File

@ -19,5 +19,18 @@ def remove_leading_symbols(text: str) -> str:
# Match Unicode ranges for punctuation and symbols # Match Unicode ranges for punctuation and symbols
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later # FIXME this pattern is confused quick fix for #11868 maybe refactor it later
pattern = r'^[\[\]\u2000-\u2025\u2027-\u206F\u2E00-\u2E7F\u3000-\u300F\u3011-\u303F"#$%&\'()*+,./:;<=>?@^_`~]+' pattern = re.compile(
r"""
^
(?:
[\u2000-\u2025] # General Punctuation: spaces, quotes, dashes
| [\u2027-\u206F] # General Punctuation: ellipsis, underscores, etc.
| [\u2E00-\u2E7F] # Supplemental Punctuation: medieval, ancient marks
| [\u3000-\u300F] # CJK Punctuation: 、。〃「」『》』 (excludes 【】)
| [\u3012-\u303F] # CJK Punctuation: 〖〗〔〕〘〙〚〛〜 etc.
| ["#$%&'()*+,./:;<=>?@^_`~] # ASCII punctuation (excludes []【】)
)+
""",
re.VERBOSE,
)
return re.sub(pattern, "", text) return re.sub(pattern, "", text)

View File

@ -1,6 +1,7 @@
import importlib import importlib
import pkgutil import pkgutil
from collections.abc import Callable, Iterator, Mapping, MutableMapping from collections.abc import Callable, Iterator, Mapping, MutableMapping
from dataclasses import dataclass
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING, Any, cast, final, override from typing import TYPE_CHECKING, Any, cast, final, override
@ -67,6 +68,31 @@ _START_NODE_TYPES: frozenset[NodeType] = frozenset(
) )
@dataclass(frozen=True, slots=True)
class DifyGraphInitContext:
"""Explicit graph-init values owned by the workflow layer.
Dify is gradually removing direct `GraphInitParams` construction from its
production call sites. Keep the translation here until `graphon` exposes an
equivalent explicit API.
"""
workflow_id: str
graph_config: Mapping[str, Any]
run_context: Mapping[str, Any]
call_depth: int
def to_graph_init_params(self) -> "GraphInitParams":
from graphon.entities import GraphInitParams
return GraphInitParams(
workflow_id=self.workflow_id,
graph_config=self.graph_config,
run_context=self.run_context,
call_depth=self.call_depth,
)
def _import_node_package(package_name: str, *, excluded_modules: frozenset[str] = frozenset()) -> None: def _import_node_package(package_name: str, *, excluded_modules: frozenset[str] = frozenset()) -> None:
package = importlib.import_module(package_name) package = importlib.import_module(package_name)
for _, module_name, _ in pkgutil.walk_packages(package.__path__, package.__name__ + "."): for _, module_name, _ in pkgutil.walk_packages(package.__path__, package.__name__ + "."):
@ -237,6 +263,19 @@ class DifyNodeFactory(NodeFactory):
Default implementation of NodeFactory that resolves node classes from the live registry. Default implementation of NodeFactory that resolves node classes from the live registry.
""" """
@classmethod
def from_graph_init_context(
cls,
*,
graph_init_context: DifyGraphInitContext,
graph_runtime_state: "GraphRuntimeState",
) -> "DifyNodeFactory":
"""Bridge Dify's explicit init context into the current `graphon` API."""
return cls(
graph_init_params=graph_init_context.to_graph_init_params(),
graph_runtime_state=graph_runtime_state,
)
def __init__( def __init__(
self, self,
graph_init_params: "GraphInitParams", graph_init_params: "GraphInitParams",

View File

@ -29,7 +29,7 @@ class TriggerWebhookNode(Node[WebhookData]):
def post_init(self) -> None: def post_init(self) -> None:
from core.workflow.node_runtime import DifyFileReferenceFactory from core.workflow.node_runtime import DifyFileReferenceFactory
self._file_reference_factory = DifyFileReferenceFactory(self.graph_init_params.run_context) self._file_reference_factory = DifyFileReferenceFactory(self.run_context)
@classmethod @classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:

View File

@ -24,7 +24,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_di
from core.app.file_access import DatabaseFileAccessController from core.app.file_access import DatabaseFileAccessController
from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.layers.observability import ObservabilityLayer from core.app.workflow.layers.observability import ObservabilityLayer
from core.workflow.node_factory import DifyNodeFactory, is_start_node_type, resolve_workflow_node_class from core.workflow.node_factory import (
DifyGraphInitContext,
DifyNodeFactory,
is_start_node_type,
resolve_workflow_node_class,
)
from core.workflow.system_variables import ( from core.workflow.system_variables import (
default_system_variables, default_system_variables,
get_node_creation_preload_selectors, get_node_creation_preload_selectors,
@ -251,17 +256,18 @@ class WorkflowEntry:
node_version = str(node_config_data.version) node_version = str(node_config_data.version)
node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version) node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version)
# init graph init params and runtime state # init graph context and runtime state
graph_init_params = GraphInitParams( run_context = build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
graph_init_context = DifyGraphInitContext(
workflow_id=workflow.id, workflow_id=workflow.id,
graph_config=workflow.graph_dict, graph_config=workflow.graph_dict,
run_context=build_dify_run_context( run_context=run_context,
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
),
call_depth=0, call_depth=0,
) )
graph_runtime_state = GraphRuntimeState( graph_runtime_state = GraphRuntimeState(
@ -313,8 +319,8 @@ class WorkflowEntry:
) )
# init workflow run state # init workflow run state
node_factory = DifyNodeFactory( node_factory = DifyNodeFactory.from_graph_init_context(
graph_init_params=graph_init_params, graph_init_context=graph_init_context,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
) )
node = node_factory.create_node(node_config) node = node_factory.create_node(node_config)
@ -409,17 +415,18 @@ class WorkflowEntry:
variable_pool = VariablePool() variable_pool = VariablePool()
add_variables_to_pool(variable_pool, default_system_variables()) add_variables_to_pool(variable_pool, default_system_variables())
# init graph init params and runtime state # init graph context and runtime state
graph_init_params = GraphInitParams( run_context = build_dify_run_context(
tenant_id=tenant_id,
app_id="",
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
graph_init_context = DifyGraphInitContext(
workflow_id="", workflow_id="",
graph_config=graph_dict, graph_config=graph_dict,
run_context=build_dify_run_context( run_context=run_context,
tenant_id=tenant_id,
app_id="",
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
),
call_depth=0, call_depth=0,
) )
graph_runtime_state = GraphRuntimeState( graph_runtime_state = GraphRuntimeState(
@ -430,8 +437,8 @@ class WorkflowEntry:
# init workflow run state # init workflow run state
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data}) node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
node_factory = DifyNodeFactory( node_factory = DifyNodeFactory.from_graph_init_context(
graph_init_params=graph_init_params, graph_init_context=graph_init_context,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
) )
node = node_factory.create_node(node_config) node = node_factory.create_node(node_config)

View File

@ -5,12 +5,30 @@ from typing import Any
import pytz # type: ignore[import-untyped] import pytz # type: ignore[import-untyped]
from celery import Celery, Task from celery import Celery, Task
from celery.schedules import crontab from celery.schedules import crontab
from typing_extensions import TypedDict
from configs import dify_config from configs import dify_config
from dify_app import DifyApp from dify_app import DifyApp
def get_celery_ssl_options() -> dict[str, Any] | None: class _CelerySentinelKwargsDict(TypedDict):
socket_timeout: float | None
password: str | None
class CelerySentinelTransportDict(TypedDict):
master_name: str | None
sentinel_kwargs: _CelerySentinelKwargsDict
class CelerySSLOptionsDict(TypedDict):
ssl_cert_reqs: int
ssl_ca_certs: str | None
ssl_certfile: str | None
ssl_keyfile: str | None
def get_celery_ssl_options() -> CelerySSLOptionsDict | None:
"""Get SSL configuration for Celery broker/backend connections.""" """Get SSL configuration for Celery broker/backend connections."""
# Only apply SSL if we're using Redis as broker/backend # Only apply SSL if we're using Redis as broker/backend
if not dify_config.BROKER_USE_SSL: if not dify_config.BROKER_USE_SSL:
@ -33,26 +51,24 @@ def get_celery_ssl_options() -> dict[str, Any] | None:
ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE) ssl_cert_reqs = cert_reqs_map.get(dify_config.REDIS_SSL_CERT_REQS, ssl.CERT_NONE)
ssl_options = { return CelerySSLOptionsDict(
"ssl_cert_reqs": ssl_cert_reqs, ssl_cert_reqs=ssl_cert_reqs,
"ssl_ca_certs": dify_config.REDIS_SSL_CA_CERTS, ssl_ca_certs=dify_config.REDIS_SSL_CA_CERTS,
"ssl_certfile": dify_config.REDIS_SSL_CERTFILE, ssl_certfile=dify_config.REDIS_SSL_CERTFILE,
"ssl_keyfile": dify_config.REDIS_SSL_KEYFILE, ssl_keyfile=dify_config.REDIS_SSL_KEYFILE,
} )
return ssl_options
def get_celery_broker_transport_options() -> dict[str, Any]: def get_celery_broker_transport_options() -> CelerySentinelTransportDict | dict[str, Any]:
"""Get broker transport options (e.g. Redis Sentinel) for Celery connections.""" """Get broker transport options (e.g. Redis Sentinel) for Celery connections."""
if dify_config.CELERY_USE_SENTINEL: if dify_config.CELERY_USE_SENTINEL:
return { return CelerySentinelTransportDict(
"master_name": dify_config.CELERY_SENTINEL_MASTER_NAME, master_name=dify_config.CELERY_SENTINEL_MASTER_NAME,
"sentinel_kwargs": { sentinel_kwargs=_CelerySentinelKwargsDict(
"socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT, socket_timeout=dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
"password": dify_config.CELERY_SENTINEL_PASSWORD, password=dify_config.CELERY_SENTINEL_PASSWORD,
}, ),
} )
return {} return {}

View File

@ -674,28 +674,24 @@ class AppModelConfig(TypeBase):
def suggested_questions_list(self) -> list[str]: def suggested_questions_list(self) -> list[str]:
return json.loads(self.suggested_questions) if self.suggested_questions else [] return json.loads(self.suggested_questions) if self.suggested_questions else []
def _get_enabled_config(self, value: str | None, *, default_enabled: bool = False) -> EnabledConfig:
return cast(EnabledConfig, json.loads(value) if value else {"enabled": default_enabled})
@property @property
def suggested_questions_after_answer_dict(self) -> EnabledConfig: def suggested_questions_after_answer_dict(self) -> EnabledConfig:
return cast( return self._get_enabled_config(self.suggested_questions_after_answer)
EnabledConfig,
json.loads(self.suggested_questions_after_answer)
if self.suggested_questions_after_answer
else {"enabled": False},
)
@property @property
def speech_to_text_dict(self) -> EnabledConfig: def speech_to_text_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}) return self._get_enabled_config(self.speech_to_text)
@property @property
def text_to_speech_dict(self) -> EnabledConfig: def text_to_speech_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}) return self._get_enabled_config(self.text_to_speech)
@property @property
def retriever_resource_dict(self) -> EnabledConfig: def retriever_resource_dict(self) -> EnabledConfig:
return cast( return self._get_enabled_config(self.retriever_resource, default_enabled=True)
EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True}
)
@property @property
def annotation_reply_dict(self) -> AnnotationReplyConfig: def annotation_reply_dict(self) -> AnnotationReplyConfig:
@ -722,7 +718,7 @@ class AppModelConfig(TypeBase):
@property @property
def more_like_this_dict(self) -> EnabledConfig: def more_like_this_dict(self) -> EnabledConfig:
return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}) return self._get_enabled_config(self.more_like_this)
@property @property
def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig: def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig:
@ -902,7 +898,7 @@ class InstalledApp(TypeBase):
return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id)) return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id))
class TrialApp(Base): class TrialApp(TypeBase):
__tablename__ = "trial_apps" __tablename__ = "trial_apps"
__table_args__ = ( __table_args__ = (
sa.PrimaryKeyConstraint("id", name="trial_app_pkey"), sa.PrimaryKeyConstraint("id", name="trial_app_pkey"),
@ -911,18 +907,26 @@ class TrialApp(Base):
sa.UniqueConstraint("app_id", name="unique_trail_app_id"), sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
) )
id = mapped_column(StringUUID, default=gen_uuidv4_string) id: Mapped[str] = mapped_column(
app_id = mapped_column(StringUUID, nullable=False) StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False
tenant_id = mapped_column(StringUUID, nullable=False) )
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
trial_limit = mapped_column(sa.Integer, nullable=False, default=3) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
insert_default=func.current_timestamp(),
server_default=func.current_timestamp(),
init=False,
)
trial_limit: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=3)
@property @property
def app(self) -> App | None: def app(self) -> App | None:
return db.session.scalar(select(App).where(App.id == self.app_id)) return db.session.scalar(select(App).where(App.id == self.app_id))
class AccountTrialAppRecord(Base): class AccountTrialAppRecord(TypeBase):
__tablename__ = "account_trial_app_records" __tablename__ = "account_trial_app_records"
__table_args__ = ( __table_args__ = (
sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"), sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"),
@ -930,11 +934,19 @@ class AccountTrialAppRecord(Base):
sa.Index("account_trial_app_record_app_id_idx", "app_id"), sa.Index("account_trial_app_record_app_id_idx", "app_id"),
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"), sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
) )
id = mapped_column(StringUUID, default=gen_uuidv4_string) id: Mapped[str] = mapped_column(
account_id = mapped_column(StringUUID, nullable=False) StringUUID, insert_default=gen_uuidv4_string, default_factory=gen_uuidv4_string, init=False
app_id = mapped_column(StringUUID, nullable=False) )
count = mapped_column(sa.Integer, nullable=False, default=0) account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime,
nullable=False,
insert_default=func.current_timestamp(),
server_default=func.current_timestamp(),
init=False,
)
@property @property
def app(self) -> App | None: def app(self) -> App | None:

View File

@ -66,12 +66,15 @@ def build_file_from_stored_mapping(
record_id = resolve_file_record_id(mapping) record_id = resolve_file_record_id(mapping)
transfer_method = FileTransferMethod.value_of(mapping["transfer_method"]) transfer_method = FileTransferMethod.value_of(mapping["transfer_method"])
if transfer_method == FileTransferMethod.TOOL_FILE and record_id: match transfer_method:
mapping["tool_file_id"] = record_id case FileTransferMethod.TOOL_FILE if record_id:
elif transfer_method in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL] and record_id: mapping["tool_file_id"] = record_id
mapping["upload_file_id"] = record_id case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL if record_id:
elif transfer_method == FileTransferMethod.DATASOURCE_FILE and record_id: mapping["upload_file_id"] = record_id
mapping["datasource_file_id"] = record_id case FileTransferMethod.DATASOURCE_FILE if record_id:
mapping["datasource_file_id"] = record_id
case _:
pass
if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None: if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None:
remote_url = mapping.get("remote_url") remote_url = mapping.get("remote_url")

View File

@ -173,7 +173,7 @@ dev = [
"sseclient-py>=1.8.0", "sseclient-py>=1.8.0",
"pytest-timeout>=2.4.0", "pytest-timeout>=2.4.0",
"pytest-xdist>=3.8.0", "pytest-xdist>=3.8.0",
"pyrefly>=0.59.1", "pyrefly>=0.60.0",
] ]
############################################################ ############################################################

View File

@ -467,61 +467,67 @@ class AppDslService:
) )
# Initialize app based on mode # Initialize app based on mode
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: match app_mode:
workflow_data = data.get("workflow") case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW:
if not workflow_data or not isinstance(workflow_data, dict): workflow_data = data.get("workflow")
raise ValueError("Missing workflow data for workflow/advanced chat app") if not workflow_data or not isinstance(workflow_data, dict):
raise ValueError("Missing workflow data for workflow/advanced chat app")
environment_variables_list = workflow_data.get("environment_variables", []) environment_variables_list = workflow_data.get("environment_variables", [])
environment_variables = [ environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
] ]
conversation_variables_list = workflow_data.get("conversation_variables", []) conversation_variables_list = workflow_data.get("conversation_variables", [])
conversation_variables = [ conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list variable_factory.build_conversation_variable_from_mapping(obj)
] for obj in conversation_variables_list
]
workflow_service = WorkflowService() workflow_service = WorkflowService()
current_draft_workflow = workflow_service.get_draft_workflow(app_model=app) current_draft_workflow = workflow_service.get_draft_workflow(app_model=app)
if current_draft_workflow: if current_draft_workflow:
unique_hash = current_draft_workflow.unique_hash unique_hash = current_draft_workflow.unique_hash
else: else:
unique_hash = None unique_hash = None
graph = workflow_data.get("graph", {}) graph = workflow_data.get("graph", {})
for node in graph.get("nodes", []): for node in graph.get("nodes", []):
if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL:
dataset_ids = node["data"].get("dataset_ids", []) dataset_ids = node["data"].get("dataset_ids", [])
node["data"]["dataset_ids"] = [ node["data"]["dataset_ids"] = [
decrypted_id decrypted_id
for dataset_id in dataset_ids for dataset_id in dataset_ids
if (decrypted_id := self.decrypt_dataset_id(encrypted_data=dataset_id, tenant_id=app.tenant_id)) if (
] decrypted_id := self.decrypt_dataset_id(
workflow_service.sync_draft_workflow( encrypted_data=dataset_id, tenant_id=app.tenant_id
app_model=app, )
graph=workflow_data.get("graph", {}), )
features=workflow_data.get("features", {}), ]
unique_hash=unique_hash, workflow_service.sync_draft_workflow(
account=account, app_model=app,
environment_variables=environment_variables, graph=workflow_data.get("graph", {}),
conversation_variables=conversation_variables, features=workflow_data.get("features", {}),
) unique_hash=unique_hash,
elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}: account=account,
# Initialize model config environment_variables=environment_variables,
model_config = data.get("model_config") conversation_variables=conversation_variables,
if not model_config or not isinstance(model_config, dict): )
raise ValueError("Missing model_config for chat/agent-chat/completion app") case AppMode.CHAT | AppMode.AGENT_CHAT | AppMode.COMPLETION:
# Initialize or update model config # Initialize model config
if not app.app_model_config: model_config = data.get("model_config")
app_model_config = AppModelConfig( if not model_config or not isinstance(model_config, dict):
app_id=app.id, created_by=account.id, updated_by=account.id raise ValueError("Missing model_config for chat/agent-chat/completion app")
).from_model_config_dict(cast(AppModelConfigDict, model_config)) # Initialize or update model config
app_model_config.id = str(uuid4()) if not app.app_model_config:
app.app_model_config_id = app_model_config.id app_model_config = AppModelConfig(
app_id=app.id, created_by=account.id, updated_by=account.id
).from_model_config_dict(cast(AppModelConfigDict, model_config))
app_model_config.id = str(uuid4())
app.app_model_config_id = app_model_config.id
self._session.add(app_model_config) self._session.add(app_model_config)
app_model_config_was_updated.send(app, app_model_config=app_model_config) app_model_config_was_updated.send(app, app_model_config=app_model_config)
else: case _:
raise ValueError("Invalid app mode") raise ValueError("Invalid app mode")
return app return app
@classmethod @classmethod

View File

@ -132,8 +132,8 @@ class FileService:
return file_size <= file_size_limit return file_size <= file_size_limit
def get_file_base64(self, file_id: str) -> str: def get_file_base64(self, file_id: str) -> str:
upload_file = ( upload_file = self._session_maker(expire_on_commit=False).scalar(
self._session_maker(expire_on_commit=False).query(UploadFile).where(UploadFile.id == file_id).first() select(UploadFile).where(UploadFile.id == file_id).limit(1)
) )
if not upload_file: if not upload_file:
raise NotFound("File not found") raise NotFound("File not found")
@ -178,7 +178,7 @@ class FileService:
Return a short text preview extracted from a document file. Return a short text preview extracted from a document file.
""" """
with self._session_maker(expire_on_commit=False) as session: with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if not upload_file: if not upload_file:
raise NotFound("File not found") raise NotFound("File not found")
@ -200,7 +200,7 @@ class FileService:
if not result: if not result:
raise NotFound("File not found or signature is invalid") raise NotFound("File not found or signature is invalid")
with self._session_maker(expire_on_commit=False) as session: with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if not upload_file: if not upload_file:
raise NotFound("File not found or signature is invalid") raise NotFound("File not found or signature is invalid")
@ -220,7 +220,7 @@ class FileService:
raise NotFound("File not found or signature is invalid") raise NotFound("File not found or signature is invalid")
with self._session_maker(expire_on_commit=False) as session: with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if not upload_file: if not upload_file:
raise NotFound("File not found or signature is invalid") raise NotFound("File not found or signature is invalid")
@ -231,7 +231,7 @@ class FileService:
def get_public_image_preview(self, file_id: str): def get_public_image_preview(self, file_id: str):
with self._session_maker(expire_on_commit=False) as session: with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() upload_file = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if not upload_file: if not upload_file:
raise NotFound("File not found or signature is invalid") raise NotFound("File not found or signature is invalid")
@ -247,7 +247,7 @@ class FileService:
def get_file_content(self, file_id: str) -> str: def get_file_content(self, file_id: str) -> str:
with self._session_maker(expire_on_commit=False) as session: with self._session_maker(expire_on_commit=False) as session:
upload_file: UploadFile | None = session.query(UploadFile).where(UploadFile.id == file_id).first() upload_file: UploadFile | None = session.scalar(select(UploadFile).where(UploadFile.id == file_id).limit(1))
if not upload_file: if not upload_file:
raise NotFound("File not found") raise NotFound("File not found")

View File

@ -1,14 +1,13 @@
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db from core.db.session_factory import session_factory
from models.account import TenantPluginAutoUpgradeStrategy from models.account import TenantPluginAutoUpgradeStrategy
class PluginAutoUpgradeService: class PluginAutoUpgradeService:
@staticmethod @staticmethod
def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None: def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None:
with sessionmaker(bind=db.engine).begin() as session: with session_factory.create_session() as session:
return session.scalar( return session.scalar(
select(TenantPluginAutoUpgradeStrategy) select(TenantPluginAutoUpgradeStrategy)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
@ -24,7 +23,7 @@ class PluginAutoUpgradeService:
exclude_plugins: list[str], exclude_plugins: list[str],
include_plugins: list[str], include_plugins: list[str],
) -> bool: ) -> bool:
with sessionmaker(bind=db.engine).begin() as session: with session_factory.create_session() as session:
exist_strategy = session.scalar( exist_strategy = session.scalar(
select(TenantPluginAutoUpgradeStrategy) select(TenantPluginAutoUpgradeStrategy)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
@ -51,7 +50,7 @@ class PluginAutoUpgradeService:
@staticmethod @staticmethod
def exclude_plugin(tenant_id: str, plugin_id: str) -> bool: def exclude_plugin(tenant_id: str, plugin_id: str) -> bool:
with sessionmaker(bind=db.engine).begin() as session: with session_factory.create_session() as session:
exist_strategy = session.scalar( exist_strategy = session.scalar(
select(TenantPluginAutoUpgradeStrategy) select(TenantPluginAutoUpgradeStrategy)
.where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id) .where(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)

View File

@ -1,3 +1,5 @@
from typing import Any, TypedDict
from sqlalchemy import select from sqlalchemy import select
from constants.languages import languages from constants.languages import languages
@ -8,16 +10,43 @@ from services.recommend_app.recommend_app_base import RecommendAppRetrievalBase
from services.recommend_app.recommend_app_type import RecommendAppType from services.recommend_app.recommend_app_type import RecommendAppType
class RecommendedAppItemDict(TypedDict):
id: str
app: App | None
app_id: str
description: Any
copyright: Any
privacy_policy: Any
custom_disclaimer: str
category: str
position: int
is_listed: bool
class RecommendedAppsResultDict(TypedDict):
recommended_apps: list[RecommendedAppItemDict]
categories: list[str]
class RecommendedAppDetailDict(TypedDict):
id: str
name: str
icon: Any
icon_background: str | None
mode: str
export_data: str
class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase): class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
""" """
Retrieval recommended app from database Retrieval recommended app from database
""" """
def get_recommended_apps_and_categories(self, language: str): def get_recommended_apps_and_categories(self, language: str) -> RecommendedAppsResultDict:
result = self.fetch_recommended_apps_from_db(language) result = self.fetch_recommended_apps_from_db(language)
return result return result
def get_recommend_app_detail(self, app_id: str): def get_recommend_app_detail(self, app_id: str) -> RecommendedAppDetailDict | None:
result = self.fetch_recommended_app_detail_from_db(app_id) result = self.fetch_recommended_app_detail_from_db(app_id)
return result return result
@ -25,7 +54,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
return RecommendAppType.DATABASE return RecommendAppType.DATABASE
@classmethod @classmethod
def fetch_recommended_apps_from_db(cls, language: str): def fetch_recommended_apps_from_db(cls, language: str) -> RecommendedAppsResultDict:
""" """
Fetch recommended apps from db. Fetch recommended apps from db.
:param language: language :param language: language
@ -41,7 +70,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
).all() ).all()
categories = set() categories = set()
recommended_apps_result = [] recommended_apps_result: list[RecommendedAppItemDict] = []
for recommended_app in recommended_apps: for recommended_app in recommended_apps:
app = recommended_app.app app = recommended_app.app
if not app or not app.is_public: if not app or not app.is_public:
@ -51,7 +80,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
if not site: if not site:
continue continue
recommended_app_result = { recommended_app_result: RecommendedAppItemDict = {
"id": recommended_app.id, "id": recommended_app.id,
"app": recommended_app.app, "app": recommended_app.app,
"app_id": recommended_app.app_id, "app_id": recommended_app.app_id,
@ -67,10 +96,10 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
categories.add(recommended_app.category) categories.add(recommended_app.category)
return {"recommended_apps": recommended_apps_result, "categories": sorted(categories)} return RecommendedAppsResultDict(recommended_apps=recommended_apps_result, categories=sorted(categories))
@classmethod @classmethod
def fetch_recommended_app_detail_from_db(cls, app_id: str) -> dict | None: def fetch_recommended_app_detail_from_db(cls, app_id: str) -> RecommendedAppDetailDict | None:
""" """
Fetch recommended app detail from db. Fetch recommended app detail from db.
:param app_id: App ID :param app_id: App ID
@ -89,11 +118,11 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
if not app_model or not app_model.is_public: if not app_model or not app_model.is_public:
return None return None
return { return RecommendedAppDetailDict(
"id": app_model.id, id=app_model.id,
"name": app_model.name, name=app_model.name,
"icon": app_model.icon, icon=app_model.icon,
"icon_background": app_model.icon_background, icon_background=app_model.icon_background,
"mode": app_model.mode, mode=app_model.mode,
"export_data": AppDslService.export_dsl(app_model=app_model), export_data=AppDslService.export_dsl(app_model=app_model),
} )

View File

@ -104,32 +104,32 @@ class WebhookService:
""" """
with Session(db.engine) as session: with Session(db.engine) as session:
# Get webhook trigger # Get webhook trigger
webhook_trigger = ( webhook_trigger = session.scalar(
session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.webhook_id == webhook_id).first() select(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.webhook_id == webhook_id).limit(1)
) )
if not webhook_trigger: if not webhook_trigger:
raise ValueError(f"Webhook not found: {webhook_id}") raise ValueError(f"Webhook not found: {webhook_id}")
if is_debug: if is_debug:
workflow = ( workflow = session.scalar(
session.query(Workflow) select(Workflow)
.filter( .where(
Workflow.app_id == webhook_trigger.app_id, Workflow.app_id == webhook_trigger.app_id,
Workflow.version == Workflow.VERSION_DRAFT, Workflow.version == Workflow.VERSION_DRAFT,
) )
.order_by(Workflow.created_at.desc()) .order_by(Workflow.created_at.desc())
.first() .limit(1)
) )
else: else:
# Check if the corresponding AppTrigger exists # Check if the corresponding AppTrigger exists
app_trigger = ( app_trigger = session.scalar(
session.query(AppTrigger) select(AppTrigger)
.filter( .where(
AppTrigger.app_id == webhook_trigger.app_id, AppTrigger.app_id == webhook_trigger.app_id,
AppTrigger.node_id == webhook_trigger.node_id, AppTrigger.node_id == webhook_trigger.node_id,
AppTrigger.trigger_type == AppTriggerType.TRIGGER_WEBHOOK, AppTrigger.trigger_type == AppTriggerType.TRIGGER_WEBHOOK,
) )
.first() .limit(1)
) )
if not app_trigger: if not app_trigger:
@ -146,14 +146,14 @@ class WebhookService:
raise ValueError(f"Webhook trigger is disabled for webhook {webhook_id}") raise ValueError(f"Webhook trigger is disabled for webhook {webhook_id}")
# Get workflow # Get workflow
workflow = ( workflow = session.scalar(
session.query(Workflow) select(Workflow)
.filter( .where(
Workflow.app_id == webhook_trigger.app_id, Workflow.app_id == webhook_trigger.app_id,
Workflow.version != Workflow.VERSION_DRAFT, Workflow.version != Workflow.VERSION_DRAFT,
) )
.order_by(Workflow.created_at.desc()) .order_by(Workflow.created_at.desc())
.first() .limit(1)
) )
if not workflow: if not workflow:
raise ValueError(f"Workflow not found for app {webhook_trigger.app_id}") raise ValueError(f"Workflow not found for app {webhook_trigger.app_id}")

View File

@ -5,7 +5,7 @@ import uuid
from collections.abc import Callable, Generator, Mapping, Sequence from collections.abc import Callable, Generator, Mapping, Sequence
from typing import Any, cast from typing import Any, cast
from graphon.entities import GraphInitParams, WorkflowNodeExecution from graphon.entities import WorkflowNodeExecution
from graphon.entities.graph_config import NodeConfigDict from graphon.entities.graph_config import NodeConfigDict
from graphon.entities.pause_reason import HumanInputRequired from graphon.entities.pause_reason import HumanInputRequired
from graphon.enums import ( from graphon.enums import (
@ -48,7 +48,12 @@ from core.workflow.human_input_compat import (
normalize_human_input_node_data_for_graph, normalize_human_input_node_data_for_graph,
parse_human_input_delivery_methods, parse_human_input_delivery_methods,
) )
from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type from core.workflow.node_factory import (
LATEST_VERSION,
DifyGraphInitContext,
get_node_type_classes_mapping,
is_start_node_type,
)
from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
@ -1204,18 +1209,20 @@ class WorkflowService:
node_config: NodeConfigDict, node_config: NodeConfigDict,
variable_pool: VariablePool, variable_pool: VariablePool,
) -> HumanInputNode: ) -> HumanInputNode:
graph_init_params = GraphInitParams( run_context = build_dify_run_context(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
user_id=account.id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
)
graph_init_context = DifyGraphInitContext(
workflow_id=workflow.id, workflow_id=workflow.id,
graph_config=workflow.graph_dict, graph_config=workflow.graph_dict,
run_context=build_dify_run_context( run_context=run_context,
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
user_id=account.id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
),
call_depth=0, call_depth=0,
) )
graph_init_params = graph_init_context.to_graph_init_params()
graph_runtime_state = GraphRuntimeState( graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool, variable_pool=variable_pool,
start_at=time.perf_counter(), start_at=time.perf_counter(),
@ -1225,7 +1232,7 @@ class WorkflowService:
config=node_config, config=node_config,
graph_init_params=graph_init_params, graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state, graph_runtime_state=graph_runtime_state,
runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), runtime=DifyHumanInputNodeRuntime(run_context),
) )
return node return node

View File

@ -3,6 +3,7 @@ import time
import click import click
from celery import shared_task # type: ignore from celery import shared_task # type: ignore
from sqlalchemy import select, update
from core.db.session_factory import session_factory from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.doc_type import DocType
@ -26,43 +27,42 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
with session_factory.create_session() as session: with session_factory.create_session() as session:
try: try:
dataset = session.query(Dataset).filter_by(id=dataset_id).first() dataset = session.scalar(select(Dataset).where(Dataset.id == dataset_id).limit(1))
if not dataset: if not dataset:
raise Exception("Dataset not found") raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "upgrade": if action == "upgrade":
dataset_documents = ( dataset_documents = session.scalars(
session.query(DatasetDocument) select(DatasetDocument).where(
.where(
DatasetDocument.dataset_id == dataset_id, DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed", DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
DatasetDocument.archived == False, DatasetDocument.archived == False,
) )
.all() ).all()
)
if dataset_documents: if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents] dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( session.execute(
{"indexing_status": "indexing"}, synchronize_session=False update(DatasetDocument)
.where(DatasetDocument.id.in_(dataset_documents_ids))
.values(indexing_status="indexing")
) )
session.commit() session.commit()
for dataset_document in dataset_documents: for dataset_document in dataset_documents:
try: try:
# add from vector index # add from vector index
segments = ( segments = session.scalars(
session.query(DocumentSegment) select(DocumentSegment)
.where( .where(
DocumentSegment.document_id == dataset_document.id, DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True, DocumentSegment.enabled == True,
) )
.order_by(DocumentSegment.position.asc()) .order_by(DocumentSegment.position.asc())
.all() ).all()
)
if segments: if segments:
documents = [] documents = []
for segment in segments: for segment in segments:
@ -81,32 +81,36 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
# clean keywords # clean keywords
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False) index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
index_processor.load(dataset, documents, with_keywords=False) index_processor.load(dataset, documents, with_keywords=False)
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( session.execute(
{"indexing_status": "completed"}, synchronize_session=False update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="completed")
) )
session.commit() session.commit()
except Exception as e: except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( session.execute(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="error", error=str(e))
) )
session.commit() session.commit()
elif action == "update": elif action == "update":
dataset_documents = ( dataset_documents = session.scalars(
session.query(DatasetDocument) select(DatasetDocument).where(
.where(
DatasetDocument.dataset_id == dataset_id, DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed", DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
DatasetDocument.archived == False, DatasetDocument.archived == False,
) )
.all() ).all()
)
# add new index # add new index
if dataset_documents: if dataset_documents:
# update document status # update document status
dataset_documents_ids = [doc.id for doc in dataset_documents] dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update( session.execute(
{"indexing_status": "indexing"}, synchronize_session=False update(DatasetDocument)
.where(DatasetDocument.id.in_(dataset_documents_ids))
.values(indexing_status="indexing")
) )
session.commit() session.commit()
@ -116,15 +120,14 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
for dataset_document in dataset_documents: for dataset_document in dataset_documents:
# update from vector index # update from vector index
try: try:
segments = ( segments = session.scalars(
session.query(DocumentSegment) select(DocumentSegment)
.where( .where(
DocumentSegment.document_id == dataset_document.id, DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True, DocumentSegment.enabled == True,
) )
.order_by(DocumentSegment.position.asc()) .order_by(DocumentSegment.position.asc())
.all() ).all()
)
if segments: if segments:
documents = [] documents = []
multimodal_documents = [] multimodal_documents = []
@ -173,13 +176,17 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
index_processor.load( index_processor.load(
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
) )
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( session.execute(
{"indexing_status": "completed"}, synchronize_session=False update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="completed")
) )
session.commit() session.commit()
except Exception as e: except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update( session.execute(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False update(DatasetDocument)
.where(DatasetDocument.id == dataset_document.id)
.values(indexing_status="error", error=str(e))
) )
session.commit() session.commit()
else: else:

View File

@ -862,6 +862,15 @@ class TestAuthOrchestration:
result = discover_protected_resource_metadata(None, "https://api.example.com") result = discover_protected_resource_metadata(None, "https://api.example.com")
assert result is None assert result is None
# JSONDecodeError (non-JSON 200 response)
mock_get.side_effect = None
bad_json_response = Mock()
bad_json_response.status_code = 200
bad_json_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
mock_get.return_value = bad_json_response
result = discover_protected_resource_metadata(None, "https://api.example.com")
assert result is None
@patch("core.helper.ssrf_proxy.get") @patch("core.helper.ssrf_proxy.get")
def test_discover_oauth_authorization_server_metadata(self, mock_get): def test_discover_oauth_authorization_server_metadata(self, mock_get):
# Success # Success
@ -892,6 +901,14 @@ class TestAuthOrchestration:
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com") result = discover_oauth_authorization_server_metadata(None, "https://api.example.com")
assert result is None assert result is None
# JSONDecodeError (non-JSON 200 response)
bad_json_response = Mock()
bad_json_response.status_code = 200
bad_json_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
mock_get.return_value = bad_json_response
result = discover_oauth_authorization_server_metadata(None, "https://api.example.com")
assert result is None
def test_get_effective_scope(self): def test_get_effective_scope(self):
prm = ProtectedResourceMetadata( prm = ProtectedResourceMetadata(
resource="https://api.example.com", resource="https://api.example.com",
@ -997,6 +1014,24 @@ class TestAuthOrchestration:
supported, url = check_support_resource_discovery("https://api") supported, url = check_support_resource_discovery("https://api")
assert supported is False assert supported is False
# Case 6: JSONDecodeError (non-JSON 200 response)
mock_get.side_effect = None
bad_json_res = Mock()
bad_json_res.status_code = 200
bad_json_res.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
mock_get.return_value = bad_json_res
supported, url = check_support_resource_discovery("https://api")
assert supported is False
assert url == ""
# Case 7: Empty authorization_servers array (IndexError)
empty_res = Mock()
empty_res.status_code = 200
empty_res.json.return_value = {"authorization_servers": []}
mock_get.return_value = empty_res
supported, url = check_support_resource_discovery("https://api")
assert supported is False
def test_discover_oauth_metadata(self): def test_discover_oauth_metadata(self):
with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm: with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm:
with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm: with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm:

View File

@ -110,6 +110,34 @@ class TestFetchMemory:
) )
class TestDifyGraphInitContext:
def test_to_graph_init_params_preserves_explicit_values(self):
run_context = {
DIFY_RUN_CONTEXT_KEY: DifyRunContext(
tenant_id="tenant-id",
app_id="app-id",
user_id="user-id",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
),
"extra": "value",
}
graph_config = {"nodes": [], "edges": []}
graph_init_context = node_factory.DifyGraphInitContext(
workflow_id="workflow-id",
graph_config=graph_config,
run_context=run_context,
call_depth=2,
)
result = graph_init_context.to_graph_init_params()
assert result.workflow_id == "workflow-id"
assert result.graph_config == graph_config
assert result.run_context == run_context
assert result.call_depth == 2
class TestDefaultWorkflowCodeExecutor: class TestDefaultWorkflowCodeExecutor:
def test_execute_delegates_to_code_executor(self, monkeypatch): def test_execute_delegates_to_code_executor(self, monkeypatch):
executor = node_factory.DefaultWorkflowCodeExecutor() executor = node_factory.DefaultWorkflowCodeExecutor()
@ -172,6 +200,23 @@ class TestCodeExecutorJinja2TemplateRenderer:
class TestDifyNodeFactoryInit: class TestDifyNodeFactoryInit:
def test_from_graph_init_context_translates_before_init(self):
graph_init_context = MagicMock()
graph_init_context.to_graph_init_params.return_value = sentinel.graph_init_params
with patch.object(node_factory.DifyNodeFactory, "__init__", return_value=None) as init:
factory = node_factory.DifyNodeFactory.from_graph_init_context(
graph_init_context=graph_init_context,
graph_runtime_state=sentinel.graph_runtime_state,
)
assert isinstance(factory, node_factory.DifyNodeFactory)
graph_init_context.to_graph_init_params.assert_called_once_with()
init.assert_called_once_with(
graph_init_params=sentinel.graph_init_params,
graph_runtime_state=sentinel.graph_runtime_state,
)
def test_init_builds_default_dependencies(self): def test_init_builds_default_dependencies(self):
graph_init_params = SimpleNamespace(run_context={"context": "value"}) graph_init_params = SimpleNamespace(run_context={"context": "value"})
graph_runtime_state = sentinel.graph_runtime_state graph_runtime_state = sentinel.graph_runtime_state

View File

@ -349,7 +349,7 @@ class TestWorkflowEntrySingleStepRun:
] ]
with ( with (
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
patch.object( patch.object(
workflow_entry, workflow_entry,
"GraphRuntimeState", "GraphRuntimeState",
@ -358,7 +358,7 @@ class TestWorkflowEntrySingleStepRun:
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
patch.object(workflow_entry.time, "perf_counter", return_value=123.0), patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeLLMNode), patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeLLMNode),
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory, patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory,
patch.object(workflow_entry, "load_into_variable_pool"), patch.object(workflow_entry, "load_into_variable_pool"),
patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"), patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"),
patch.object( patch.object(
@ -412,12 +412,12 @@ class TestWorkflowEntrySingleStepRun:
raise NotImplementedError raise NotImplementedError
with ( with (
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
patch.object(workflow_entry.time, "perf_counter", return_value=123.0), patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeNode), patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeNode),
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory, patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory,
patch.object(workflow_entry, "add_node_inputs_to_pool") as add_node_inputs_to_pool, patch.object(workflow_entry, "add_node_inputs_to_pool") as add_node_inputs_to_pool,
patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool, patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool,
patch.object( patch.object(
@ -481,12 +481,12 @@ class TestWorkflowEntrySingleStepRun:
return {"question": ["node", "question"]} return {"question": ["node", "question"]}
with ( with (
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
patch.object(workflow_entry.time, "perf_counter", return_value=123.0), patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeDatasourceNode), patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeDatasourceNode),
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory, patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory,
patch.object(workflow_entry, "add_node_inputs_to_pool") as add_node_inputs_to_pool, patch.object(workflow_entry, "add_node_inputs_to_pool") as add_node_inputs_to_pool,
patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool, patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool,
patch.object( patch.object(
@ -541,12 +541,12 @@ class TestWorkflowEntrySingleStepRun:
return "1" return "1"
with ( with (
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
patch.object(workflow_entry.time, "perf_counter", return_value=123.0), patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeNode), patch.object(workflow_entry, "resolve_workflow_node_class", return_value=FakeNode),
patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory, patch.object(workflow_entry.DifyNodeFactory, "from_graph_init_context") as dify_node_factory,
patch.object(workflow_entry, "add_node_inputs_to_pool"), patch.object(workflow_entry, "add_node_inputs_to_pool"),
patch.object(workflow_entry, "load_into_variable_pool"), patch.object(workflow_entry, "load_into_variable_pool"),
patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"), patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"),
@ -651,14 +651,18 @@ class TestWorkflowEntryHelpers:
patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool) as variable_pool_cls, patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool) as variable_pool_cls,
patch.object(workflow_entry, "add_variables_to_pool") as add_variables_to_pool, patch.object(workflow_entry, "add_variables_to_pool") as add_variables_to_pool,
patch.object( patch.object(
workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context
) as graph_init_params, ) as graph_init_context_cls,
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
patch.object( patch.object(
workflow_entry, "build_dify_run_context", return_value={"_dify": "context"} workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}
) as build_dify_run_context, ) as build_dify_run_context,
patch.object(workflow_entry.time, "perf_counter", return_value=123.0), patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory) as dify_node_factory_cls, patch.object(
workflow_entry.DifyNodeFactory,
"from_graph_init_context",
return_value=dify_node_factory,
) as dify_node_factory_cls,
patch.object( patch.object(
workflow_entry.WorkflowEntry, workflow_entry.WorkflowEntry,
"mapping_user_inputs_to_variable_pool", "mapping_user_inputs_to_variable_pool",
@ -688,7 +692,7 @@ class TestWorkflowEntryHelpers:
user_from=UserFrom.ACCOUNT, user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER, invoke_from=InvokeFrom.DEBUGGER,
) )
graph_init_params.assert_called_once_with( graph_init_context_cls.assert_called_once_with(
workflow_id="", workflow_id="",
graph_config=workflow_entry.WorkflowEntry._create_single_node_graph( graph_config=workflow_entry.WorkflowEntry._create_single_node_graph(
"node-id", {"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR, "title": "Node"} "node-id", {"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR, "title": "Node"}
@ -697,7 +701,7 @@ class TestWorkflowEntryHelpers:
call_depth=0, call_depth=0,
) )
dify_node_factory_cls.assert_called_once_with( dify_node_factory_cls.assert_called_once_with(
graph_init_params=sentinel.graph_init_params, graph_init_context=sentinel.graph_init_context,
graph_runtime_state=sentinel.graph_runtime_state, graph_runtime_state=sentinel.graph_runtime_state,
) )
mapping_user_inputs_to_variable_pool.assert_called_once_with( mapping_user_inputs_to_variable_pool.assert_called_once_with(
@ -734,11 +738,15 @@ class TestWorkflowEntryHelpers:
patch.object(workflow_entry, "default_system_variables", return_value=sentinel.system_variables), patch.object(workflow_entry, "default_system_variables", return_value=sentinel.system_variables),
patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool), patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool),
patch.object(workflow_entry, "add_variables_to_pool"), patch.object(workflow_entry, "add_variables_to_pool"),
patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), patch.object(workflow_entry, "DifyGraphInitContext", return_value=sentinel.graph_init_context),
patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state),
patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}),
patch.object(workflow_entry.time, "perf_counter", return_value=123.0), patch.object(workflow_entry.time, "perf_counter", return_value=123.0),
patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory), patch.object(
workflow_entry.DifyNodeFactory,
"from_graph_init_context",
return_value=dify_node_factory,
),
patch.object( patch.object(
workflow_entry.WorkflowEntry, workflow_entry.WorkflowEntry,
"mapping_user_inputs_to_variable_pool", "mapping_user_inputs_to_variable_pool",

View File

@ -6,23 +6,23 @@ MODULE = "services.plugin.plugin_auto_upgrade_service"
def _patched_session(): def _patched_session():
"""Patch sessionmaker(bind=db.engine).begin() to return a mock session as context manager.""" """Patch session_factory.create_session() to return a mock session as context manager."""
session = MagicMock() session = MagicMock()
mock_sessionmaker = MagicMock() session.__enter__ = MagicMock(return_value=session)
mock_sessionmaker.return_value.begin.return_value.__enter__ = MagicMock(return_value=session) session.__exit__ = MagicMock(return_value=False)
mock_sessionmaker.return_value.begin.return_value.__exit__ = MagicMock(return_value=False) mock_factory = MagicMock()
patcher = patch(f"{MODULE}.sessionmaker", mock_sessionmaker) mock_factory.create_session.return_value = session
db_patcher = patch(f"{MODULE}.db") patcher = patch(f"{MODULE}.session_factory", mock_factory)
return patcher, db_patcher, session return patcher, session
class TestGetStrategy: class TestGetStrategy:
def test_returns_strategy_when_found(self): def test_returns_strategy_when_found(self):
p1, p2, session = _patched_session() p1, session = _patched_session()
strategy = MagicMock() strategy = MagicMock()
session.scalar.return_value = strategy session.scalar.return_value = strategy
with p1, p2: with p1:
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.get_strategy("t1") result = PluginAutoUpgradeService.get_strategy("t1")
@ -30,10 +30,10 @@ class TestGetStrategy:
assert result is strategy assert result is strategy
def test_returns_none_when_not_found(self): def test_returns_none_when_not_found(self):
p1, p2, session = _patched_session() p1, session = _patched_session()
session.scalar.return_value = None session.scalar.return_value = None
with p1, p2: with p1:
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.get_strategy("t1") result = PluginAutoUpgradeService.get_strategy("t1")
@ -43,10 +43,10 @@ class TestGetStrategy:
class TestChangeStrategy: class TestChangeStrategy:
def test_creates_new_strategy(self): def test_creates_new_strategy(self):
p1, p2, session = _patched_session() p1, session = _patched_session()
session.scalar.return_value = None session.scalar.return_value = None
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.return_value = MagicMock() strat_cls.return_value = MagicMock()
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
@ -63,11 +63,11 @@ class TestChangeStrategy:
session.add.assert_called_once() session.add.assert_called_once()
def test_updates_existing_strategy(self): def test_updates_existing_strategy(self):
p1, p2, session = _patched_session() p1, session = _patched_session()
existing = MagicMock() existing = MagicMock()
session.scalar.return_value = existing session.scalar.return_value = existing
with p1, p2: with p1:
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
result = PluginAutoUpgradeService.change_strategy( result = PluginAutoUpgradeService.change_strategy(
@ -89,12 +89,11 @@ class TestChangeStrategy:
class TestExcludePlugin: class TestExcludePlugin:
def test_creates_default_strategy_when_none_exists(self): def test_creates_default_strategy_when_none_exists(self):
p1, p2, session = _patched_session() p1, session = _patched_session()
session.scalar.return_value = None session.scalar.return_value = None
with ( with (
p1, p1,
p2,
patch(f"{MODULE}.select"), patch(f"{MODULE}.select"),
patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls, patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls,
patch(f"{MODULE}.PluginAutoUpgradeService.change_strategy") as cs, patch(f"{MODULE}.PluginAutoUpgradeService.change_strategy") as cs,
@ -110,13 +109,13 @@ class TestExcludePlugin:
cs.assert_called_once() cs.assert_called_once()
def test_appends_to_exclude_list_in_exclude_mode(self): def test_appends_to_exclude_list_in_exclude_mode(self):
p1, p2, session = _patched_session() p1, session = _patched_session()
existing = MagicMock() existing = MagicMock()
existing.upgrade_mode = "exclude" existing.upgrade_mode = "exclude"
existing.exclude_plugins = ["p-existing"] existing.exclude_plugins = ["p-existing"]
session.scalar.return_value = existing session.scalar.return_value = existing
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude" strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial" strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all" strat_cls.UpgradeMode.ALL = "all"
@ -128,13 +127,13 @@ class TestExcludePlugin:
assert existing.exclude_plugins == ["p-existing", "p-new"] assert existing.exclude_plugins == ["p-existing", "p-new"]
def test_removes_from_include_list_in_partial_mode(self): def test_removes_from_include_list_in_partial_mode(self):
p1, p2, session = _patched_session() p1, session = _patched_session()
existing = MagicMock() existing = MagicMock()
existing.upgrade_mode = "partial" existing.upgrade_mode = "partial"
existing.include_plugins = ["p1", "p2"] existing.include_plugins = ["p1", "p2"]
session.scalar.return_value = existing session.scalar.return_value = existing
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude" strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial" strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all" strat_cls.UpgradeMode.ALL = "all"
@ -146,12 +145,12 @@ class TestExcludePlugin:
assert existing.include_plugins == ["p2"] assert existing.include_plugins == ["p2"]
def test_switches_to_exclude_mode_from_all(self): def test_switches_to_exclude_mode_from_all(self):
p1, p2, session = _patched_session() p1, session = _patched_session()
existing = MagicMock() existing = MagicMock()
existing.upgrade_mode = "all" existing.upgrade_mode = "all"
session.scalar.return_value = existing session.scalar.return_value = existing
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude" strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial" strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all" strat_cls.UpgradeMode.ALL = "all"
@ -164,13 +163,13 @@ class TestExcludePlugin:
assert existing.exclude_plugins == ["p1"] assert existing.exclude_plugins == ["p1"]
def test_no_duplicate_in_exclude_list(self): def test_no_duplicate_in_exclude_list(self):
p1, p2, session = _patched_session() p1, session = _patched_session()
existing = MagicMock() existing = MagicMock()
existing.upgrade_mode = "exclude" existing.upgrade_mode = "exclude"
existing.exclude_plugins = ["p1"] existing.exclude_plugins = ["p1"]
session.scalar.return_value = existing session.scalar.return_value = existing
with p1, p2, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls: with p1, patch(f"{MODULE}.select"), patch(f"{MODULE}.TenantPluginAutoUpgradeStrategy") as strat_cls:
strat_cls.UpgradeMode.EXCLUDE = "exclude" strat_cls.UpgradeMode.EXCLUDE = "exclude"
strat_cls.UpgradeMode.PARTIAL = "partial" strat_cls.UpgradeMode.PARTIAL = "partial"
strat_cls.UpgradeMode.ALL = "all" strat_cls.UpgradeMode.ALL = "all"

View File

@ -165,7 +165,7 @@ class TestFileService:
upload_file = MagicMock(spec=UploadFile) upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id" upload_file.id = "file_id"
upload_file.key = "test_key" upload_file.key = "test_key"
mock_db_session.query().where().first.return_value = upload_file mock_db_session.scalar.return_value = upload_file
with patch("services.file_service.storage") as mock_storage: with patch("services.file_service.storage") as mock_storage:
mock_storage.load_once.return_value = b"test content" mock_storage.load_once.return_value = b"test content"
@ -178,7 +178,7 @@ class TestFileService:
mock_storage.load_once.assert_called_once_with("test_key") mock_storage.load_once.assert_called_once_with("test_key")
def test_get_file_base64_not_found(self, file_service, mock_db_session): def test_get_file_base64_not_found(self, file_service, mock_db_session):
mock_db_session.query().where().first.return_value = None mock_db_session.scalar.return_value = None
with pytest.raises(NotFound, match="File not found"): with pytest.raises(NotFound, match="File not found"):
file_service.get_file_base64("non_existent") file_service.get_file_base64("non_existent")
@ -215,7 +215,7 @@ class TestFileService:
upload_file = MagicMock(spec=UploadFile) upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id" upload_file.id = "file_id"
upload_file.extension = "pdf" upload_file.extension = "pdf"
mock_db_session.query().where().first.return_value = upload_file mock_db_session.scalar.return_value = upload_file
with patch("services.file_service.ExtractProcessor.load_from_upload_file") as mock_extract: with patch("services.file_service.ExtractProcessor.load_from_upload_file") as mock_extract:
mock_extract.return_value = "Extracted text content" mock_extract.return_value = "Extracted text content"
@ -227,7 +227,7 @@ class TestFileService:
assert result == "Extracted text content" assert result == "Extracted text content"
def test_get_file_preview_not_found(self, file_service, mock_db_session): def test_get_file_preview_not_found(self, file_service, mock_db_session):
mock_db_session.query().where().first.return_value = None mock_db_session.scalar.return_value = None
with pytest.raises(NotFound, match="File not found"): with pytest.raises(NotFound, match="File not found"):
file_service.get_file_preview("non_existent") file_service.get_file_preview("non_existent")
@ -235,7 +235,7 @@ class TestFileService:
upload_file = MagicMock(spec=UploadFile) upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id" upload_file.id = "file_id"
upload_file.extension = "exe" upload_file.extension = "exe"
mock_db_session.query().where().first.return_value = upload_file mock_db_session.scalar.return_value = upload_file
with pytest.raises(UnsupportedFileTypeError): with pytest.raises(UnsupportedFileTypeError):
file_service.get_file_preview("file_id") file_service.get_file_preview("file_id")
@ -246,7 +246,7 @@ class TestFileService:
upload_file.extension = "jpg" upload_file.extension = "jpg"
upload_file.mime_type = "image/jpeg" upload_file.mime_type = "image/jpeg"
upload_file.key = "key" upload_file.key = "key"
mock_db_session.query().where().first.return_value = upload_file mock_db_session.scalar.return_value = upload_file
with ( with (
patch("services.file_service.file_helpers.verify_image_signature") as mock_verify, patch("services.file_service.file_helpers.verify_image_signature") as mock_verify,
@ -269,7 +269,7 @@ class TestFileService:
file_service.get_image_preview("file_id", "ts", "nonce", "sign") file_service.get_image_preview("file_id", "ts", "nonce", "sign")
def test_get_image_preview_not_found(self, file_service, mock_db_session): def test_get_image_preview_not_found(self, file_service, mock_db_session):
mock_db_session.query().where().first.return_value = None mock_db_session.scalar.return_value = None
with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify: with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify:
mock_verify.return_value = True mock_verify.return_value = True
with pytest.raises(NotFound, match="File not found or signature is invalid"): with pytest.raises(NotFound, match="File not found or signature is invalid"):
@ -279,7 +279,7 @@ class TestFileService:
upload_file = MagicMock(spec=UploadFile) upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id" upload_file.id = "file_id"
upload_file.extension = "txt" upload_file.extension = "txt"
mock_db_session.query().where().first.return_value = upload_file mock_db_session.scalar.return_value = upload_file
with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify: with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify:
mock_verify.return_value = True mock_verify.return_value = True
with pytest.raises(UnsupportedFileTypeError): with pytest.raises(UnsupportedFileTypeError):
@ -289,7 +289,7 @@ class TestFileService:
upload_file = MagicMock(spec=UploadFile) upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id" upload_file.id = "file_id"
upload_file.key = "key" upload_file.key = "key"
mock_db_session.query().where().first.return_value = upload_file mock_db_session.scalar.return_value = upload_file
with ( with (
patch("services.file_service.file_helpers.verify_file_signature") as mock_verify, patch("services.file_service.file_helpers.verify_file_signature") as mock_verify,
@ -309,7 +309,7 @@ class TestFileService:
file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign") file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign")
def test_get_file_generator_by_file_id_not_found(self, file_service, mock_db_session): def test_get_file_generator_by_file_id_not_found(self, file_service, mock_db_session):
mock_db_session.query().where().first.return_value = None mock_db_session.scalar.return_value = None
with patch("services.file_service.file_helpers.verify_file_signature") as mock_verify: with patch("services.file_service.file_helpers.verify_file_signature") as mock_verify:
mock_verify.return_value = True mock_verify.return_value = True
with pytest.raises(NotFound, match="File not found or signature is invalid"): with pytest.raises(NotFound, match="File not found or signature is invalid"):
@ -321,7 +321,7 @@ class TestFileService:
upload_file.extension = "png" upload_file.extension = "png"
upload_file.mime_type = "image/png" upload_file.mime_type = "image/png"
upload_file.key = "key" upload_file.key = "key"
mock_db_session.query().where().first.return_value = upload_file mock_db_session.scalar.return_value = upload_file
with patch("services.file_service.storage") as mock_storage: with patch("services.file_service.storage") as mock_storage:
mock_storage.load.return_value = b"image content" mock_storage.load.return_value = b"image content"
@ -330,7 +330,7 @@ class TestFileService:
assert mime == "image/png" assert mime == "image/png"
def test_get_public_image_preview_not_found(self, file_service, mock_db_session): def test_get_public_image_preview_not_found(self, file_service, mock_db_session):
mock_db_session.query().where().first.return_value = None mock_db_session.scalar.return_value = None
with pytest.raises(NotFound, match="File not found or signature is invalid"): with pytest.raises(NotFound, match="File not found or signature is invalid"):
file_service.get_public_image_preview("file_id") file_service.get_public_image_preview("file_id")
@ -338,7 +338,7 @@ class TestFileService:
upload_file = MagicMock(spec=UploadFile) upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id" upload_file.id = "file_id"
upload_file.extension = "txt" upload_file.extension = "txt"
mock_db_session.query().where().first.return_value = upload_file mock_db_session.scalar.return_value = upload_file
with pytest.raises(UnsupportedFileTypeError): with pytest.raises(UnsupportedFileTypeError):
file_service.get_public_image_preview("file_id") file_service.get_public_image_preview("file_id")
@ -346,7 +346,7 @@ class TestFileService:
upload_file = MagicMock(spec=UploadFile) upload_file = MagicMock(spec=UploadFile)
upload_file.id = "file_id" upload_file.id = "file_id"
upload_file.key = "key" upload_file.key = "key"
mock_db_session.query().where().first.return_value = upload_file mock_db_session.scalar.return_value = upload_file
with patch("services.file_service.storage") as mock_storage: with patch("services.file_service.storage") as mock_storage:
mock_storage.load.return_value = b"hello world" mock_storage.load.return_value = b"hello world"
@ -354,7 +354,7 @@ class TestFileService:
assert result == "hello world" assert result == "hello world"
def test_get_file_content_not_found(self, file_service, mock_db_session): def test_get_file_content_not_found(self, file_service, mock_db_session):
mock_db_session.query().where().first.return_value = None mock_db_session.scalar.return_value = None
with pytest.raises(NotFound, match="File not found"): with pytest.raises(NotFound, match="File not found"):
file_service.get_file_content("file_id") file_service.get_file_content("file_id")

View File

@ -657,7 +657,7 @@ def _app(**kwargs: Any) -> App:
def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None: def test_get_webhook_trigger_and_workflow_should_raise_when_webhook_not_found(monkeypatch: pytest.MonkeyPatch) -> None:
# Arrange # Arrange
fake_session = MagicMock() fake_session = MagicMock()
fake_session.query.return_value = _FakeQuery(None) fake_session.scalar.return_value = None
_patch_session(monkeypatch, fake_session) _patch_session(monkeypatch, fake_session)
# Act / Assert # Act / Assert
@ -671,7 +671,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_not_foun
# Arrange # Arrange
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
fake_session = MagicMock() fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(None)] fake_session.scalar.side_effect = [webhook_trigger, None]
_patch_session(monkeypatch, fake_session) _patch_session(monkeypatch, fake_session)
# Act / Assert # Act / Assert
@ -686,7 +686,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_rate_lim
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED) app_trigger = SimpleNamespace(status=AppTriggerStatus.RATE_LIMITED)
fake_session = MagicMock() fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)] fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
_patch_session(monkeypatch, fake_session) _patch_session(monkeypatch, fake_session)
# Act / Assert # Act / Assert
@ -701,7 +701,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_app_trigger_disabled
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED) app_trigger = SimpleNamespace(status=AppTriggerStatus.DISABLED)
fake_session = MagicMock() fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger)] fake_session.scalar.side_effect = [webhook_trigger, app_trigger]
_patch_session(monkeypatch, fake_session) _patch_session(monkeypatch, fake_session)
# Act / Assert # Act / Assert
@ -714,7 +714,7 @@ def test_get_webhook_trigger_and_workflow_should_raise_when_workflow_not_found(m
webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1") webhook_trigger = SimpleNamespace(app_id="app-1", node_id="node-1")
app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED) app_trigger = SimpleNamespace(status=AppTriggerStatus.ENABLED)
fake_session = MagicMock() fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(None)] fake_session.scalar.side_effect = [webhook_trigger, app_trigger, None]
_patch_session(monkeypatch, fake_session) _patch_session(monkeypatch, fake_session)
# Act / Assert # Act / Assert
@ -732,7 +732,7 @@ def test_get_webhook_trigger_and_workflow_should_return_values_for_non_debug_mod
workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}} workflow.get_node_config_by_id.return_value = {"data": {"key": "value"}}
fake_session = MagicMock() fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(app_trigger), _FakeQuery(workflow)] fake_session.scalar.side_effect = [webhook_trigger, app_trigger, workflow]
_patch_session(monkeypatch, fake_session) _patch_session(monkeypatch, fake_session)
# Act # Act
@ -751,7 +751,7 @@ def test_get_webhook_trigger_and_workflow_should_return_values_for_debug_mode(mo
workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}} workflow.get_node_config_by_id.return_value = {"data": {"mode": "debug"}}
fake_session = MagicMock() fake_session = MagicMock()
fake_session.query.side_effect = [_FakeQuery(webhook_trigger), _FakeQuery(workflow)] fake_session.scalar.side_effect = [webhook_trigger, workflow]
_patch_session(monkeypatch, fake_session) _patch_session(monkeypatch, fake_session)
# Act # Act

View File

@ -2826,9 +2826,9 @@ class TestWorkflowServiceFreeNodeExecution:
variable_pool = MagicMock() variable_pool = MagicMock()
with ( with (
patch("services.workflow_service.GraphInitParams") as mock_graph_init_params, patch("services.workflow_service.DifyGraphInitContext") as mock_graph_init_context_cls,
patch("services.workflow_service.GraphRuntimeState"), patch("services.workflow_service.GraphRuntimeState"),
patch("services.workflow_service.build_dify_run_context"), patch("services.workflow_service.build_dify_run_context") as mock_build_dify_run_context,
patch("services.workflow_service.DifyHumanInputNodeRuntime") as mock_runtime_cls, patch("services.workflow_service.DifyHumanInputNodeRuntime") as mock_runtime_cls,
patch("services.workflow_service.HumanInputNode") as mock_node_cls, patch("services.workflow_service.HumanInputNode") as mock_node_cls,
): ):
@ -2837,4 +2837,17 @@ class TestWorkflowServiceFreeNodeExecution:
) )
assert node == mock_node_cls.return_value assert node == mock_node_cls.return_value
mock_node_cls.assert_called_once() mock_node_cls.assert_called_once()
mock_runtime_cls.assert_called_once_with(mock_graph_init_params.return_value.run_context) mock_graph_init_context_cls.assert_called_once_with(
workflow_id="wf-1",
graph_config=workflow.graph_dict,
run_context=mock_build_dify_run_context.return_value,
call_depth=0,
)
mock_runtime_cls.assert_called_once_with(mock_build_dify_run_context.return_value)
mock_node_cls.assert_called_once_with(
id="n-1",
config=node_config,
graph_init_params=mock_graph_init_context_cls.return_value.to_graph_init_params.return_value,
graph_runtime_state=ANY,
runtime=mock_runtime_cls.return_value,
)

View File

@ -19,7 +19,57 @@ from core.tools.utils.text_processing_utils import remove_leading_symbols
("[Google](https://google.com) is a search engine", "[Google](https://google.com) is a search engine"), ("[Google](https://google.com) is a search engine", "[Google](https://google.com) is a search engine"),
("[Example](http://example.com) some text", "[Example](http://example.com) some text"), ("[Example](http://example.com) some text", "[Example](http://example.com) some text"),
# Leading symbols before markdown link are removed, including the opening bracket [ # Leading symbols before markdown link are removed, including the opening bracket [
("@[Test](https://example.com)", "Test](https://example.com)"), ("@[Test](https://example.com)", "[Test](https://example.com)"),
("~~标题~~", "标题~~"),
('""quoted', "quoted"),
("''test", "test"),
("##话题", "话题"),
("$$价格", "价格"),
("%%百分比", "百分比"),
("&&与逻辑", "与逻辑"),
("((括号))", "括号))"),
("**强调**", "强调**"),
("++自增", "自增"),
(",,逗号", "逗号"),
("..省略", "省略"),
("//注释", "注释"),
("::范围", "范围"),
(";;分号", "分号"),
("<<左移", "左移"),
("==等于", "等于"),
(">>右移", "右移"),
("??疑问", "疑问"),
("@@提及", "提及"),
("^^上标", "上标"),
("__下划线", "下划线"),
("``代码", "代码"),
("~~删除线", "删除线"),
(" 全角空格开头", "全角空格开头"),
("、顿号开头", "顿号开头"),
("。句号开头", "句号开头"),
("「引号」测试", "引号」测试"),
("『书名号』", "书名号』"),
("【保留】测试", "【保留】测试"),
("〖括号〗测试", "括号〗测试"),
("〔括号〕测试", "括号〕测试"),
("~~【保留】~~", "【保留】~~"),
('"[公告]"', '[公告]"'),
("[公告] 更新", "[公告] 更新"),
("【通知】重要", "【通知】重要"),
("[[嵌套]]", "[[嵌套]]"),
("【【嵌套】】", "【【嵌套】】"),
("[【混合】]", "[【混合】]"),
("normal text", "normal text"),
("123数字", "123数字"),
("中文开头", "中文开头"),
("alpha", "alpha"),
("~", ""),
("", ""),
("[", "["),
("~~~", ""),
("【【【", "【【【"),
("\t制表符", "\t制表符"),
("\n换行", "\n换行"),
], ],
) )
def test_remove_leading_symbols(input_text, expected_output): def test_remove_leading_symbols(input_text, expected_output):

24
api/uv.lock generated
View File

@ -1585,7 +1585,7 @@ dev = [
{ name = "lxml-stubs", specifier = "~=0.5.1" }, { name = "lxml-stubs", specifier = "~=0.5.1" },
{ name = "mypy", specifier = "~=1.20.0" }, { name = "mypy", specifier = "~=1.20.0" },
{ name = "pandas-stubs", specifier = "~=3.0.0" }, { name = "pandas-stubs", specifier = "~=3.0.0" },
{ name = "pyrefly", specifier = ">=0.59.1" }, { name = "pyrefly", specifier = ">=0.60.0" },
{ name = "pytest", specifier = "~=9.0.2" }, { name = "pytest", specifier = "~=9.0.2" },
{ name = "pytest-benchmark", specifier = "~=5.2.3" }, { name = "pytest-benchmark", specifier = "~=5.2.3" },
{ name = "pytest-cov", specifier = "~=7.1.0" }, { name = "pytest-cov", specifier = "~=7.1.0" },
@ -4850,19 +4850,19 @@ wheels = [
[[package]] [[package]]
name = "pyrefly" name = "pyrefly"
version = "0.59.1" version = "0.60.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d5/ce/7882c2af92b2ff6505fcd3430eff8048ece6c6254cc90bdc76ecee12dfab/pyrefly-0.59.1.tar.gz", hash = "sha256:bf1675b0c38d45df2c8f8618cbdfa261a1b92430d9d31eba16e0282b551e210f", size = 5475432, upload-time = "2026-04-01T22:04:04.11Z" } sdist = { url = "https://files.pythonhosted.org/packages/c6/c7/28d14b64888e2d03815627ebff8d57a9f08389c4bbebfe70ae1ed98a1267/pyrefly-0.60.0.tar.gz", hash = "sha256:2499f5b6ff5342e86dfe1cd94bcce133519bbbc93b7ad5636195fea4f0fa3b81", size = 5500389, upload-time = "2026-04-06T19:57:30.643Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/d0/10/04a0e05b08fc855b6fe38c3df549925fc3c2c6e750506870de7335d3e1f7/pyrefly-0.59.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:390db3cd14aa7e0268e847b60cd9ee18b04273eddfa38cf341ed3bb43f3fef2a", size = 12868133, upload-time = "2026-04-01T22:03:39.436Z" }, { url = "https://files.pythonhosted.org/packages/31/99/6c9984a09220e5eb7dd5c869b7a32d25c3d06b5e8854c6eb679db1145c3e/pyrefly-0.60.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:bf1691af0fee69d0c99c3c6e9d26ab6acd3c8afef96416f9ba2e74934833b7b5", size = 12921262, upload-time = "2026-04-06T19:57:00.745Z" },
{ url = "https://files.pythonhosted.org/packages/c7/78/fa7be227c3e3fcacee501c1562278dd026186ffd1b5b5beb51d3941a3aed/pyrefly-0.59.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d246d417b6187c1650d7f855f61c68fbfd6d6155dc846d4e4d273a3e6b5175cb", size = 12379325, upload-time = "2026-04-01T22:03:42.046Z" }, { url = "https://files.pythonhosted.org/packages/05/b3/6216aa3c00c88e59a27eb4149851b5affe86eeea6129f4224034a32dddb0/pyrefly-0.60.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:3e71b70c9b95545cf3b479bc55d1381b531de7b2380eb64411088a1e56b634cb", size = 12424413, upload-time = "2026-04-06T19:57:03.417Z" },
{ url = "https://files.pythonhosted.org/packages/bb/13/6828ce1c98171b5f8388f33c4b0b9ea2ab8c49abe0ef8d793c31e30a05cb/pyrefly-0.59.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:575ac67b04412dc651a7143d27e38a40fbdd3c831c714d5520d0e9d4c8631ab4", size = 35826408, upload-time = "2026-04-01T22:03:45.067Z" }, { url = "https://files.pythonhosted.org/packages/9b/87/eb8dd73abd92a93952ac27a605e463c432fb250fb23186574038c7035594/pyrefly-0.60.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:680ee5f8f98230ea145652d7344708f5375786209c5bf03d8b911fdb0d0d4195", size = 35940884, upload-time = "2026-04-06T19:57:06.909Z" },
{ url = "https://files.pythonhosted.org/packages/23/56/79ed8ece9a7ecad0113c394a06a084107db3ad8f1fefe19e7ded43c51245/pyrefly-0.59.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:062e6262ce1064d59dcad81ac0499bb7a3ad501e9bc8a677a50dc630ff0bf862", size = 38532699, upload-time = "2026-04-01T22:03:48.376Z" }, { url = "https://files.pythonhosted.org/packages/0d/34/dc6aeb67b840c745fcee6db358295d554abe6ab555a7eaaf44624bd80bf1/pyrefly-0.60.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d0b20dbbe4aff15b959e8d825b7521a144c4122c11e57022e83b36568c54470", size = 38677220, upload-time = "2026-04-06T19:57:11.235Z" },
{ url = "https://files.pythonhosted.org/packages/18/7d/ecc025e0f0e3f295b497f523cc19cefaa39e57abede8fc353d29445d174b/pyrefly-0.59.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:43ef4247f9e6f734feb93e1f2b75335b943629956e509f545cc9cdcccd76dd20", size = 36743570, upload-time = "2026-04-01T22:03:51.362Z" }, { url = "https://files.pythonhosted.org/packages/66/6b/c863fcf7ef592b7d1db91502acf0d1113be8bed7a2a7143fc6f0dd90616f/pyrefly-0.60.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2911563c8e6b2eaefff68885c94727965469a35375a409235a7a4d2b7157dc15", size = 36907431, upload-time = "2026-04-06T19:57:15.074Z" },
{ url = "https://files.pythonhosted.org/packages/2f/03/b1ce882ebcb87c673165c00451fbe4df17bf96ccfde18c75880dc87c5f5e/pyrefly-0.59.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:59a2d01723b84d042f4fa6ec871ffd52d0a7e83b0ea791c2e0bb0ff750abce56", size = 41236246, upload-time = "2026-04-01T22:03:54.361Z" }, { url = "https://files.pythonhosted.org/packages/8e/a2/25ea095ab2ecca8e62884669b11a79f14299db93071685b73a97efbaf4f3/pyrefly-0.60.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0a631d9d04705e303fe156f2e62551611bc7ef8066c34708ceebcfb3088bd55", size = 41447898, upload-time = "2026-04-06T19:57:19.382Z" },
{ url = "https://files.pythonhosted.org/packages/17/af/5e9c7afd510e7dd64a2204be0ed39e804089cbc4338675a28615c7176acb/pyrefly-0.59.1-py3-none-win32.whl", hash = "sha256:4ea70c780848f8376411e787643ae5d2d09da8a829362332b7b26d15ebcbaf56", size = 11884747, upload-time = "2026-04-01T22:03:56.776Z" }, { url = "https://files.pythonhosted.org/packages/8e/2c/097bdc6e8d40676b28eb03710a4577bc3c7b803cd24693ac02bf15de3d67/pyrefly-0.60.0-py3-none-win32.whl", hash = "sha256:a08d69298da5626cf502d3debbb6944fd13d2f405ea6625363751f1ff570d366", size = 11913434, upload-time = "2026-04-06T19:57:22.887Z" },
{ url = "https://files.pythonhosted.org/packages/aa/c1/7db1077627453fd1068f0761f059a9512645c00c4c20acfb9f0c24ac02ec/pyrefly-0.59.1-py3-none-win_amd64.whl", hash = "sha256:67e6a08cfd129a0d2788d5e40a627f9860e0fe91a876238d93d5c63ff4af68ae", size = 12720608, upload-time = "2026-04-01T22:03:59.252Z" }, { url = "https://files.pythonhosted.org/packages/0a/d4/8d27fe310e830c8d11ab73db38b93f9fd2e218744b6efb1204401c9a74d5/pyrefly-0.60.0-py3-none-win_amd64.whl", hash = "sha256:56cf30654e708ae1dd635ffefcba4fa4b349dd7004a6ccc5c41e3a9bb944320c", size = 12745033, upload-time = "2026-04-06T19:57:25.517Z" },
{ url = "https://files.pythonhosted.org/packages/07/16/4bb6e5fce5a9cf0992932d9435d964c33e507aaaf96fdfbb1be493078a4a/pyrefly-0.59.1-py3-none-win_arm64.whl", hash = "sha256:01179cb215cf079e8223a064f61a074f7079aa97ea705cbbc68af3d6713afd15", size = 12223158, upload-time = "2026-04-01T22:04:01.869Z" }, { url = "https://files.pythonhosted.org/packages/1f/ad/8eea1f8fb8209f91f6dbfe48000c9d05fd0cdb1b5b3157283c9b1dada55d/pyrefly-0.60.0-py3-none-win_arm64.whl", hash = "sha256:b6d27fba970f4777063c0227c54167d83bece1804ea34f69e7118e409ba038d2", size = 12246390, upload-time = "2026-04-06T19:57:28.141Z" },
] ]
[[package]] [[package]]

13
pnpm-lock.yaml generated
View File

@ -520,8 +520,8 @@ catalogs:
specifier: 13.0.0 specifier: 13.0.0
version: 13.0.0 version: 13.0.0
vinext: vinext:
specifier: https://pkg.pr.new/vinext@adbf24d specifier: 0.0.41
version: 0.0.5 version: 0.0.41
vite-plugin-inspect: vite-plugin-inspect:
specifier: 12.0.0-beta.1 specifier: 12.0.0-beta.1
version: 12.0.0-beta.1 version: 12.0.0-beta.1
@ -1162,7 +1162,7 @@ importers:
version: 3.19.3 version: 3.19.3
vinext: vinext:
specifier: 'catalog:' specifier: 'catalog:'
version: https://pkg.pr.new/vinext@adbf24d(@mdx-js/rollup@3.1.1(rollup@4.59.0))(@vitejs/plugin-react@6.0.1(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)))(@vitejs/plugin-rsc@0.5.23(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5))(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(next@16.2.3(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(sass@1.98.0))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5)(typescript@6.0.2) version: 0.0.41(@mdx-js/rollup@3.1.1(rollup@4.59.0))(@vitejs/plugin-react@6.0.1(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)))(@vitejs/plugin-rsc@0.5.23(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5))(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(next@16.2.3(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(sass@1.98.0))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5)(typescript@6.0.2)
vite: vite:
specifier: npm:@voidzero-dev/vite-plus-core@0.1.16 specifier: npm:@voidzero-dev/vite-plus-core@0.1.16
version: '@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)' version: '@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)'
@ -8336,9 +8336,8 @@ packages:
vfile@6.0.3: vfile@6.0.3:
resolution: {integrity: sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==} resolution: {integrity: sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==}
vinext@https://pkg.pr.new/vinext@adbf24d: vinext@0.0.41:
resolution: {tarball: https://pkg.pr.new/vinext@adbf24d} resolution: {integrity: sha512-fpQjNp6cIqjYGH2/kbhN2SdIYHEu79RdlII23SWsY1Qp7LM+je8GfTJH1sxw6dASxPhZKZB/jCmTm5d2/D25zw==}
version: 0.0.5
engines: {node: '>=22'} engines: {node: '>=22'}
hasBin: true hasBin: true
peerDependencies: peerDependencies:
@ -16586,7 +16585,7 @@ snapshots:
'@types/unist': 3.0.3 '@types/unist': 3.0.3
vfile-message: 4.0.3 vfile-message: 4.0.3
vinext@https://pkg.pr.new/vinext@adbf24d(@mdx-js/rollup@3.1.1(rollup@4.59.0))(@vitejs/plugin-react@6.0.1(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)))(@vitejs/plugin-rsc@0.5.23(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5))(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(next@16.2.3(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(sass@1.98.0))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5)(typescript@6.0.2): vinext@0.0.41(@mdx-js/rollup@3.1.1(rollup@4.59.0))(@vitejs/plugin-react@6.0.1(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3)))(@vitejs/plugin-rsc@0.5.23(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5))(@voidzero-dev/vite-plus-core@0.1.16(@types/node@25.5.2)(jiti@2.6.1)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@6.0.2)(yaml@2.8.3))(next@16.2.3(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(sass@1.98.0))(react-dom@19.2.5(react@19.2.5))(react-server-dom-webpack@19.2.5(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(webpack@5.105.4(uglify-js@3.19.3)))(react@19.2.5)(typescript@6.0.2):
dependencies: dependencies:
'@unpic/react': 1.0.2(next@16.2.3(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(sass@1.98.0))(react-dom@19.2.5(react@19.2.5))(react@19.2.5) '@unpic/react': 1.0.2(next@16.2.3(@babel/core@7.29.0)(@playwright/test@1.59.1)(react-dom@19.2.5(react@19.2.5))(react@19.2.5)(sass@1.98.0))(react-dom@19.2.5(react@19.2.5))(react@19.2.5)
'@vercel/og': 0.8.6 '@vercel/og': 0.8.6

View File

@ -221,7 +221,7 @@ catalog:
unist-util-visit: 5.1.0 unist-util-visit: 5.1.0
use-context-selector: 2.0.0 use-context-selector: 2.0.0
uuid: 13.0.0 uuid: 13.0.0
vinext: https://pkg.pr.new/vinext@adbf24d vinext: 0.0.41
vite: npm:@voidzero-dev/vite-plus-core@0.1.16 vite: npm:@voidzero-dev/vite-plus-core@0.1.16
vite-plugin-inspect: 12.0.0-beta.1 vite-plugin-inspect: 12.0.0-beta.1
vite-plus: 0.1.16 vite-plus: 0.1.16

View File

@ -229,6 +229,23 @@ describe('output-schema-utils', () => {
}) })
}) })
describe('Dify compact types (workflow-as-tool output_schema)', () => {
it('should resolve array[string] to arrayString (issue #34428)', () => {
const result = resolveVarType({ type: 'array[string]' })
expect(result.type).toBe(VarType.arrayString)
})
it('should resolve Array[string] case-insensitively', () => {
const result = resolveVarType({ type: 'Array[string]' })
expect(result.type).toBe(VarType.arrayString)
})
it('should resolve array[object] to arrayObject', () => {
const result = resolveVarType({ type: 'array[object]' })
expect(result.type).toBe(VarType.arrayObject)
})
})
describe('unknown types', () => { describe('unknown types', () => {
it('should resolve unknown type to any', () => { it('should resolve unknown type to any', () => {
const result = resolveVarType({ type: 'unknown_type' }) const result = resolveVarType({ type: 'unknown_type' })

View File

@ -2,6 +2,30 @@ import type { SchemaTypeDefinition } from '@/service/use-common'
import { VarType } from '@/app/components/workflow/types' import { VarType } from '@/app/components/workflow/types'
import { getMatchedSchemaType } from '../_base/components/variable/use-match-schema-type' import { getMatchedSchemaType } from '../_base/components/variable/use-match-schema-type'
/**
* Workflow-as-tool and some internal APIs store Dify VarType strings (e.g. `array[string]`)
* in JSON Schema `type` instead of standard `{ type: 'array', items: { type: 'string' } }`.
* Map those compact strings to VarType so downstream (e.g. Code node var picker) does not
* fall back to `any` and get filtered out.
*/
const resolveDifyCompactTypeString = (typeStr: string): VarType | undefined => {
const trimmed = typeStr.trim()
const m = /^array\[(string|number|integer|boolean|object|file|any)\]$/i.exec(trimmed)
if (!m)
return undefined
const inner = m[1].toLowerCase()
const map: Record<string, VarType> = {
string: VarType.arrayString,
number: VarType.arrayNumber,
integer: VarType.arrayNumber,
boolean: VarType.arrayBoolean,
object: VarType.arrayObject,
file: VarType.arrayFile,
any: VarType.arrayAny,
}
return map[inner]
}
/** /**
* Normalizes a JSON Schema type to a simple string type. * Normalizes a JSON Schema type to a simple string type.
* Handles complex schemas with oneOf, anyOf, allOf. * Handles complex schemas with oneOf, anyOf, allOf.
@ -54,6 +78,12 @@ export const resolveVarType = (
schemaTypeDefinitions?: SchemaTypeDefinition[], schemaTypeDefinitions?: SchemaTypeDefinition[],
): { type: VarType, schemaType?: string } => { ): { type: VarType, schemaType?: string } => {
const schemaType = getMatchedSchemaType(schema, schemaTypeDefinitions) const schemaType = getMatchedSchemaType(schema, schemaTypeDefinitions)
if (schema && typeof schema.type === 'string') {
const compact = resolveDifyCompactTypeString(schema.type)
if (compact !== undefined)
return { type: compact, schemaType }
}
const normalizedType = normalizeJsonSchemaType(schema) const normalizedType = normalizeJsonSchemaType(schema)
switch (normalizedType) { switch (normalizedType) {