diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 5ac0e342e6..7e41260eeb 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -95,7 +95,7 @@ class CreateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode") - icon_type: str | None = Field(default=None, description="Icon type") + icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") @@ -103,7 +103,7 @@ class CreateAppPayload(BaseModel): class UpdateAppPayload(BaseModel): name: str = Field(..., min_length=1, description="App name") description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400) - icon_type: str | None = Field(default=None, description="Icon type") + icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon") @@ -113,7 +113,7 @@ class UpdateAppPayload(BaseModel): class CopyAppPayload(BaseModel): name: str | None = Field(default=None, description="Name for the copied app") description: str | None = Field(default=None, description="Description for the copied app", max_length=400) - icon_type: str | None = Field(default=None, description="Icon type") + icon_type: IconType | None = Field(default=None, description="Icon type") icon: str | None = Field(default=None, description="Icon") icon_background: str | None = Field(default=None, description="Icon background color") @@ -594,7 +594,7 @@ class AppApi(Resource): args_dict: AppService.ArgsDict = { "name": args.name, "description": args.description or "", - "icon_type": args.icon_type or "", + "icon_type": args.icon_type, "icon": args.icon or "", "icon_background": args.icon_background or "", "use_icon_as_answer_icon": args.use_icon_as_answer_icon or False, diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index 2ca1275a8a..e0f1759e5e 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -19,6 +19,7 @@ class RateLimit: _REQUEST_MAX_ALIVE_TIME = 10 * 60 # 10 minutes _ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL = 5 * 60 # recalculate request_count from request_detail every 5 minutes _instance_dict: dict[str, "RateLimit"] = {} + max_active_requests: int def __new__(cls, client_id: str, max_active_requests: int): if client_id not in cls._instance_dict: @@ -27,7 +28,13 @@ class RateLimit: return cls._instance_dict[client_id] def __init__(self, client_id: str, max_active_requests: int): + flush_cache = hasattr(self, "max_active_requests") and self.max_active_requests != max_active_requests self.max_active_requests = max_active_requests + # Only flush here if this instance has already been fully initialized, + # i.e. the Redis key attributes exist. Otherwise, rely on the flush at + # the end of initialization below. + if flush_cache and hasattr(self, "active_requests_key") and hasattr(self, "max_active_requests_key"): + self.flush_cache(use_local_value=True) # must be called after max_active_requests is set if self.disabled(): return @@ -41,8 +48,6 @@ class RateLimit: self.flush_cache(use_local_value=True) def flush_cache(self, use_local_value=False): - if self.disabled(): - return self.last_recalculate_time = time.time() # flush max active requests if use_local_value or not redis_client.exists(self.max_active_requests_key): @@ -50,7 +55,8 @@ class RateLimit: else: self.max_active_requests = int(redis_client.get(self.max_active_requests_key).decode("utf-8")) redis_client.expire(self.max_active_requests_key, timedelta(days=1)) - + if self.disabled(): + return # flush max active requests (in-transit request list) if not redis_client.exists(self.active_requests_key): return diff --git a/api/dify_graph/nodes/llm/llm_utils.py b/api/dify_graph/nodes/llm/llm_utils.py index 2be391a424..8682c3682c 100644 --- a/api/dify_graph/nodes/llm/llm_utils.py +++ b/api/dify_graph/nodes/llm/llm_utils.py @@ -1,6 +1,9 @@ from __future__ import annotations -from collections.abc import Sequence +import json +import logging +import re +from collections.abc import Mapping, Sequence from typing import Any, cast from core.model_manager import ModelInstance @@ -36,6 +39,11 @@ from .exc import ( ) from .protocols import TemplateRenderer +logger = logging.getLogger(__name__) + +VARIABLE_PATTERN = re.compile(r"\{\{#[^#]+#\}\}") +MAX_RESOLVED_VALUE_LENGTH = 1024 + def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity: model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema( @@ -475,3 +483,61 @@ def _append_file_prompts( prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents) else: prompt_messages.append(UserPromptMessage(content=file_prompts)) + + +def _coerce_resolved_value(raw: str) -> int | float | bool | str: + """Try to restore the original type from a resolved template string. + + Variable references are always resolved to text, but completion params may + expect numeric or boolean values (e.g. a variable that holds "0.7" mapped to + the ``temperature`` parameter). This helper attempts a JSON parse so that + ``"0.7"`` → ``0.7``, ``"true"`` → ``True``, etc. Plain strings that are not + valid JSON literals are returned as-is. + """ + stripped = raw.strip() + if not stripped: + return raw + + try: + parsed: object = json.loads(stripped) + except (json.JSONDecodeError, ValueError): + return raw + + if isinstance(parsed, (int, float, bool)): + return parsed + return raw + + +def resolve_completion_params_variables( + completion_params: Mapping[str, Any], + variable_pool: VariablePool, +) -> dict[str, Any]: + """Resolve variable references (``{{#node_id.var#}}``) in string-typed completion params. + + Security notes: + - Resolved values are length-capped to ``MAX_RESOLVED_VALUE_LENGTH`` to + prevent denial-of-service through excessively large variable payloads. + - This follows the same ``VariablePool.convert_template`` pattern used across + Dify (Answer Node, HTTP Request Node, Agent Node, etc.). The downstream + model plugin receives these values as structured JSON key-value pairs — they + are never concatenated into raw HTTP headers or SQL queries. + - Numeric/boolean coercion is applied so that variables holding ``"0.7"`` are + restored to their native type rather than sent as a bare string. + """ + resolved: dict[str, Any] = {} + for key, value in completion_params.items(): + if isinstance(value, str) and VARIABLE_PATTERN.search(value): + segment_group = variable_pool.convert_template(value) + text = segment_group.text + if len(text) > MAX_RESOLVED_VALUE_LENGTH: + logger.warning( + "Resolved value for param '%s' truncated from %d to %d chars", + key, + len(text), + MAX_RESOLVED_VALUE_LENGTH, + ) + text = text[:MAX_RESOLVED_VALUE_LENGTH] + resolved[key] = _coerce_resolved_value(text) + else: + resolved[key] = value + return resolved diff --git a/api/dify_graph/nodes/llm/node.py b/api/dify_graph/nodes/llm/node.py index 5ed90ed7e3..a5492aee6b 100644 --- a/api/dify_graph/nodes/llm/node.py +++ b/api/dify_graph/nodes/llm/node.py @@ -202,6 +202,10 @@ class LLMNode(Node[LLMNodeData]): # fetch model config model_instance = self._model_instance + # Resolve variable references in string-typed completion params + model_instance.parameters = llm_utils.resolve_completion_params_variables( + model_instance.parameters, variable_pool + ) model_name = model_instance.model_name model_provider = model_instance.provider model_stop = model_instance.stop diff --git a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py index 3913a27697..e6e8a44d06 100644 --- a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py @@ -164,6 +164,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): ) model_instance = self._model_instance + # Resolve variable references in string-typed completion params + model_instance.parameters = llm_utils.resolve_completion_params_variables( + model_instance.parameters, variable_pool + ) if not isinstance(model_instance.model_type_instance, LargeLanguageModel): raise InvalidModelTypeError("Model is not a Large Language Model") diff --git a/api/dify_graph/nodes/question_classifier/question_classifier_node.py b/api/dify_graph/nodes/question_classifier/question_classifier_node.py index 59d0a2a4d8..928618fdbc 100644 --- a/api/dify_graph/nodes/question_classifier/question_classifier_node.py +++ b/api/dify_graph/nodes/question_classifier/question_classifier_node.py @@ -114,6 +114,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): variables = {"query": query} # fetch model instance model_instance = self._model_instance + # Resolve variable references in string-typed completion params + model_instance.parameters = llm_utils.resolve_completion_params_variables( + model_instance.parameters, variable_pool + ) memory = self._memory # fetch instruction node_data.instruction = node_data.instruction or "" diff --git a/api/libs/login.py b/api/libs/login.py index bd5cb5f30d..dce332b01d 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -18,15 +18,23 @@ if TYPE_CHECKING: from models.model import EndUser +def _resolve_current_user() -> EndUser | Account | None: + """ + Resolve the current user proxy to its underlying user object. + This keeps unit tests working when they patch `current_user` directly + instead of bootstrapping a full Flask-Login manager. + """ + user_proxy = current_user + get_current_object = getattr(user_proxy, "_get_current_object", None) + return get_current_object() if callable(get_current_object) else user_proxy # type: ignore + + def current_account_with_tenant(): """ Resolve the underlying account for the current user proxy and ensure tenant context exists. Allows tests to supply plain Account mocks without the LocalProxy helper. """ - user_proxy = current_user - - get_current_object = getattr(user_proxy, "_get_current_object", None) - user = get_current_object() if callable(get_current_object) else user_proxy # type: ignore + user = _resolve_current_user() if not isinstance(user, Account): raise ValueError("current_user must be an Account instance") @@ -79,9 +87,10 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue] if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: return current_app.ensure_sync(func)(*args, **kwargs) - user = _get_user() + user = _resolve_current_user() if user is None or not user.is_authenticated: return current_app.login_manager.unauthorized() # type: ignore + g._login_user = user # we put csrf validation here for less conflicts # TODO: maybe find a better place for it. check_csrf_token(request, user.id) diff --git a/api/services/app_service.py b/api/services/app_service.py index c5d1479a20..69c7c0c95a 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -241,7 +241,7 @@ class AppService: class ArgsDict(TypedDict): name: str description: str - icon_type: str + icon_type: IconType | str | None icon: str icon_background: str use_icon_as_answer_icon: bool @@ -257,7 +257,13 @@ class AppService: assert current_user is not None app.name = args["name"] app.description = args["description"] - app.icon_type = IconType(args["icon_type"]) if args["icon_type"] else None + icon_type = args.get("icon_type") + if icon_type is None: + resolved_icon_type = app.icon_type + else: + resolved_icon_type = IconType(icon_type) + + app.icon_type = resolved_icon_type app.icon = args["icon"] app.icon_background = args["icon_background"] app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index d79f80c009..9ca8729b77 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from constants.model_template import default_app_templates from models import Account -from models.model import App, Site +from models.model import App, IconType, Site from services.account_service import AccountService, TenantService from tests.test_containers_integration_tests.helpers import generate_valid_password @@ -463,6 +463,109 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by + def test_update_app_should_preserve_icon_type_when_omitted( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """ + Test update_app keeps the persisted icon_type when the update payload omits it. + """ + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=generate_valid_password(fake), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + from services.app_service import AppService + + app_service = AppService() + app = app_service.create_app( + tenant.id, + { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + }, + account, + ) + + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.id = account.id + mock_current_user.current_tenant_id = account.current_tenant_id + + with patch("services.app_service.current_user", mock_current_user): + updated_app = app_service.update_app( + app, + { + "name": "Updated App Name", + "description": "Updated app description", + "icon_type": None, + "icon": "🔄", + "icon_background": "#FF8C42", + "use_icon_as_answer_icon": True, + }, + ) + + assert updated_app.icon_type == IconType.EMOJI + + def test_update_app_should_reject_empty_icon_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """ + Test update_app rejects an explicit empty icon_type. + """ + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=generate_valid_password(fake), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + from services.app_service import AppService + + app_service = AppService() + app = app_service.create_app( + tenant.id, + { + "name": fake.company(), + "description": fake.text(max_nb_chars=100), + "mode": "chat", + "icon_type": "emoji", + "icon": "🎯", + "icon_background": "#45B7D1", + }, + account, + ) + + mock_current_user = create_autospec(Account, instance=True) + mock_current_user.id = account.id + mock_current_user.current_tenant_id = account.current_tenant_id + + with patch("services.app_service.current_user", mock_current_user): + with pytest.raises(ValueError): + app_service.update_app( + app, + { + "name": "Updated App Name", + "description": "Updated app description", + "icon_type": "", + "icon": "🔄", + "icon_background": "#FF8C42", + "use_icon_as_answer_icon": True, + }, + ) + def test_update_app_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app name update. diff --git a/api/tests/unit_tests/controllers/console/app/test_app_apis.py b/api/tests/unit_tests/controllers/console/app/test_app_apis.py index beb8ff55a5..1d1e119fd6 100644 --- a/api/tests/unit_tests/controllers/console/app/test_app_apis.py +++ b/api/tests/unit_tests/controllers/console/app/test_app_apis.py @@ -7,14 +7,19 @@ from __future__ import annotations import uuid from types import SimpleNamespace -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest +from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound +from controllers.console import console_ns from controllers.console.app import ( annotation as annotation_module, ) +from controllers.console.app import ( + app as app_module, +) from controllers.console.app import ( completion as completion_module, ) @@ -203,6 +208,48 @@ class TestCompletionEndpoints: method(app_model=MagicMock(id="app-1")) +class TestAppEndpoints: + """Tests for app endpoints.""" + + def test_app_put_should_preserve_icon_type_when_payload_omits_it(self, app, monkeypatch): + api = app_module.AppApi() + method = _unwrap(api.put) + payload = { + "name": "Updated App", + "description": "Updated description", + "icon": "🤖", + "icon_background": "#FFFFFF", + } + app_service = MagicMock() + app_service.update_app.return_value = SimpleNamespace() + response_model = MagicMock() + response_model.model_dump.return_value = {"id": "app-1"} + + monkeypatch.setattr(app_module, "AppService", lambda: app_service) + monkeypatch.setattr(app_module.AppDetailWithSite, "model_validate", MagicMock(return_value=response_model)) + + with ( + app.test_request_context("/console/api/apps/app-1", method="PUT", json=payload), + patch.object(type(console_ns), "payload", payload), + ): + response = method(app_model=SimpleNamespace(icon_type=app_module.IconType.EMOJI)) + + assert response == {"id": "app-1"} + assert app_service.update_app.call_args.args[1]["icon_type"] is None + + def test_update_app_payload_should_reject_empty_icon_type(self): + with pytest.raises(ValidationError): + app_module.UpdateAppPayload.model_validate( + { + "name": "Updated App", + "description": "Updated description", + "icon_type": "", + "icon": "🤖", + "icon_background": "#FFFFFF", + } + ) + + # ========== OpsTrace Tests ========== class TestOpsTraceEndpoints: """Tests for ops_trace endpoint.""" diff --git a/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py b/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py index 3db10c1c72..538b130cac 100644 --- a/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py +++ b/api/tests/unit_tests/core/app/features/rate_limiting/test_rate_limit.py @@ -68,8 +68,8 @@ class TestRateLimit: assert rate_limit.disabled() assert not hasattr(rate_limit, "initialized") - def test_should_skip_reinitialization_of_existing_instance(self, redis_patch): - """Test that existing instance doesn't reinitialize.""" + def test_should_flush_cache_when_reinitializing_existing_instance(self, redis_patch): + """Test existing instance refreshes Redis cache on reinitialization.""" redis_patch.configure_mock( **{ "exists.return_value": False, @@ -82,7 +82,37 @@ class TestRateLimit: RateLimit("client1", 10) + redis_patch.setex.assert_called_once_with( + "dify:rate_limit:client1:max_active_requests", + timedelta(days=1), + 10, + ) + + def test_should_reinitialize_after_being_disabled(self, redis_patch): + """Test disabled instance can be reinitialized and writes max_active_requests to Redis.""" + redis_patch.configure_mock( + **{ + "exists.return_value": False, + "setex.return_value": True, + } + ) + + # First construct with max_active_requests = 0 (disabled), which should skip initialization. + RateLimit("client1", 0) + + # Redis should not have been written to during disabled initialization. redis_patch.setex.assert_not_called() + redis_patch.reset_mock() + + # Reinitialize with a positive max_active_requests value; this should not raise + # and must write the max_active_requests key to Redis. + RateLimit("client1", 10) + + redis_patch.setex.assert_called_once_with( + "dify:rate_limit:client1:max_active_requests", + timedelta(days=1), + 10, + ) def test_should_be_disabled_when_max_requests_is_zero_or_negative(self): """Test disabled state for zero or negative limits.""" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py index 618a498659..acecbf4944 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py @@ -3,7 +3,11 @@ from unittest import mock import pytest from core.model_manager import ModelInstance -from dify_graph.model_runtime.entities import ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent +from dify_graph.model_runtime.entities import ( + ImagePromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, +) from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage from dify_graph.nodes.llm import llm_utils from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage @@ -11,6 +15,15 @@ from dify_graph.nodes.llm.exc import NoPromptFoundError from dify_graph.runtime import VariablePool +@pytest.fixture +def variable_pool() -> VariablePool: + pool = VariablePool.empty() + pool.add(["node1", "output"], "resolved_value") + pool.add(["node2", "text"], "hello world") + pool.add(["start", "user_input"], "dynamic_param") + return pool + + def _fetch_prompt_messages_with_mocked_content(content): variable_pool = VariablePool.empty() model_instance = mock.MagicMock(spec=ModelInstance) @@ -53,6 +66,159 @@ def _fetch_prompt_messages_with_mocked_content(content): ) +class TestTypeCoercionViaResolve: + """Type coercion is tested through the public resolve_completion_params_variables API.""" + + def test_numeric_string_coerced_to_float(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "0.7") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == 0.7 + + def test_integer_string_coerced_to_int(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "1024") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == 1024 + + def test_boolean_string_coerced_to_bool(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "true") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] is True + + def test_plain_string_stays_string(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "json_object") + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == "json_object" + + def test_json_object_string_stays_string(self): + pool = VariablePool.empty() + pool.add(["n", "v"], '{"key": "val"}') + result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool) + assert result["p"] == '{"key": "val"}' + + def test_mixed_text_and_variable_stays_string(self): + pool = VariablePool.empty() + pool.add(["n", "v"], "0.7") + result = llm_utils.resolve_completion_params_variables({"p": "val={{#n.v#}}"}, pool) + assert result["p"] == "val=0.7" + + +class TestResolveCompletionParamsVariables: + def test_plain_string_values_unchanged(self, variable_pool: VariablePool): + params = {"response_format": "json", "custom_param": "static_value"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"response_format": "json", "custom_param": "static_value"} + + def test_numeric_values_unchanged(self, variable_pool: VariablePool): + params = {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024} + + def test_boolean_values_unchanged(self, variable_pool: VariablePool): + params = {"stream": True, "echo": False} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"stream": True, "echo": False} + + def test_list_values_unchanged(self, variable_pool: VariablePool): + params = {"stop": ["Human:", "Assistant:"]} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"stop": ["Human:", "Assistant:"]} + + def test_single_variable_reference_resolved(self, variable_pool: VariablePool): + params = {"response_format": "{{#node1.output#}}"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"response_format": "resolved_value"} + + def test_multiple_variable_references_resolved(self, variable_pool: VariablePool): + params = { + "param_a": "{{#node1.output#}}", + "param_b": "{{#node2.text#}}", + } + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"param_a": "resolved_value", "param_b": "hello world"} + + def test_mixed_text_and_variable_resolved(self, variable_pool: VariablePool): + params = {"prompt_prefix": "prefix_{{#node1.output#}}_suffix"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"prompt_prefix": "prefix_resolved_value_suffix"} + + def test_mixed_params_types(self, variable_pool: VariablePool): + """Non-string params pass through; string params with variables get resolved.""" + params = { + "temperature": 0.7, + "response_format": "{{#node1.output#}}", + "custom_string": "no_vars_here", + "max_tokens": 512, + "stop": ["\n"], + } + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == { + "temperature": 0.7, + "response_format": "resolved_value", + "custom_string": "no_vars_here", + "max_tokens": 512, + "stop": ["\n"], + } + + def test_empty_params(self, variable_pool: VariablePool): + result = llm_utils.resolve_completion_params_variables({}, variable_pool) + + assert result == {} + + def test_unresolvable_variable_keeps_selector_text(self): + """When a referenced variable doesn't exist in the pool, convert_template + falls back to the raw selector path (e.g. 'nonexistent.var').""" + pool = VariablePool.empty() + params = {"format": "{{#nonexistent.var#}}"} + + result = llm_utils.resolve_completion_params_variables(params, pool) + + assert result["format"] == "nonexistent.var" + + def test_multiple_variables_in_single_value(self, variable_pool: VariablePool): + params = {"combined": "{{#node1.output#}} and {{#node2.text#}}"} + + result = llm_utils.resolve_completion_params_variables(params, variable_pool) + + assert result == {"combined": "resolved_value and hello world"} + + def test_original_params_not_mutated(self, variable_pool: VariablePool): + original = {"response_format": "{{#node1.output#}}", "temperature": 0.5} + original_copy = dict(original) + + _ = llm_utils.resolve_completion_params_variables(original, variable_pool) + + assert original == original_copy + + def test_long_value_truncated(self): + pool = VariablePool.empty() + pool.add(["node1", "big"], "x" * 2000) + params = {"param": "{{#node1.big#}}"} + + result = llm_utils.resolve_completion_params_variables(params, pool) + + assert len(result["param"]) == llm_utils.MAX_RESOLVED_VALUE_LENGTH + + def test_fetch_prompt_messages_skips_messages_when_all_contents_are_filtered_out(): with pytest.raises(NoPromptFoundError): _fetch_prompt_messages_with_mocked_content( diff --git a/api/tests/unit_tests/libs/test_login.py b/api/tests/unit_tests/libs/test_login.py index a94ba0c00b..8613d89215 100644 --- a/api/tests/unit_tests/libs/test_login.py +++ b/api/tests/unit_tests/libs/test_login.py @@ -130,6 +130,25 @@ class TestLoginRequired: assert result == "Synced content" setup_app.ensure_sync.assert_called_once() + @patch("libs.login.check_csrf_token", mock_csrf_check) + def test_patched_current_user_without_login_manager(self, app: Flask): + """Test that patched current_user bypasses login manager bootstrapping.""" + + @login_required + def protected_view(): + return "Protected content" + + mock_user = MockUser("test_user", is_authenticated=True) + mock_proxy = MagicMock() + mock_proxy._get_current_object.return_value = mock_user + + with app.test_request_context(): + app.ensure_sync = lambda func: func + with patch("libs.login.current_user", mock_proxy): + result = protected_view() + assert result == "Protected content" + assert g._login_user == mock_user + @patch("libs.login.check_csrf_token", mock_csrf_check) def test_flask_1_compatibility(self, setup_app: Flask): """Test Flask 1.x compatibility without ensure_sync.""" diff --git a/api/tests/unit_tests/services/test_app_service.py b/api/tests/unit_tests/services/test_app_service.py index bff8dc92c6..95fc28b1e7 100644 --- a/api/tests/unit_tests/services/test_app_service.py +++ b/api/tests/unit_tests/services/test_app_service.py @@ -9,7 +9,7 @@ import pytest from core.errors.error import ProviderTokenNotInitError from models import Account, Tenant -from models.model import App, AppMode +from models.model import App, AppMode, IconType from services.app_service import AppService @@ -411,6 +411,7 @@ class TestAppServiceGetAndUpdate: # Assert assert updated is app + assert updated.icon_type == IconType.IMAGE assert renamed is app assert iconed is app assert site_same is app @@ -419,6 +420,79 @@ class TestAppServiceGetAndUpdate: assert api_changed is app assert mock_db.session.commit.call_count >= 5 + def test_update_app_should_preserve_icon_type_when_not_provided(self, service: AppService) -> None: + """Test update_app keeps the existing icon_type when the payload omits it.""" + # Arrange + app = cast( + App, + SimpleNamespace( + name="old", + description="old", + icon_type=IconType.EMOJI, + icon="a", + icon_background="#111", + use_icon_as_answer_icon=False, + max_active_requests=1, + ), + ) + args = { + "name": "new", + "description": "new-desc", + "icon_type": None, + "icon": "new-icon", + "icon_background": "#222", + "use_icon_as_answer_icon": True, + "max_active_requests": 5, + } + user = SimpleNamespace(id="user-1") + + with ( + patch("services.app_service.current_user", user), + patch("services.app_service.db") as mock_db, + patch("services.app_service.naive_utc_now", return_value="now"), + ): + # Act + updated = service.update_app(app, args) + + # Assert + assert updated is app + assert updated.icon_type == IconType.EMOJI + mock_db.session.commit.assert_called_once() + + def test_update_app_should_reject_empty_icon_type(self, service: AppService) -> None: + """Test update_app rejects an explicit empty icon_type.""" + app = cast( + App, + SimpleNamespace( + name="old", + description="old", + icon_type=IconType.EMOJI, + icon="a", + icon_background="#111", + use_icon_as_answer_icon=False, + max_active_requests=1, + ), + ) + args = { + "name": "new", + "description": "new-desc", + "icon_type": "", + "icon": "new-icon", + "icon_background": "#222", + "use_icon_as_answer_icon": True, + "max_active_requests": 5, + } + user = SimpleNamespace(id="user-1") + + with ( + patch("services.app_service.current_user", user), + patch("services.app_service.db") as mock_db, + ): + with pytest.raises(ValueError): + service.update_app(app, args) + + mock_db.session.commit.assert_not_called() + class TestAppServiceDeleteAndMeta: """Test suite for delete and metadata methods.""" diff --git a/web/app/components/apps/index.tsx b/web/app/components/apps/index.tsx index dce9de190d..b6ca60bd7b 100644 --- a/web/app/components/apps/index.tsx +++ b/web/app/components/apps/index.tsx @@ -8,12 +8,14 @@ import AppListContext from '@/context/app-list-context' import useDocumentTitle from '@/hooks/use-document-title' import { useImportDSL } from '@/hooks/use-import-dsl' import { DSLImportMode } from '@/models/app' +import dynamic from '@/next/dynamic' import { fetchAppDetail } from '@/service/explore' -import DSLConfirmModal from '../app/create-from-dsl-modal/dsl-confirm-modal' -import CreateAppModal from '../explore/create-app-modal' -import TryApp from '../explore/try-app' import List from './list' +const DSLConfirmModal = dynamic(() => import('../app/create-from-dsl-modal/dsl-confirm-modal'), { ssr: false }) +const CreateAppModal = dynamic(() => import('../explore/create-app-modal'), { ssr: false }) +const TryApp = dynamic(() => import('../explore/try-app'), { ssr: false }) + const Apps = () => { const { t } = useTranslation() diff --git a/web/app/components/apps/list.tsx b/web/app/components/apps/list.tsx index 0d52bd468c..2ef344f816 100644 --- a/web/app/components/apps/list.tsx +++ b/web/app/components/apps/list.tsx @@ -5,11 +5,11 @@ import { useDebounceFn } from 'ahooks' import { parseAsStringLiteral, useQueryState } from 'nuqs' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' +import Checkbox from '@/app/components/base/checkbox' import Input from '@/app/components/base/input' import TabSliderNew from '@/app/components/base/tab-slider-new' import TagFilter from '@/app/components/base/tag-management/filter' import { useStore as useTagStore } from '@/app/components/base/tag-management/store' -import CheckboxWithLabel from '@/app/components/datasets/create/website/base/checkbox-with-label' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' @@ -205,12 +205,12 @@ const List: FC = ({ options={options} />
- + { - return IS_CLOUD_EDITION && !!AMPLITUDE_API_KEY -} - // Map URL pathname to English page name for consistent Amplitude tracking const getEnglishPageName = (pathname: string): string => { // Remove leading slash and get the first segment @@ -59,7 +54,7 @@ const AmplitudeProvider: FC = ({ }) => { useEffect(() => { // Only enable in Saas edition with valid API key - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return // Initialize Amplitude diff --git a/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx b/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx index b30da72091..5835634eb7 100644 --- a/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx +++ b/web/app/components/base/amplitude/__tests__/AmplitudeProvider.spec.tsx @@ -2,14 +2,24 @@ import * as amplitude from '@amplitude/analytics-browser' import { sessionReplayPlugin } from '@amplitude/plugin-session-replay-browser' import { render } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import AmplitudeProvider, { isAmplitudeEnabled } from '../AmplitudeProvider' +import AmplitudeProvider from '../AmplitudeProvider' const mockConfig = vi.hoisted(() => ({ AMPLITUDE_API_KEY: 'test-api-key', IS_CLOUD_EDITION: true, })) -vi.mock('@/config', () => mockConfig) +vi.mock('@/config', () => ({ + get AMPLITUDE_API_KEY() { + return mockConfig.AMPLITUDE_API_KEY + }, + get IS_CLOUD_EDITION() { + return mockConfig.IS_CLOUD_EDITION + }, + get isAmplitudeEnabled() { + return mockConfig.IS_CLOUD_EDITION && !!mockConfig.AMPLITUDE_API_KEY + }, +})) vi.mock('@amplitude/analytics-browser', () => ({ init: vi.fn(), @@ -27,22 +37,6 @@ describe('AmplitudeProvider', () => { mockConfig.IS_CLOUD_EDITION = true }) - describe('isAmplitudeEnabled', () => { - it('returns true when cloud edition and api key present', () => { - expect(isAmplitudeEnabled()).toBe(true) - }) - - it('returns false when cloud edition but no api key', () => { - mockConfig.AMPLITUDE_API_KEY = '' - expect(isAmplitudeEnabled()).toBe(false) - }) - - it('returns false when not cloud edition', () => { - mockConfig.IS_CLOUD_EDITION = false - expect(isAmplitudeEnabled()).toBe(false) - }) - }) - describe('Component', () => { it('initializes amplitude when enabled', () => { render() diff --git a/web/app/components/base/amplitude/__tests__/index.spec.ts b/web/app/components/base/amplitude/__tests__/index.spec.ts deleted file mode 100644 index 2d7ad6ab84..0000000000 --- a/web/app/components/base/amplitude/__tests__/index.spec.ts +++ /dev/null @@ -1,32 +0,0 @@ -import { describe, expect, it } from 'vitest' -import AmplitudeProvider, { isAmplitudeEnabled } from '../AmplitudeProvider' -import indexDefault, { - isAmplitudeEnabled as indexIsAmplitudeEnabled, - resetUser, - setUserId, - setUserProperties, - trackEvent, -} from '../index' -import { - resetUser as utilsResetUser, - setUserId as utilsSetUserId, - setUserProperties as utilsSetUserProperties, - trackEvent as utilsTrackEvent, -} from '../utils' - -describe('Amplitude index exports', () => { - it('exports AmplitudeProvider as default', () => { - expect(indexDefault).toBe(AmplitudeProvider) - }) - - it('exports isAmplitudeEnabled', () => { - expect(indexIsAmplitudeEnabled).toBe(isAmplitudeEnabled) - }) - - it('exports utils', () => { - expect(resetUser).toBe(utilsResetUser) - expect(setUserId).toBe(utilsSetUserId) - expect(setUserProperties).toBe(utilsSetUserProperties) - expect(trackEvent).toBe(utilsTrackEvent) - }) -}) diff --git a/web/app/components/base/amplitude/__tests__/utils.spec.ts b/web/app/components/base/amplitude/__tests__/utils.spec.ts index ecbc57e387..f1ff5db1e3 100644 --- a/web/app/components/base/amplitude/__tests__/utils.spec.ts +++ b/web/app/components/base/amplitude/__tests__/utils.spec.ts @@ -20,8 +20,10 @@ const MockIdentify = vi.hoisted(() => }, ) -vi.mock('../AmplitudeProvider', () => ({ - isAmplitudeEnabled: () => mockState.enabled, +vi.mock('@/config', () => ({ + get isAmplitudeEnabled() { + return mockState.enabled + }, })) vi.mock('@amplitude/analytics-browser', () => ({ diff --git a/web/app/components/base/amplitude/index.ts b/web/app/components/base/amplitude/index.ts index acc792339e..44cbf728e2 100644 --- a/web/app/components/base/amplitude/index.ts +++ b/web/app/components/base/amplitude/index.ts @@ -1,2 +1,2 @@ -export { default, isAmplitudeEnabled } from './AmplitudeProvider' +export { default } from './lazy-amplitude-provider' export { resetUser, setUserId, setUserProperties, trackEvent } from './utils' diff --git a/web/app/components/base/amplitude/lazy-amplitude-provider.tsx b/web/app/components/base/amplitude/lazy-amplitude-provider.tsx new file mode 100644 index 0000000000..5dfa0e7b53 --- /dev/null +++ b/web/app/components/base/amplitude/lazy-amplitude-provider.tsx @@ -0,0 +1,11 @@ +'use client' + +import type { FC } from 'react' +import type { IAmplitudeProps } from './AmplitudeProvider' +import dynamic from '@/next/dynamic' + +const AmplitudeProvider = dynamic(() => import('./AmplitudeProvider'), { ssr: false }) + +const LazyAmplitudeProvider: FC = props => + +export default LazyAmplitudeProvider diff --git a/web/app/components/base/amplitude/utils.ts b/web/app/components/base/amplitude/utils.ts index 57b96243ec..8faa8e852e 100644 --- a/web/app/components/base/amplitude/utils.ts +++ b/web/app/components/base/amplitude/utils.ts @@ -1,5 +1,5 @@ import * as amplitude from '@amplitude/analytics-browser' -import { isAmplitudeEnabled } from './AmplitudeProvider' +import { isAmplitudeEnabled } from '@/config' /** * Track custom event @@ -7,7 +7,7 @@ import { isAmplitudeEnabled } from './AmplitudeProvider' * @param eventProperties Event properties (optional) */ export const trackEvent = (eventName: string, eventProperties?: Record) => { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return amplitude.track(eventName, eventProperties) } @@ -17,7 +17,7 @@ export const trackEvent = (eventName: string, eventProperties?: Record { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return amplitude.setUserId(userId) } @@ -27,7 +27,7 @@ export const setUserId = (userId: string) => { * @param properties User properties */ export const setUserProperties = (properties: Record) => { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return const identifyEvent = new amplitude.Identify() Object.entries(properties).forEach(([key, value]) => { @@ -40,7 +40,7 @@ export const setUserProperties = (properties: Record) => { * Reset user (e.g., when user logs out) */ export const resetUser = () => { - if (!isAmplitudeEnabled()) + if (!isAmplitudeEnabled) return amplitude.reset() } diff --git a/web/app/components/billing/pricing/__tests__/header.spec.tsx b/web/app/components/billing/pricing/__tests__/header.spec.tsx index 0aadc3b0ce..cb8991ff42 100644 --- a/web/app/components/billing/pricing/__tests__/header.spec.tsx +++ b/web/app/components/billing/pricing/__tests__/header.spec.tsx @@ -1,12 +1,14 @@ import { fireEvent, render, screen } from '@testing-library/react' import * as React from 'react' -import { Dialog } from '@/app/components/base/ui/dialog' +import { Dialog, DialogContent } from '@/app/components/base/ui/dialog' import Header from '../header' function renderHeader(onClose: () => void) { return render( -
+ +
+
, ) } @@ -24,7 +26,7 @@ describe('Header', () => { expect(screen.getByText('billing.plansCommon.title.plans')).toBeInTheDocument() expect(screen.getByText('billing.plansCommon.title.description')).toBeInTheDocument() - expect(screen.getByRole('button')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.close' })).toBeInTheDocument() }) }) @@ -33,7 +35,7 @@ describe('Header', () => { const handleClose = vi.fn() renderHeader(handleClose) - fireEvent.click(screen.getByRole('button')) + fireEvent.click(screen.getByRole('button', { name: 'common.operation.close' })) expect(handleClose).toHaveBeenCalledTimes(1) }) @@ -41,11 +43,11 @@ describe('Header', () => { describe('Edge Cases', () => { it('should render structural elements with translation keys', () => { - const { container } = renderHeader(vi.fn()) + renderHeader(vi.fn()) - expect(container.querySelector('span')).toBeInTheDocument() - expect(container.querySelector('p')).toBeInTheDocument() - expect(screen.getByRole('button')).toBeInTheDocument() + expect(screen.getByText('billing.plansCommon.title.plans')).toBeInTheDocument() + expect(screen.getByText('billing.plansCommon.title.description')).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'common.operation.close' })).toBeInTheDocument() }) }) }) diff --git a/web/app/components/billing/pricing/__tests__/index.spec.tsx b/web/app/components/billing/pricing/__tests__/index.spec.tsx index 36848cd463..a8d0a4329e 100644 --- a/web/app/components/billing/pricing/__tests__/index.spec.tsx +++ b/web/app/components/billing/pricing/__tests__/index.spec.tsx @@ -68,6 +68,7 @@ describe('Pricing', () => { it('should render pricing header and localized footer link', () => { render() + expect(screen.getByRole('dialog', { name: 'billing.plansCommon.title.plans' })).toBeInTheDocument() expect(screen.getByText('billing.plansCommon.title.plans')).toBeInTheDocument() expect(screen.getByTestId('pricing-link')).toHaveAttribute('href', 'https://dify.ai/en/pricing#plans-and-features') }) diff --git a/web/app/components/billing/pricing/footer.tsx b/web/app/components/billing/pricing/footer.tsx index 0d3fd965b0..1422ec1cb1 100644 --- a/web/app/components/billing/pricing/footer.tsx +++ b/web/app/components/billing/pricing/footer.tsx @@ -28,8 +28,9 @@ const Footer = ({ {t('plansCommon.comparePlanAndFeatures', { ns: 'billing' })} diff --git a/web/app/components/billing/pricing/header.tsx b/web/app/components/billing/pricing/header.tsx index d0ffe100db..5ab1895667 100644 --- a/web/app/components/billing/pricing/header.tsx +++ b/web/app/components/billing/pricing/header.tsx @@ -1,5 +1,6 @@ import * as React from 'react' import { useTranslation } from 'react-i18next' +import { DialogDescription, DialogTitle } from '@/app/components/base/ui/dialog' import { cn } from '@/utils/classnames' import Button from '../../base/button' import DifyLogo from '../../base/logo/dify-logo' @@ -18,24 +19,25 @@ const Header = ({
-
+ - {t('plansCommon.title.plans', { ns: 'billing' })} - +
-

+ {t('plansCommon.title.description', { ns: 'billing' })} -

+ @@ -119,7 +125,6 @@ describe('ModelParameterModal', () => { beforeEach(() => { vi.clearAllMocks() - isAPIKeySet = true isRulesLoading = false parameterRules = [ { @@ -233,6 +238,26 @@ describe('ModelParameterModal', () => { expect(screen.getByTestId('model-selector')).toBeInTheDocument() }) + it('should pass nodesOutputVars and availableNodes to ParameterItem', () => { + const mockNodesOutputVars = [{ nodeId: 'n1', title: 'Node', vars: [] }] + const mockAvailableNodes = [{ id: 'n1', data: { title: 'Node', type: 'llm' } }] + + render( + , + ) + + fireEvent.click(screen.getByText('Open Settings')) + + const paramEl = screen.getByTestId('param-temperature') + expect(paramEl).toHaveAttribute('data-has-nodes-output-vars', 'true') + expect(paramEl).toHaveAttribute('data-has-available-nodes', 'true') + }) + it('should support custom triggers, workflow mode, and missing default model values', async () => { render( ({ @@ -18,6 +23,29 @@ vi.mock('@/app/components/base/tag-input', () => ({ ), })) +let promptEditorOnChange: ((text: string) => void) | undefined +let capturedWorkflowNodesMap: Record | undefined + +vi.mock('@/app/components/base/prompt-editor', () => ({ + default: ({ value, onChange, workflowVariableBlock }: { + value: string + onChange: (text: string) => void + workflowVariableBlock?: { + show: boolean + variables: NodeOutPutVar[] + workflowNodesMap?: Record + } + }) => { + promptEditorOnChange = onChange + capturedWorkflowNodesMap = workflowVariableBlock?.workflowNodesMap + return ( +
+ {value} +
+ ) + }, +})) + describe('ParameterItem', () => { const createRule = (overrides: Partial = {}): ModelParameterRule => ({ name: 'temp', @@ -30,9 +58,10 @@ describe('ParameterItem', () => { beforeEach(() => { vi.clearAllMocks() + promptEditorOnChange = undefined + capturedWorkflowNodesMap = undefined }) - // Float tests it('should render float controls and clamp numeric input to max', () => { const onChange = vi.fn() render() @@ -50,7 +79,6 @@ describe('ParameterItem', () => { expect(onChange).toHaveBeenCalledWith(0.1) }) - // Int tests it('should render int controls and clamp numeric input', () => { const onChange = vi.fn() render() @@ -75,22 +103,17 @@ describe('ParameterItem', () => { it('should render int input without slider if min or max is missing', () => { render() expect(screen.queryByRole('slider')).not.toBeInTheDocument() - // No max -> precision step expect(screen.getByRole('spinbutton')).toHaveAttribute('step', '0') }) - // Slider events (uses generic value mock for slider) it('should handle slide change and clamp values', () => { const onChange = vi.fn() render() - // Test that the actual slider triggers the onChange logic correctly - // The implementation of Slider uses onChange(val) directly via the mock fireEvent.click(screen.getByTestId('slider-btn')) expect(onChange).toHaveBeenCalledWith(2) }) - // Text & String tests it('should render exact string input and propagate text changes', () => { const onChange = vi.fn() render() @@ -109,21 +132,17 @@ describe('ParameterItem', () => { it('should render select for string with options', () => { render() - // Select renders the selected value in the trigger expect(screen.getByText('a')).toBeInTheDocument() }) - // Tag Tests it('should render tag input for tag type', () => { const onChange = vi.fn() render() expect(screen.getByText('placeholder')).toBeInTheDocument() - // Trigger mock tag input fireEvent.click(screen.getByTestId('tag-input')) expect(onChange).toHaveBeenCalledWith(['tag1', 'tag2']) }) - // Boolean tests it('should render boolean radios and update value on click', () => { const onChange = vi.fn() render() @@ -131,7 +150,6 @@ describe('ParameterItem', () => { expect(onChange).toHaveBeenCalledWith(false) }) - // Switch tests it('should call onSwitch with current value when optional switch is toggled off', () => { const onSwitch = vi.fn() render() @@ -146,7 +164,6 @@ describe('ParameterItem', () => { expect(screen.queryByRole('switch')).not.toBeInTheDocument() }) - // Default Value Fallbacks (rendering without value) it('should use default values if value is undefined', () => { const { rerender } = render() expect(screen.getByRole('spinbutton')).toHaveValue(0.5) @@ -158,26 +175,102 @@ describe('ParameterItem', () => { expect(screen.getByText('True')).toBeInTheDocument() expect(screen.getByText('False')).toBeInTheDocument() - // Without default - rerender() // min is 0 by default in createRule + rerender() expect(screen.getByRole('spinbutton')).toHaveValue(0) }) - // Input Blur it('should reset input to actual bound value on blur', () => { render() const input = screen.getByRole('spinbutton') - // change local state (which triggers clamp internally to let's say 1.4 -> 1 but leaves input text, though handleInputChange updates local state) - // Actually our test fires a change so localValue = 1, then blur sets it fireEvent.change(input, { target: { value: '5' } }) fireEvent.blur(input) expect(input).toHaveValue(1) }) - // Unsupported it('should render no input for unsupported parameter type', () => { render() expect(screen.queryByRole('textbox')).not.toBeInTheDocument() expect(screen.queryByRole('spinbutton')).not.toBeInTheDocument() }) + + describe('workflow variable reference', () => { + const mockNodesOutputVars: NodeOutPutVar[] = [ + { nodeId: 'node1', title: 'LLM Node', vars: [] }, + ] + const mockAvailableNodes: Node[] = [ + { id: 'node1', type: 'custom', position: { x: 0, y: 0 }, data: { title: 'LLM Node', type: BlockEnum.LLM } } as Node, + { id: 'start', type: 'custom', position: { x: 0, y: 0 }, data: { title: 'Start', type: BlockEnum.Start } } as Node, + ] + + it('should build workflowNodesMap and render PromptEditor for string type', () => { + const onChange = vi.fn() + render( + , + ) + + const editor = screen.getByTestId('prompt-editor') + expect(editor).toBeInTheDocument() + expect(editor).toHaveAttribute('data-has-workflow-vars', 'true') + expect(capturedWorkflowNodesMap).toBeDefined() + expect(capturedWorkflowNodesMap!.node1.title).toBe('LLM Node') + expect(capturedWorkflowNodesMap!.sys.title).toBe('workflow.blocks.start') + expect(capturedWorkflowNodesMap!.sys.type).toBe(BlockEnum.Start) + + promptEditorOnChange?.('updated text') + expect(onChange).toHaveBeenCalledWith('updated text') + }) + + it('should build workflowNodesMap and render PromptEditor for text type', () => { + const onChange = vi.fn() + render( + , + ) + + const editor = screen.getByTestId('prompt-editor') + expect(editor).toBeInTheDocument() + expect(editor).toHaveAttribute('data-has-workflow-vars', 'true') + expect(capturedWorkflowNodesMap).toBeDefined() + + promptEditorOnChange?.('new long text') + expect(onChange).toHaveBeenCalledWith('new long text') + }) + + it('should fall back to plain input when not in workflow mode for string type', () => { + render( + , + ) + + expect(screen.queryByTestId('prompt-editor')).not.toBeInTheDocument() + expect(screen.getByRole('textbox')).toBeInTheDocument() + }) + + it('should return undefined workflowNodesMap when not in workflow mode', () => { + render( + , + ) + + expect(capturedWorkflowNodesMap).toBeUndefined() + }) + }) }) diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx index 6b4018e2aa..ccb2c67a0d 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx @@ -9,6 +9,10 @@ import type { } from '../declarations' import type { ParameterValue } from './parameter-item' import type { TriggerProps } from './trigger' +import type { + Node, + NodeOutPutVar, +} from '@/app/components/workflow/types' import { useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arrows' @@ -45,6 +49,8 @@ export type ModelParameterModalProps = { readonly?: boolean isInWorkflow?: boolean scope?: string + nodesOutputVars?: NodeOutPutVar[] + availableNodes?: Node[] } const ModelParameterModal: FC = ({ @@ -61,11 +67,18 @@ const ModelParameterModal: FC = ({ renderTrigger, readonly, isInWorkflow, + nodesOutputVars, + availableNodes, }) => { const { t } = useTranslation() const [open, setOpen] = useState(false) const settingsIconRef = useRef(null) - const { data: parameterRulesData, isLoading } = useModelParameterRules(provider, modelId) + const { + data: parameterRulesData, + isPending, + isLoading, + } = useModelParameterRules(provider, modelId) + const isRulesLoading = isPending || isLoading const { currentProvider, currentModel, @@ -191,7 +204,7 @@ const ModelParameterModal: FC = ({ }
{ - isLoading + isRulesLoading ?
: ( [ @@ -205,6 +218,8 @@ const ModelParameterModal: FC = ({ onChange={v => handleParamChange(parameter.name, v)} onSwitch={(checked, assignValue) => handleSwitch(parameter.name, checked, assignValue)} isInWorkflow={isInWorkflow} + nodesOutputVars={nodesOutputVars} + availableNodes={availableNodes} /> )) ) @@ -213,7 +228,7 @@ const ModelParameterModal: FC = ({ ) } { - !parameterRules.length && isLoading && ( + !parameterRules.length && isRulesLoading && (
) } diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx index 86fb6d81d0..01e3f45371 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/parameter-item.tsx @@ -1,11 +1,18 @@ import type { ModelParameterRule } from '../declarations' -import { useEffect, useRef, useState } from 'react' +import type { + Node, + NodeOutPutVar, +} from '@/app/components/workflow/types' +import { useEffect, useMemo, useRef, useState } from 'react' +import { useTranslation } from 'react-i18next' +import PromptEditor from '@/app/components/base/prompt-editor' import Radio from '@/app/components/base/radio' import Slider from '@/app/components/base/slider' import Switch from '@/app/components/base/switch' import TagInput from '@/app/components/base/tag-input' import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/app/components/base/ui/select' import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip' +import { BlockEnum } from '@/app/components/workflow/types' import { cn } from '@/utils/classnames' import { useLanguage } from '../hooks' import { isNullOrUndefined } from '../utils' @@ -18,18 +25,43 @@ type ParameterItemProps = { onChange?: (value: ParameterValue) => void onSwitch?: (checked: boolean, assignValue: ParameterValue) => void isInWorkflow?: boolean + nodesOutputVars?: NodeOutPutVar[] + availableNodes?: Node[] } + function ParameterItem({ parameterRule, value, onChange, onSwitch, isInWorkflow, + nodesOutputVars, + availableNodes = [], }: ParameterItemProps) { + const { t } = useTranslation() const language = useLanguage() const [localValue, setLocalValue] = useState(value) const numberInputRef = useRef(null) + const workflowNodesMap = useMemo(() => { + if (!isInWorkflow || !availableNodes.length) + return undefined + + return availableNodes.reduce>>((acc, node) => { + acc[node.id] = { + title: node.data.title, + type: node.data.type, + } + if (node.data.type === BlockEnum.Start) { + acc.sys = { + title: t('blocks.start', { ns: 'workflow' }), + type: BlockEnum.Start, + } + } + return acc + }, {}) + }, [availableNodes, isInWorkflow, t]) + const getDefaultValue = () => { let defaultValue: ParameterValue @@ -196,6 +228,25 @@ function ParameterItem({ } if (parameterRule.type === 'string' && !parameterRule.options?.length) { + if (isInWorkflow && nodesOutputVars) { + return ( +
+ { handleInputChange(text) }} + workflowVariableBlock={{ + show: true, + variables: nodesOutputVars, + workflowNodesMap, + }} + editable + /> +
+ ) + } + return ( + { handleInputChange(text) }} + workflowVariableBlock={{ + show: true, + variables: nodesOutputVars, + workflowNodesMap, + }} + editable + /> +
+ ) + } + return (