diff --git a/.github/workflows/expose_service_ports.sh b/.github/workflows/expose_service_ports.sh index fa0fd2ee8c..e7d5f60288 100755 --- a/.github/workflows/expose_service_ports.sh +++ b/.github/workflows/expose_service_ports.sh @@ -14,4 +14,4 @@ yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.ya yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml -echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss" +echo "Ports exposed for sandbox, weaviate (HTTP 8080, gRPC 50051), tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss" diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index e32f2814eb..aeea446c6e 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -22,7 +22,7 @@ from core.errors.error import ( from core.model_runtime.errors.invoke import InvokeError from core.workflow.graph_engine.manager import GraphEngineManager from libs import helper -from libs.login import current_user +from libs.login import current_user as current_user_ from models.model import AppMode, InstalledApp from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError @@ -31,6 +31,8 @@ from .. import console_ns logger = logging.getLogger(__name__) +current_user = current_user_._get_current_object() # type: ignore + @console_ns.route("/installed-apps//workflows/run") class InstalledAppWorkflowRunApi(InstalledAppResource): diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 2fa28711c3..8572a6dc9b 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -303,7 +303,12 @@ def edit_permission_required(f: Callable[P, R]): def decorated_function(*args: P.args, **kwargs: P.kwargs): from werkzeug.exceptions import Forbidden - current_user, _ = current_account_with_tenant() + from libs.login import current_user + from models import Account + + user = current_user._get_current_object() # type: ignore + if not isinstance(user, Account): + raise Forbidden() if not current_user.has_edit_permission: raise Forbidden() return f(*args, **kwargs) diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 47d297e194..002415a7db 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -1,11 +1,9 @@ import logging from threading import Lock -from typing import Union import contexts from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_provider import DatasourcePluginProviderController -from core.datasource.entities.common_entities import I18nObject from core.datasource.entities.datasource_entities import DatasourceProviderType from core.datasource.errors import DatasourceProviderNotFoundError from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController @@ -18,11 +16,6 @@ logger = logging.getLogger(__name__) class DatasourceManager: - _builtin_provider_lock = Lock() - _hardcoded_providers: dict[str, DatasourcePluginProviderController] = {} - _builtin_providers_loaded = False - _builtin_tools_labels: dict[str, Union[I18nObject, None]] = {} - @classmethod def get_datasource_plugin_provider( cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 4793d2bb50..d2d8fcf964 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -250,7 +250,6 @@ class WeaviateVector(BaseVector): ) ) - batch_size = max(1, int(dify_config.WEAVIATE_BATCH_SIZE or 100)) with col.batch.dynamic() as batch: for obj in objs: batch.add_object(properties=obj.properties, uuid=obj.uuid, vector=obj.vector) @@ -348,7 +347,10 @@ class WeaviateVector(BaseVector): for obj in res.objects: properties = dict(obj.properties or {}) text = properties.pop(Field.TEXT_KEY.value, "") - distance = (obj.metadata.distance if obj.metadata else None) or 1.0 + if obj.metadata and obj.metadata.distance is not None: + distance = obj.metadata.distance + else: + distance = 1.0 score = 1.0 - distance if score > score_threshold: diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 5adf04611d..50c2327004 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -3,6 +3,7 @@ import logging from collections.abc import Generator from typing import Any +from flask import has_request_context from sqlalchemy import select from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod @@ -18,7 +19,8 @@ from core.tools.errors import ToolInvokeError from extensions.ext_database import db from factories.file_factory import build_from_mapping from libs.login import current_user -from models.model import App +from models import Account, Tenant +from models.model import App, EndUser from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -79,11 +81,16 @@ class WorkflowTool(Tool): generator = WorkflowAppGenerator() assert self.runtime is not None assert self.runtime.invoke_from is not None - assert current_user is not None + + user = self._resolve_user(user_id=user_id) + + if user is None: + raise ToolInvokeError("User not found") + result = generator.generate( app_model=app, workflow=workflow, - user=current_user, + user=user, args={"inputs": tool_parameters, "files": files}, invoke_from=self.runtime.invoke_from, streaming=False, @@ -123,6 +130,51 @@ class WorkflowTool(Tool): label=self.label, ) + def _resolve_user(self, user_id: str) -> Account | EndUser | None: + """ + Resolve user object in both HTTP and worker contexts. + + In HTTP context: dereference the current_user LocalProxy (can return Account or EndUser). + In worker context: load Account from database by user_id (only returns Account, never EndUser). + + Returns: + Account | EndUser | None: The resolved user object, or None if resolution fails. + """ + if has_request_context(): + return self._resolve_user_from_request() + else: + return self._resolve_user_from_database(user_id=user_id) + + def _resolve_user_from_request(self) -> Account | EndUser | None: + """ + Resolve user from Flask request context. + """ + try: + # Note: `current_user` is a LocalProxy. Never compare it with None directly. + return getattr(current_user, "_get_current_object", lambda: current_user)() + except Exception as e: + logger.warning("Failed to resolve user from request context: %s", e) + return None + + def _resolve_user_from_database(self, user_id: str) -> Account | None: + """ + Resolve user from database (worker/Celery context). + """ + + user_stmt = select(Account).where(Account.id == user_id) + user = db.session.scalar(user_stmt) + if not user: + return None + + tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id) + tenant = db.session.scalar(tenant_stmt) + if not tenant: + return None + + user.current_tenant = tenant + + return user + def _get_workflow(self, app_id: str, version: str) -> Workflow: """ get the workflow by app id and version diff --git a/api/core/workflow/graph_engine/orchestration/dispatcher.py b/api/core/workflow/graph_engine/orchestration/dispatcher.py index 8340c10b49..f3570855ce 100644 --- a/api/core/workflow/graph_engine/orchestration/dispatcher.py +++ b/api/core/workflow/graph_engine/orchestration/dispatcher.py @@ -99,6 +99,8 @@ class Dispatcher: self._execution_coordinator.check_commands() self._event_queue.task_done() except queue.Empty: + # Process commands even when no new events arrive so abort requests are not missed + self._execution_coordinator.check_commands() # Check if execution is complete if self._execution_coordinator.is_execution_complete(): break diff --git a/api/libs/login.py b/api/libs/login.py index 2c75ef9297..d0e81a3441 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -1,6 +1,6 @@ from collections.abc import Callable from functools import wraps -from typing import Union, cast +from typing import Any from flask import current_app, g, has_request_context, request from flask_login.config import EXEMPT_METHODS # type: ignore @@ -10,16 +10,21 @@ from configs import dify_config from models import Account from models.model import EndUser -#: A proxy for the current user. If no user is logged in, this will be an -#: anonymous user -current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user())) - def current_account_with_tenant(): - if not isinstance(current_user, Account): + """ + Resolve the underlying account for the current user proxy and ensure tenant context exists. + Allows tests to supply plain Account mocks without the LocalProxy helper. + """ + user_proxy = current_user + + get_current_object = getattr(user_proxy, "_get_current_object", None) + user = get_current_object() if callable(get_current_object) else user_proxy # type: ignore + + if not isinstance(user, Account): raise ValueError("current_user must be an Account instance") - assert current_user.current_tenant_id is not None, "The tenant information should be loaded." - return current_user, current_user.current_tenant_id + assert user.current_tenant_id is not None, "The tenant information should be loaded." + return user, user.current_tenant_id from typing import ParamSpec, TypeVar @@ -81,3 +86,9 @@ def _get_user() -> EndUser | Account | None: return g._login_user # type: ignore return None + + +#: A proxy for the current user. If no user is logged in, this will be an +#: anonymous user +# NOTE: Any here, but use _get_current_object to check the fields +current_user: Any = LocalProxy(lambda: _get_user()) diff --git a/api/models/model.py b/api/models/model.py index 2373421e7d..af22ab9538 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -1479,7 +1479,7 @@ class EndUser(Base, UserMixin): sa.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"), ) - id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=True) type: Mapped[str] = mapped_column(String(255), nullable=False) diff --git a/api/pyproject.toml b/api/pyproject.toml index 1bbee3de5b..5289ef620d 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -168,6 +168,7 @@ dev = [ "mypy~=1.17.1", # "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved. "sseclient-py>=1.8.0", + "pytest-timeout>=2.4.0", ] ############################################################ @@ -216,7 +217,7 @@ vdb = [ "tidb-vector==0.0.9", "upstash-vector==0.6.0", "volcengine-compat~=1.0.0", - "weaviate-client>=4.0.0,<5.0.0", + "weaviate-client==4.17.0", "xinference-client~=1.2.2", "mo-vector~=0.1.13", "mysql-connector-python>=9.3.0", diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index 56e1207bc6..b53320bf5e 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -17,7 +17,6 @@ from core.plugin.impl.oauth import OAuthHandler from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter from extensions.ext_database import db from extensions.ext_redis import redis_client -from libs.login import current_account_with_tenant from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider from models.provider_ids import DatasourceProviderID from services.plugin.plugin_service import PluginService @@ -25,6 +24,16 @@ from services.plugin.plugin_service import PluginService logger = logging.getLogger(__name__) +def get_current_user(): + from libs.login import current_user + from models.account import Account + from models.model import EndUser + + if not isinstance(current_user._get_current_object(), (Account, EndUser)): # type: ignore + raise TypeError(f"current_user must be Account or EndUser, got {type(current_user).__name__}") + return current_user + + class DatasourceProviderService: """ Model Provider Service @@ -93,8 +102,6 @@ class DatasourceProviderService: """ get credential by id """ - current_user, _ = current_account_with_tenant() - with Session(db.engine) as session: if credential_id: datasource_provider = ( @@ -111,6 +118,7 @@ class DatasourceProviderService: return {} # refresh the credentials if datasource_provider.expires_at != -1 and (datasource_provider.expires_at - 60) < int(time.time()): + current_user = get_current_user() decrypted_credentials = self.decrypt_datasource_provider_credentials( tenant_id=tenant_id, datasource_provider=datasource_provider, @@ -159,8 +167,6 @@ class DatasourceProviderService: """ get all datasource credentials by provider """ - current_user, _ = current_account_with_tenant() - with Session(db.engine) as session: datasource_providers = ( session.query(DatasourceProvider) @@ -170,6 +176,7 @@ class DatasourceProviderService: ) if not datasource_providers: return [] + current_user = get_current_user() # refresh the credentials real_credentials_list = [] for datasource_provider in datasource_providers: @@ -608,7 +615,6 @@ class DatasourceProviderService: """ provider_name = provider_id.provider_name plugin_id = provider_id.plugin_id - current_user, _ = current_account_with_tenant() with Session(db.engine) as session: lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}" @@ -630,6 +636,7 @@ class DatasourceProviderService: raise ValueError("Authorization name is already exists") try: + current_user = get_current_user() self.provider_manager.validate_provider_credentials( tenant_id=tenant_id, user_id=current_user.id, @@ -907,7 +914,6 @@ class DatasourceProviderService: """ update datasource credentials. """ - current_user, _ = current_account_with_tenant() with Session(db.engine) as session: datasource_provider = ( @@ -944,6 +950,7 @@ class DatasourceProviderService: for key, value in credentials.items() } try: + current_user = get_current_user() self.provider_manager.validate_provider_credentials( tenant_id=tenant_id, user_id=current_user.id, diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py new file mode 100644 index 0000000000..e4db14623d --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_register_task.py @@ -0,0 +1,134 @@ +""" +TestContainers-based integration tests for mail_register_task.py + +This module provides integration tests for email registration tasks +using TestContainers to ensure real database and service interactions. +""" + +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from libs.email_i18n import EmailType +from tasks.mail_register_task import send_email_register_mail_task, send_email_register_mail_task_when_account_exist + + +class TestMailRegisterTask: + """Integration tests for mail_register_task using testcontainers.""" + + @pytest.fixture + def mock_mail_dependencies(self): + """Mock setup for mail service dependencies.""" + with ( + patch("tasks.mail_register_task.mail") as mock_mail, + patch("tasks.mail_register_task.get_email_i18n_service") as mock_get_email_service, + ): + # Setup mock mail service + mock_mail.is_inited.return_value = True + + # Setup mock email i18n service + mock_email_service = MagicMock() + mock_get_email_service.return_value = mock_email_service + + yield { + "mail": mock_mail, + "email_service": mock_email_service, + "get_email_service": mock_get_email_service, + } + + def test_send_email_register_mail_task_success(self, db_session_with_containers, mock_mail_dependencies): + """Test successful email registration mail sending.""" + fake = Faker() + language = "en-US" + to_email = fake.email() + code = fake.numerify("######") + + send_email_register_mail_task(language=language, to=to_email, code=code) + + mock_mail_dependencies["mail"].is_inited.assert_called_once() + mock_mail_dependencies["email_service"].send_email.assert_called_once_with( + email_type=EmailType.EMAIL_REGISTER, + language_code=language, + to=to_email, + template_context={ + "to": to_email, + "code": code, + }, + ) + + def test_send_email_register_mail_task_mail_not_initialized( + self, db_session_with_containers, mock_mail_dependencies + ): + """Test email registration task when mail service is not initialized.""" + mock_mail_dependencies["mail"].is_inited.return_value = False + + send_email_register_mail_task(language="en-US", to="test@example.com", code="123456") + + mock_mail_dependencies["get_email_service"].assert_not_called() + mock_mail_dependencies["email_service"].send_email.assert_not_called() + + def test_send_email_register_mail_task_exception_handling(self, db_session_with_containers, mock_mail_dependencies): + """Test email registration task exception handling.""" + mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error") + + fake = Faker() + to_email = fake.email() + code = fake.numerify("######") + + with patch("tasks.mail_register_task.logger") as mock_logger: + send_email_register_mail_task(language="en-US", to=to_email, code=code) + mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", to_email) + + def test_send_email_register_mail_task_when_account_exist_success( + self, db_session_with_containers, mock_mail_dependencies + ): + """Test successful email registration mail sending when account exists.""" + fake = Faker() + language = "en-US" + to_email = fake.email() + account_name = fake.name() + + with patch("tasks.mail_register_task.dify_config") as mock_config: + mock_config.CONSOLE_WEB_URL = "https://console.dify.ai" + + send_email_register_mail_task_when_account_exist(language=language, to=to_email, account_name=account_name) + + mock_mail_dependencies["email_service"].send_email.assert_called_once_with( + email_type=EmailType.EMAIL_REGISTER_WHEN_ACCOUNT_EXIST, + language_code=language, + to=to_email, + template_context={ + "to": to_email, + "login_url": "https://console.dify.ai/signin", + "reset_password_url": "https://console.dify.ai/reset-password", + "account_name": account_name, + }, + ) + + def test_send_email_register_mail_task_when_account_exist_mail_not_initialized( + self, db_session_with_containers, mock_mail_dependencies + ): + """Test account exist email task when mail service is not initialized.""" + mock_mail_dependencies["mail"].is_inited.return_value = False + + send_email_register_mail_task_when_account_exist( + language="en-US", to="test@example.com", account_name="Test User" + ) + + mock_mail_dependencies["get_email_service"].assert_not_called() + mock_mail_dependencies["email_service"].send_email.assert_not_called() + + def test_send_email_register_mail_task_when_account_exist_exception_handling( + self, db_session_with_containers, mock_mail_dependencies + ): + """Test account exist email task exception handling.""" + mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error") + + fake = Faker() + to_email = fake.email() + account_name = fake.name() + + with patch("tasks.mail_register_task.logger") as mock_logger: + send_email_register_mail_task_when_account_exist(language="en-US", to=to_email, account_name=account_name) + mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", to_email) diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index 17e3ebeea0..c68aad0b22 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -34,12 +34,17 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) + # Mock user resolution to avoid database access + from unittest.mock import Mock + + mock_user = Mock() + monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) + # replace `WorkflowAppGenerator.generate` 's return value. monkeypatch.setattr( "core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", lambda *args, **kwargs: {"data": {"error": "oops"}}, ) - monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None) with pytest.raises(ToolInvokeError) as exc_info: # WorkflowTool always returns a generator, so we need to iterate to diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py index fc38393e75..96926797ec 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py @@ -7,14 +7,11 @@ This test suite validates the behavior of a workflow that: 3. Handles multiple answer nodes with different outputs """ -import pytest - from core.workflow.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, - NodeRunSucceededEvent, ) from .test_mock_config import MockConfigBuilder @@ -29,7 +26,6 @@ class TestComplexBranchWorkflow: self.runner = TableTestRunner() self.fixture_path = "test_complex_branch" - @pytest.mark.skip(reason="output in this workflow can be random") def test_hello_branch_with_llm(self): """ Test when query contains 'hello' - should trigger true branch. @@ -41,42 +37,17 @@ class TestComplexBranchWorkflow: fixture_path=self.fixture_path, query="hello world", expected_outputs={ - "answer": f"{mock_text_1}contains 'hello'", + "answer": f"contains 'hello'{mock_text_1}", }, description="Basic hello case with parallel LLM execution", use_auto_mock=True, mock_config=(MockConfigBuilder().with_node_output("1755502777322", {"text": mock_text_1}).build()), - expected_event_sequence=[ - GraphRunStartedEvent, - # Start - NodeRunStartedEvent, - NodeRunSucceededEvent, - # If/Else (no streaming) - NodeRunStartedEvent, - NodeRunSucceededEvent, - # LLM (with streaming) - NodeRunStartedEvent, - ] - # LLM - + [NodeRunStreamChunkEvent] * (mock_text_1.count(" ") + 2) - + [ - # Answer's text - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Answer 2 - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], ), WorkflowTestCase( fixture_path=self.fixture_path, query="say hello to everyone", expected_outputs={ - "answer": "Mocked response for greetingcontains 'hello'", + "answer": "contains 'hello'Mocked response for greeting", }, description="Hello in middle of sentence", use_auto_mock=True, @@ -93,6 +64,35 @@ class TestComplexBranchWorkflow: for result in suite_result.results: assert result.success, f"Test '{result.test_case.description}' failed: {result.error}" assert result.actual_outputs + assert any(isinstance(event, GraphRunStartedEvent) for event in result.events) + assert any(isinstance(event, GraphRunSucceededEvent) for event in result.events) + + start_index = next( + idx for idx, event in enumerate(result.events) if isinstance(event, GraphRunStartedEvent) + ) + success_index = max( + idx for idx, event in enumerate(result.events) if isinstance(event, GraphRunSucceededEvent) + ) + assert start_index < success_index + + started_node_ids = {event.node_id for event in result.events if isinstance(event, NodeRunStartedEvent)} + assert {"1755502773326", "1755502777322"}.issubset(started_node_ids), ( + f"Branch or LLM nodes missing in events: {started_node_ids}" + ) + + assert any(isinstance(event, NodeRunStreamChunkEvent) for event in result.events), ( + "Expected streaming chunks from LLM execution" + ) + + llm_start_index = next( + idx + for idx, event in enumerate(result.events) + if isinstance(event, NodeRunStartedEvent) and event.node_id == "1755502777322" + ) + assert any( + idx > llm_start_index and isinstance(event, NodeRunStreamChunkEvent) + for idx, event in enumerate(result.events) + ), "Streaming chunks should follow LLM node start" def test_non_hello_branch_with_llm(self): """ diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher.py b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher.py index 830fc0884d..0d612e054f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher.py @@ -95,10 +95,10 @@ def _make_succeeded_event() -> NodeRunSucceededEvent: ) -def test_dispatcher_checks_commands_after_node_completion() -> None: - """Dispatcher should only check commands after node completion events.""" +def test_dispatcher_checks_commands_during_idle_and_on_completion() -> None: + """Dispatcher polls commands when idle and re-checks after completion events.""" started_checks = _run_dispatcher_for_event(_make_started_event()) succeeded_checks = _run_dispatcher_for_event(_make_succeeded_event()) - assert started_checks == 0 - assert succeeded_checks == 1 + assert started_checks == 1 + assert succeeded_checks == 2 diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 61ce640edd..94c638bb0f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -21,7 +21,6 @@ from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool -from core.workflow.graph import Graph from core.workflow.nodes.llm import llm_utils from core.workflow.nodes.llm.entities import ( ContextConfig, @@ -83,14 +82,6 @@ def graph_init_params() -> GraphInitParams: ) -@pytest.fixture -def graph() -> Graph: - # TODO: This fixture uses old Graph constructor parameters that are incompatible - # with the new queue-based engine. Need to rewrite for new engine architecture. - pytest.skip("Graph fixture incompatible with new queue-based engine - needs rewrite for ResponseStreamCoordinator") - return Graph() - - @pytest.fixture def graph_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( @@ -105,7 +96,7 @@ def graph_runtime_state() -> GraphRuntimeState: @pytest.fixture def llm_node( - llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState + llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState ) -> LLMNode: mock_file_saver = mock.MagicMock(spec=LLMFileSaver) node_config = { @@ -493,9 +484,7 @@ def test_handle_list_messages_basic(llm_node): @pytest.fixture -def llm_node_for_multimodal( - llm_node_data, graph_init_params, graph, graph_runtime_state -) -> tuple[LLMNode, LLMFileSaver]: +def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_state) -> tuple[LLMNode, LLMFileSaver]: mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver) node_config = { "id": "1", @@ -655,7 +644,7 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: gen = llm_node._save_multimodal_output_and_convert_result_to_markdown( contents=frozenset(["hello world"]), file_saver=mock_file_saver, file_outputs=[] ) - assert list(gen) == ["frozenset({'hello world'})"] + assert list(gen) == ["hello world"] mock_file_saver.save_binary_string.assert_not_called() mock_file_saver.save_remote_url.assert_not_called() diff --git a/api/tests/unit_tests/services/auth/test_auth_integration.py b/api/tests/unit_tests/services/auth/test_auth_integration.py index acfc5cc526..3832a0b8b2 100644 --- a/api/tests/unit_tests/services/auth/test_auth_integration.py +++ b/api/tests/unit_tests/services/auth/test_auth_integration.py @@ -181,14 +181,11 @@ class TestAuthIntegration: ) def test_all_providers_factory_creation(self, provider, credentials): """Test factory creation for all supported providers""" - try: - auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider) - assert auth_class is not None + auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider) + assert auth_class is not None - factory = ApiKeyAuthFactory(provider, credentials) - assert factory.auth is not None - except ImportError: - pytest.skip(f"Provider {provider} not implemented yet") + factory = ApiKeyAuthFactory(provider, credentials) + assert factory.auth is not None def _create_success_response(self, status_code=200): """Create successful HTTP response mock""" diff --git a/api/tests/unit_tests/services/test_metadata_bug_complete.py b/api/tests/unit_tests/services/test_metadata_bug_complete.py index 31fe9b2868..ee96305070 100644 --- a/api/tests/unit_tests/services/test_metadata_bug_complete.py +++ b/api/tests/unit_tests/services/test_metadata_bug_complete.py @@ -41,7 +41,10 @@ class TestMetadataBugCompleteValidation: mock_user.current_tenant_id = "tenant-123" mock_user.id = "user-456" - with patch("services.metadata_service.current_user", mock_user): + with patch( + "services.metadata_service.current_account_with_tenant", + return_value=(mock_user, mock_user.current_tenant_id), + ): # Should crash with TypeError with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): MetadataService.create_metadata("dataset-123", mock_metadata_args) @@ -51,7 +54,10 @@ class TestMetadataBugCompleteValidation: mock_user.current_tenant_id = "tenant-123" mock_user.id = "user-456" - with patch("services.metadata_service.current_user", mock_user): + with patch( + "services.metadata_service.current_account_with_tenant", + return_value=(mock_user, mock_user.current_tenant_id), + ): with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): MetadataService.update_metadata_name("dataset-123", "metadata-456", None) diff --git a/api/tests/unit_tests/services/test_metadata_nullable_bug.py b/api/tests/unit_tests/services/test_metadata_nullable_bug.py index c8cd7025c2..3d57737943 100644 --- a/api/tests/unit_tests/services/test_metadata_nullable_bug.py +++ b/api/tests/unit_tests/services/test_metadata_nullable_bug.py @@ -29,7 +29,10 @@ class TestMetadataNullableBug: mock_user.current_tenant_id = "tenant-123" mock_user.id = "user-456" - with patch("services.metadata_service.current_user", mock_user): + with patch( + "services.metadata_service.current_account_with_tenant", + return_value=(mock_user, mock_user.current_tenant_id), + ): # This should crash with TypeError when calling len(None) with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): MetadataService.create_metadata("dataset-123", mock_metadata_args) @@ -40,7 +43,10 @@ class TestMetadataNullableBug: mock_user.current_tenant_id = "tenant-123" mock_user.id = "user-456" - with patch("services.metadata_service.current_user", mock_user): + with patch( + "services.metadata_service.current_account_with_tenant", + return_value=(mock_user, mock_user.current_tenant_id), + ): # This should crash with TypeError when calling len(None) with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): MetadataService.update_metadata_name("dataset-123", "metadata-456", None) @@ -88,7 +94,10 @@ class TestMetadataNullableBug: mock_user.current_tenant_id = "tenant-123" mock_user.id = "user-456" - with patch("services.metadata_service.current_user", mock_user): + with patch( + "services.metadata_service.current_account_with_tenant", + return_value=(mock_user, mock_user.current_tenant_id), + ): # Step 4: Service layer crashes on len(None) with pytest.raises(TypeError, match="object of type 'NoneType' has no len"): MetadataService.create_metadata("dataset-123", mock_metadata_args) diff --git a/api/uv.lock b/api/uv.lock index f0ed31cfc3..2eb90b5e02 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1409,6 +1409,7 @@ dev = [ { name = "pytest-cov" }, { name = "pytest-env" }, { name = "pytest-mock" }, + { name = "pytest-timeout" }, { name = "ruff" }, { name = "scipy-stubs" }, { name = "sseclient-py" }, @@ -1600,6 +1601,7 @@ dev = [ { name = "pytest-cov", specifier = "~=4.1.0" }, { name = "pytest-env", specifier = "~=1.1.3" }, { name = "pytest-mock", specifier = "~=3.14.0" }, + { name = "pytest-timeout", specifier = ">=2.4.0" }, { name = "ruff", specifier = "~=0.14.0" }, { name = "scipy-stubs", specifier = ">=1.15.3.0" }, { name = "sseclient-py", specifier = ">=1.8.0" }, @@ -1684,7 +1686,7 @@ vdb = [ { name = "tidb-vector", specifier = "==0.0.9" }, { name = "upstash-vector", specifier = "==0.6.0" }, { name = "volcengine-compat", specifier = "~=1.0.0" }, - { name = "weaviate-client", specifier = ">=4.0.0,<5.0.0" }, + { name = "weaviate-client", specifier = "==4.17.0" }, { name = "xinference-client", specifier = "~=1.2.2" }, ] @@ -4996,6 +4998,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b2/05/77b60e520511c53d1c1ca75f1930c7dd8e971d0c4379b7f4b3f9644685ba/pytest_mock-3.14.1-py3-none-any.whl", hash = "sha256:178aefcd11307d874b4cd3100344e7e2d888d9791a6a1d9bfe90fbc1b74fd1d0", size = 9923, upload-time = "2025-05-26T13:58:43.487Z" }, ] +[[package]] +name = "pytest-timeout" +version = "2.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ac/82/4c9ecabab13363e72d880f2fb504c5f750433b2b6f16e99f4ec21ada284c/pytest_timeout-2.4.0.tar.gz", hash = "sha256:7e68e90b01f9eff71332b25001f85c75495fc4e3a836701876183c4bcfd0540a", size = 17973, upload-time = "2025-05-05T19:44:34.99Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2", size = 14382, upload-time = "2025-05-05T19:44:33.502Z" }, +] + [[package]] name = "python-calamine" version = "0.5.3" diff --git a/dev/pytest/pytest_artifacts.sh b/dev/pytest/pytest_artifacts.sh index 3086ef5cc4..29cacdcc07 100755 --- a/dev/pytest/pytest_artifacts.sh +++ b/dev/pytest/pytest_artifacts.sh @@ -4,4 +4,6 @@ set -x SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/../.." -pytest api/tests/artifact_tests/ +PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}" + +pytest --timeout "${PYTEST_TIMEOUT}" api/tests/artifact_tests/ diff --git a/dev/pytest/pytest_model_runtime.sh b/dev/pytest/pytest_model_runtime.sh index 2cbbbbfd81..fd68dbe697 100755 --- a/dev/pytest/pytest_model_runtime.sh +++ b/dev/pytest/pytest_model_runtime.sh @@ -4,7 +4,9 @@ set -x SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/../.." -pytest api/tests/integration_tests/model_runtime/anthropic \ +PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-180}" + +pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/model_runtime/anthropic \ api/tests/integration_tests/model_runtime/azure_openai \ api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm \ api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference \ diff --git a/dev/pytest/pytest_testcontainers.sh b/dev/pytest/pytest_testcontainers.sh index e55a436138..f92f8821bf 100755 --- a/dev/pytest/pytest_testcontainers.sh +++ b/dev/pytest/pytest_testcontainers.sh @@ -4,4 +4,6 @@ set -x SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/../.." -pytest api/tests/test_containers_integration_tests +PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}" + +pytest --timeout "${PYTEST_TIMEOUT}" api/tests/test_containers_integration_tests diff --git a/dev/pytest/pytest_tools.sh b/dev/pytest/pytest_tools.sh index d10934626f..989784f078 100755 --- a/dev/pytest/pytest_tools.sh +++ b/dev/pytest/pytest_tools.sh @@ -4,4 +4,6 @@ set -x SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/../.." -pytest api/tests/integration_tests/tools +PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}" + +pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/tools diff --git a/dev/pytest/pytest_unit_tests.sh b/dev/pytest/pytest_unit_tests.sh index 1a1819ca28..496cb40952 100755 --- a/dev/pytest/pytest_unit_tests.sh +++ b/dev/pytest/pytest_unit_tests.sh @@ -4,5 +4,7 @@ set -x SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/../.." +PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-20}" + # libs -pytest api/tests/unit_tests +pytest --timeout "${PYTEST_TIMEOUT}" api/tests/unit_tests diff --git a/dev/pytest/pytest_vdb.sh b/dev/pytest/pytest_vdb.sh index 7f617a9c05..3c11a079cc 100755 --- a/dev/pytest/pytest_vdb.sh +++ b/dev/pytest/pytest_vdb.sh @@ -4,7 +4,9 @@ set -x SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/../.." -pytest api/tests/integration_tests/vdb/chroma \ +PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-180}" + +pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/vdb/chroma \ api/tests/integration_tests/vdb/milvus \ api/tests/integration_tests/vdb/pgvecto_rs \ api/tests/integration_tests/vdb/pgvector \ diff --git a/dev/pytest/pytest_workflow.sh b/dev/pytest/pytest_workflow.sh index b63d49069f..941c8d3e7e 100755 --- a/dev/pytest/pytest_workflow.sh +++ b/dev/pytest/pytest_workflow.sh @@ -4,4 +4,6 @@ set -x SCRIPT_DIR="$(dirname "$(realpath "$0")")" cd "$SCRIPT_DIR/../.." -pytest api/tests/integration_tests/workflow +PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}" + +pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/workflow diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 27c428a38b..12dc3aa48d 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -24,6 +24,13 @@ services: volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage + # TODO: Remove this entrypoint override when weaviate-client 4.17.0 is included in the next Dify release + entrypoint: + - /bin/bash + - -c + - | + uv pip install --system weaviate-client==4.17.0 + exec /bin/bash /app/api/docker/entrypoint.sh networks: - ssrf_proxy_network - default @@ -51,6 +58,13 @@ services: volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage + # TODO: Remove this entrypoint override when weaviate-client 4.17.0 is included in the next Dify release + entrypoint: + - /bin/bash + - -c + - | + uv pip install --system weaviate-client==4.17.0 + exec /bin/bash /app/api/docker/entrypoint.sh networks: - ssrf_proxy_network - default @@ -331,7 +345,6 @@ services: weaviate: image: semitechnologies/weaviate:1.27.0 profiles: - - "" - weaviate restart: always volumes: diff --git a/docker/docker-compose.override.yml b/docker/docker-compose.override.yml deleted file mode 100644 index 8f2ab1cb43..0000000000 --- a/docker/docker-compose.override.yml +++ /dev/null @@ -1,9 +0,0 @@ -services: - api: - volumes: - - ../api/core/rag/datasource/vdb/weaviate/weaviate_vector.py:/app/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py:ro - command: > - sh -c " - pip install --no-cache-dir 'weaviate>=4.0.0' && - /bin/bash /entrypoint.sh - " diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 7988d65b2d..5c5448a33a 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -636,6 +636,13 @@ services: volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage + # TODO: Remove this entrypoint override when weaviate-client 4.17.0 is included in the next Dify release + entrypoint: + - /bin/bash + - -c + - | + uv pip install --system weaviate-client==4.17.0 + exec /bin/bash /app/api/docker/entrypoint.sh networks: - ssrf_proxy_network - default @@ -663,6 +670,13 @@ services: volumes: # Mount the storage directory to the container, for storing user files. - ./volumes/app/storage:/app/api/storage + # TODO: Remove this entrypoint override when weaviate-client 4.17.0 is included in the next Dify release + entrypoint: + - /bin/bash + - -c + - | + uv pip install --system weaviate-client==4.17.0 + exec /bin/bash /app/api/docker/entrypoint.sh networks: - ssrf_proxy_network - default @@ -943,7 +957,6 @@ services: weaviate: image: semitechnologies/weaviate:1.27.0 profiles: - - "" - weaviate restart: always volumes: diff --git a/docs/weaviate/WEAVIATE_MIGRATION_GUIDE/README.md b/docs/weaviate/WEAVIATE_MIGRATION_GUIDE/README.md new file mode 100644 index 0000000000..b2599e8c2e --- /dev/null +++ b/docs/weaviate/WEAVIATE_MIGRATION_GUIDE/README.md @@ -0,0 +1,187 @@ +# Weaviate Migration Guide: v1.19 → v1.27 + +## Overview + +Dify has upgraded from Weaviate v1.19 to v1.27 with the Python client updated from v3.24 to v4.17. + +## What Changed + +### Breaking Changes + +1. **Weaviate Server**: `1.19.0` → `1.27.0` +1. **Python Client**: `weaviate-client~=3.24.0` → `weaviate-client==4.17.0` +1. **gRPC Required**: Weaviate v1.27 requires gRPC port `50051` (in addition to HTTP port `8080`) +1. **Docker Compose**: Added temporary entrypoint overrides for client installation + +### Key Improvements + +- Faster vector operations via gRPC +- Improved batch processing +- Better error handling + +## Migration Steps + +### For Docker Users + +#### Step 1: Backup Your Data + +```bash +cd docker +docker compose down +sudo cp -r ./volumes/weaviate ./volumes/weaviate_backup_$(date +%Y%m%d) +``` + +#### Step 2: Update Dify + +```bash +git pull origin main +docker compose pull +``` + +#### Step 3: Start Services + +```bash +docker compose up -d +sleep 30 +curl http://localhost:8080/v1/meta +``` + +#### Step 4: Verify Migration + +```bash +# Check both ports are accessible +curl http://localhost:8080/v1/meta +netstat -tulpn | grep 50051 + +# Test in Dify UI: +# 1. Go to Knowledge Base +# 2. Test search functionality +# 3. Upload a test document +``` + +### For Source Installation + +#### Step 1: Update Dependencies + +```bash +cd api +uv sync --dev +uv run python -c "import weaviate; print(weaviate.__version__)" +# Should show: 4.17.0 +``` + +#### Step 2: Update Weaviate Server + +```bash +cd docker +docker compose -f docker-compose.middleware.yaml --profile weaviate up -d weaviate +curl http://localhost:8080/v1/meta +netstat -tulpn | grep 50051 +``` + +## Troubleshooting + +### Error: "No module named 'weaviate.classes'" + +**Solution**: + +```bash +cd api +uv sync --reinstall-package weaviate-client +uv run python -c "import weaviate; print(weaviate.__version__)" +# Should show: 4.17.0 +``` + +### Error: "gRPC health check failed" + +**Solution**: + +```bash +# Check Weaviate ports +docker ps | grep weaviate +# Should show: 0.0.0.0:8080->8080/tcp, 0.0.0.0:50051->50051/tcp + +# If missing gRPC port, add to docker-compose: +# ports: +# - "8080:8080" +# - "50051:50051" +``` + +### Error: "Weaviate version 1.19.0 is not supported" + +**Solution**: + +```bash +# Update Weaviate image in docker-compose +# Change: semitechnologies/weaviate:1.19.0 +# To: semitechnologies/weaviate:1.27.0 +docker compose down +docker compose up -d +``` + +### Data Migration Failed + +**Solution**: + +```bash +cd docker +docker compose down +sudo rm -rf ./volumes/weaviate +sudo cp -r ./volumes/weaviate_backup_YYYYMMDD ./volumes/weaviate +docker compose up -d +``` + +## Rollback Instructions + +```bash +# 1. Stop services +docker compose down + +# 2. Restore data backup +sudo rm -rf ./volumes/weaviate +sudo cp -r ./volumes/weaviate_backup_YYYYMMDD ./volumes/weaviate + +# 3. Checkout previous version +git checkout + +# 4. Restart services +docker compose up -d +``` + +## Compatibility + +| Component | Old Version | New Version | Compatible | +|-----------|-------------|-------------|------------| +| Weaviate Server | 1.19.0 | 1.27.0 | ✅ Yes | +| weaviate-client | ~3.24.0 | ==4.17.0 | ✅ Yes | +| Existing Data | v1.19 format | v1.27 format | ✅ Yes | + +## Testing Checklist + +Before deploying to production: + +- [ ] Backup all Weaviate data +- [ ] Test in staging environment +- [ ] Verify existing collections are accessible +- [ ] Test vector search functionality +- [ ] Test document upload and retrieval +- [ ] Monitor gRPC connection stability +- [ ] Check performance metrics + +## Support + +If you encounter issues: + +1. Check GitHub Issues: https://github.com/langgenius/dify/issues +1. Create a bug report with: + - Error messages + - Docker logs: `docker compose logs weaviate` + - Dify version + - Migration steps attempted + +## Important Notes + +- **Data Safety**: Existing vector data remains fully compatible +- **No Re-indexing**: No need to rebuild vector indexes +- **Temporary Workaround**: The entrypoint overrides are temporary until next Dify release +- **Performance**: May see improved performance due to gRPC usage diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx index 2201b28a2f..2b6bd73df0 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx @@ -100,7 +100,10 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut }) } } - + catch (e: any) { + if (e.code === 'authentication_failed') + Toast.notify({ type: 'error', message: e.message }) + } finally { setIsLoading(false) } diff --git a/web/app/components/base/param-item/top-k-item.tsx b/web/app/components/base/param-item/top-k-item.tsx index f59c0f273e..0bd552e48d 100644 --- a/web/app/components/base/param-item/top-k-item.tsx +++ b/web/app/components/base/param-item/top-k-item.tsx @@ -32,7 +32,7 @@ const TopKItem: FC = ({ }) => { const { t } = useTranslation() const handleParamChange = (key: string, value: number) => { - let notOutRangeValue = Number.parseFloat(value.toFixed(2)) + let notOutRangeValue = Number.parseInt(value.toFixed(0)) notOutRangeValue = Math.max(VALUE_LIMIT.min, notOutRangeValue) notOutRangeValue = Math.min(VALUE_LIMIT.max, notOutRangeValue) onChange(key, notOutRangeValue) diff --git a/web/app/components/base/textarea/index.tsx b/web/app/components/base/textarea/index.tsx index 63eae48e31..7813eb7209 100644 --- a/web/app/components/base/textarea/index.tsx +++ b/web/app/components/base/textarea/index.tsx @@ -25,8 +25,8 @@ export type TextareaProps = { destructive?: boolean styleCss?: CSSProperties ref?: React.Ref - onFocus?: () => void - onBlur?: () => void + onFocus?: React.FocusEventHandler + onBlur?: React.FocusEventHandler } & React.TextareaHTMLAttributes & VariantProps const Textarea = React.forwardRef( diff --git a/web/app/components/workflow/nodes/if-else/components/condition-list/condition-item.tsx b/web/app/components/workflow/nodes/if-else/components/condition-list/condition-item.tsx index 1a76622c57..252c9a7d77 100644 --- a/web/app/components/workflow/nodes/if-else/components/condition-list/condition-item.tsx +++ b/web/app/components/workflow/nodes/if-else/components/condition-list/condition-item.tsx @@ -234,6 +234,9 @@ const ConditionItem = ({ draft.varType = resolvedVarType draft.value = resolvedVarType === VarType.boolean ? false : '' draft.comparison_operator = getOperators(resolvedVarType)[0] + delete draft.key + delete draft.sub_variable_condition + delete draft.numberVarType setTimeout(() => setControlPromptEditorRerenderKey(Date.now())) }) doUpdateCondition(newCondition) diff --git a/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/top-k-and-score-threshold.tsx b/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/top-k-and-score-threshold.tsx index 9d46037d84..9eb1cb39c9 100644 --- a/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/top-k-and-score-threshold.tsx +++ b/web/app/components/workflow/nodes/knowledge-base/components/retrieval-setting/top-k-and-score-threshold.tsx @@ -1,8 +1,8 @@ -import { memo } from 'react' +import { memo, useCallback } from 'react' import { useTranslation } from 'react-i18next' import Tooltip from '@/app/components/base/tooltip' -import Input from '@/app/components/base/input' import Switch from '@/app/components/base/switch' +import { InputNumber } from '@/app/components/base/input-number' export type TopKAndScoreThresholdProps = { topK: number @@ -14,6 +14,24 @@ export type TopKAndScoreThresholdProps = { readonly?: boolean hiddenScoreThreshold?: boolean } + +const maxTopK = (() => { + const configValue = Number.parseInt(globalThis.document?.body?.getAttribute('data-public-top-k-max-value') || '', 10) + if (configValue && !isNaN(configValue)) + return configValue + return 10 +})() +const TOP_K_VALUE_LIMIT = { + amount: 1, + min: 1, + max: maxTopK, +} +const SCORE_THRESHOLD_VALUE_LIMIT = { + step: 0.01, + min: 0, + max: 1, +} + const TopKAndScoreThreshold = ({ topK, onTopKChange, @@ -25,18 +43,18 @@ const TopKAndScoreThreshold = ({ hiddenScoreThreshold, }: TopKAndScoreThresholdProps) => { const { t } = useTranslation() - const handleTopKChange = (e: React.ChangeEvent) => { - const value = Number(e.target.value) - if (Number.isNaN(value)) - return - onTopKChange?.(value) - } + const handleTopKChange = useCallback((value: number) => { + let notOutRangeValue = Number.parseInt(value.toFixed(0)) + notOutRangeValue = Math.max(TOP_K_VALUE_LIMIT.min, notOutRangeValue) + notOutRangeValue = Math.min(TOP_K_VALUE_LIMIT.max, notOutRangeValue) + onTopKChange?.(notOutRangeValue) + }, [onTopKChange]) - const handleScoreThresholdChange = (e: React.ChangeEvent) => { - const value = Number(e.target.value) - if (Number.isNaN(value)) - return - onScoreThresholdChange?.(value) + const handleScoreThresholdChange = (value: number) => { + let notOutRangeValue = Number.parseFloat(value.toFixed(2)) + notOutRangeValue = Math.max(SCORE_THRESHOLD_VALUE_LIMIT.min, notOutRangeValue) + notOutRangeValue = Math.min(SCORE_THRESHOLD_VALUE_LIMIT.max, notOutRangeValue) + onScoreThresholdChange?.(notOutRangeValue) } return ( @@ -49,11 +67,13 @@ const TopKAndScoreThreshold = ({ popupContent={t('appDebug.datasetConfig.top_kTip')} /> - { @@ -74,11 +94,13 @@ const TopKAndScoreThreshold = ({ popupContent={t('appDebug.datasetConfig.score_thresholdTip')} /> - ) diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-number.tsx b/web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-number.tsx index 7016e8bd2a..6421401a2a 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-number.tsx +++ b/web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-number.tsx @@ -18,7 +18,7 @@ type ConditionNumberProps = { nodesOutputVars: NodeOutPutVar[] availableNodes: Node[] isCommonVariable?: boolean - commonVariables: { name: string, type: string }[] + commonVariables: { name: string; type: string; value: string }[] } & ConditionValueMethodProps const ConditionNumber = ({ value, diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-string.tsx b/web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-string.tsx index cf85f1259b..d5cb06e690 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-string.tsx +++ b/web/app/components/workflow/nodes/knowledge-retrieval/components/metadata/condition-list/condition-string.tsx @@ -18,7 +18,7 @@ type ConditionStringProps = { nodesOutputVars: NodeOutPutVar[] availableNodes: Node[] isCommonVariable?: boolean - commonVariables: { name: string, type: string }[] + commonVariables: { name: string; type: string; value: string }[] } & ConditionValueMethodProps const ConditionString = ({ value, diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/types.ts b/web/app/components/workflow/nodes/knowledge-retrieval/types.ts index 1cae4ecd3b..65f7dc2493 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/types.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/types.ts @@ -128,6 +128,6 @@ export type MetadataShape = { availableNumberVars?: NodeOutPutVar[] availableNumberNodesWithParent?: Node[] isCommonVariable?: boolean - availableCommonStringVars?: { name: string; type: string; }[] - availableCommonNumberVars?: { name: string; type: string; }[] + availableCommonStringVars?: { name: string; type: string; value: string }[] + availableCommonNumberVars?: { name: string; type: string; value: string }[] } diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-importer.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-importer.tsx index 6368397b74..463d87d7d1 100644 --- a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-importer.tsx +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/json-importer.tsx @@ -24,7 +24,7 @@ const JsonImporter: FC = ({ const [open, setOpen] = useState(false) const [json, setJson] = useState('') const [parseError, setParseError] = useState(null) - const importBtnRef = useRef(null) + const importBtnRef = useRef(null) const advancedEditing = useVisualEditorStore(state => state.advancedEditing) const isAddingNewField = useVisualEditorStore(state => state.isAddingNewField) const { emit } = useMittContext() diff --git a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/context.tsx b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/context.tsx index 5bf4b22f11..268683aec3 100644 --- a/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/context.tsx +++ b/web/app/components/workflow/nodes/llm/components/json-schema-config-modal/visual-editor/context.tsx @@ -18,7 +18,7 @@ type VisualEditorProviderProps = { export const VisualEditorContext = createContext(null) export const VisualEditorContextProvider = ({ children }: VisualEditorProviderProps) => { - const storeRef = useRef() + const storeRef = useRef(null) if (!storeRef.current) storeRef.current = createVisualEditorStore() diff --git a/web/app/components/workflow/nodes/llm/use-config.ts b/web/app/components/workflow/nodes/llm/use-config.ts index 45635be3f2..d11fb6db28 100644 --- a/web/app/components/workflow/nodes/llm/use-config.ts +++ b/web/app/components/workflow/nodes/llm/use-config.ts @@ -23,7 +23,7 @@ const useConfig = (id: string, payload: LLMNodeType) => { const { nodesReadOnly: readOnly } = useNodesReadOnly() const isChatMode = useIsChatMode() - const defaultConfig = useStore(s => s.nodesDefaultConfigs)[payload.type] + const defaultConfig = useStore(s => s.nodesDefaultConfigs)?.[payload.type] const [defaultRolePrefix, setDefaultRolePrefix] = useState<{ user: string; assistant: string }>({ user: '', assistant: '' }) const { inputs, setInputs: doSetInputs } = useNodeCrud(id, payload) const inputRef = useRef(inputs) diff --git a/web/app/components/workflow/nodes/llm/utils.ts b/web/app/components/workflow/nodes/llm/utils.ts index 29591d76ad..10c287f86b 100644 --- a/web/app/components/workflow/nodes/llm/utils.ts +++ b/web/app/components/workflow/nodes/llm/utils.ts @@ -10,7 +10,7 @@ export const checkNodeValid = (_payload: LLMNodeType) => { export const getFieldType = (field: Field) => { const { type, items } = field - if(field.schemaType === 'file') return 'file' + if(field.schemaType === 'file') return Type.file if (type !== Type.array || !items) return type diff --git a/web/app/components/workflow/nodes/loop/components/condition-list/condition-item.tsx b/web/app/components/workflow/nodes/loop/components/condition-list/condition-item.tsx index 6e573093b7..4de07cc3da 100644 --- a/web/app/components/workflow/nodes/loop/components/condition-list/condition-item.tsx +++ b/web/app/components/workflow/nodes/loop/components/condition-list/condition-item.tsx @@ -196,6 +196,9 @@ const ConditionItem = ({ draft.varType = varItem.type draft.value = '' draft.comparison_operator = getOperators(varItem.type)[0] + delete draft.key + delete draft.sub_variable_condition + delete draft.numberVarType }) doUpdateCondition(newCondition) setOpen(false) diff --git a/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/import-from-tool.tsx b/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/import-from-tool.tsx index 8cb5a4b877..43b85f9ffc 100644 --- a/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/import-from-tool.tsx +++ b/web/app/components/workflow/nodes/parameter-extractor/components/extract-parameter/import-from-tool.tsx @@ -9,7 +9,10 @@ import BlockSelector from '../../../../block-selector' import type { Param, ParamType } from '../../types' import cn from '@/utils/classnames' import { useStore } from '@/app/components/workflow/store' -import type { PluginDefaultValue } from '@/app/components/workflow/block-selector/types' +import type { + PluginDefaultValue, + ToolDefaultValue, +} from '@/app/components/workflow/block-selector/types' import type { ToolParameter } from '@/app/components/tools/types' import { CollectionType } from '@/app/components/tools/types' import type { BlockEnum } from '@/app/components/workflow/types' @@ -44,9 +47,10 @@ const ImportFromTool: FC = ({ const workflowTools = useStore(s => s.workflowTools) const handleSelectTool = useCallback((_type: BlockEnum, toolInfo?: PluginDefaultValue) => { - if (!toolInfo || !('tool_name' in toolInfo)) + if (!toolInfo || 'datasource_name' in toolInfo || !('tool_name' in toolInfo)) return - const { provider_id, provider_type, tool_name: tool_name } = toolInfo! + + const { provider_id, provider_type, tool_name } = toolInfo as ToolDefaultValue const currentTools = (() => { switch (provider_type) { case CollectionType.builtIn: diff --git a/web/app/components/workflow/nodes/parameter-extractor/use-config.ts b/web/app/components/workflow/nodes/parameter-extractor/use-config.ts index 2caae83f2a..d3699853c2 100644 --- a/web/app/components/workflow/nodes/parameter-extractor/use-config.ts +++ b/web/app/components/workflow/nodes/parameter-extractor/use-config.ts @@ -27,7 +27,7 @@ const useConfig = (id: string, payload: ParameterExtractorNodeType) => { const { handleOutVarRenameChange } = useWorkflow() const isChatMode = useIsChatMode() - const defaultConfig = useStore(s => s.nodesDefaultConfigs)[payload.type] + const defaultConfig = useStore(s => s.nodesDefaultConfigs)?.[payload.type] const [defaultRolePrefix, setDefaultRolePrefix] = useState<{ user: string; assistant: string }>({ user: '', assistant: '' }) const { inputs, setInputs: doSetInputs } = useNodeCrud(id, payload) diff --git a/web/app/components/workflow/nodes/question-classifier/use-config.ts b/web/app/components/workflow/nodes/question-classifier/use-config.ts index b4907641b5..5106f373a8 100644 --- a/web/app/components/workflow/nodes/question-classifier/use-config.ts +++ b/web/app/components/workflow/nodes/question-classifier/use-config.ts @@ -20,7 +20,7 @@ const useConfig = (id: string, payload: QuestionClassifierNodeType) => { const updateNodeInternals = useUpdateNodeInternals() const { nodesReadOnly: readOnly } = useNodesReadOnly() const isChatMode = useIsChatMode() - const defaultConfig = useStore(s => s.nodesDefaultConfigs)[payload.type] + const defaultConfig = useStore(s => s.nodesDefaultConfigs)?.[payload.type] const { getBeforeNodesInSameBranch } = useWorkflow() const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start) const startNodeId = startNode?.id diff --git a/web/app/components/workflow/nodes/template-transform/use-config.ts b/web/app/components/workflow/nodes/template-transform/use-config.ts index fa7eb81baf..dc012a3844 100644 --- a/web/app/components/workflow/nodes/template-transform/use-config.ts +++ b/web/app/components/workflow/nodes/template-transform/use-config.ts @@ -13,7 +13,7 @@ import useAvailableVarList from '@/app/components/workflow/nodes/_base/hooks/use const useConfig = (id: string, payload: TemplateTransformNodeType) => { const { nodesReadOnly: readOnly } = useNodesReadOnly() - const defaultConfig = useStore(s => s.nodesDefaultConfigs)[payload.type] + const defaultConfig = useStore(s => s.nodesDefaultConfigs)?.[payload.type] const { inputs, setInputs: doSetInputs } = useNodeCrud(id, payload) const inputsRef = useRef(inputs) diff --git a/web/app/components/workflow/run/agent-log/agent-log-nav-more.tsx b/web/app/components/workflow/run/agent-log/agent-log-nav-more.tsx index 9f14aa1210..6062946ede 100644 --- a/web/app/components/workflow/run/agent-log/agent-log-nav-more.tsx +++ b/web/app/components/workflow/run/agent-log/agent-log-nav-more.tsx @@ -9,7 +9,7 @@ import Button from '@/app/components/base/button' import type { AgentLogItemWithChildren } from '@/types/workflow' type AgentLogNavMoreProps = { - options: { id: string; label: string }[] + options: AgentLogItemWithChildren[] onShowAgentOrToolLog: (detail?: AgentLogItemWithChildren) => void } const AgentLogNavMore = ({ @@ -41,10 +41,10 @@ const AgentLogNavMore = ({ { options.map(option => (
{ - onShowAgentOrToolLog(option as AgentLogItemWithChildren) + onShowAgentOrToolLog(option) setOpen(false) }} > diff --git a/web/app/components/workflow/workflow-preview/components/nodes/base.tsx b/web/app/components/workflow/workflow-preview/components/nodes/base.tsx index 55dfac467e..c7483c11bb 100644 --- a/web/app/components/workflow/workflow-preview/components/nodes/base.tsx +++ b/web/app/components/workflow/workflow-preview/components/nodes/base.tsx @@ -23,8 +23,10 @@ import { } from '../node-handle' import ErrorHandleOnNode from '../error-handle-on-node' +type NodeChildElement = ReactElement> + type NodeCardProps = NodeProps & { - children?: ReactElement + children?: NodeChildElement } const BaseCard = ({ diff --git a/web/context/debug-configuration.ts b/web/context/debug-configuration.ts index bbf7be8099..dba2e7a231 100644 --- a/web/context/debug-configuration.ts +++ b/web/context/debug-configuration.ts @@ -242,7 +242,7 @@ const DebugConfigurationContext = createContext({ }, datasetConfigsRef: { current: null, - }, + } as unknown as RefObject, setDatasetConfigs: noop, hasSetContextVar: false, isShowVisionConfig: false, diff --git a/web/hooks/use-oauth.ts b/web/hooks/use-oauth.ts index 74fbf14f0b..34ed8bafb0 100644 --- a/web/hooks/use-oauth.ts +++ b/web/hooks/use-oauth.ts @@ -1,5 +1,6 @@ 'use client' import { useEffect } from 'react' +import { validateRedirectUrl } from '@/utils/urlValidation' export const useOAuthCallback = () => { useEffect(() => { @@ -40,6 +41,7 @@ export const openOAuthPopup = (url: string, callback: (data?: any) => void) => { const left = window.screenX + (window.outerWidth - width) / 2 const top = window.screenY + (window.outerHeight - height) / 2 + validateRedirectUrl(url) const popup = window.open( url, 'OAuth', diff --git a/web/package.json b/web/package.json index 1b73a837d2..c4ee84b717 100644 --- a/web/package.json +++ b/web/package.json @@ -145,7 +145,7 @@ "@babel/core": "^7.28.3", "@chromatic-com/storybook": "^3.1.0", "@eslint-react/eslint-plugin": "^1.15.0", - "@happy-dom/jest-environment": "^20.0.0", + "@happy-dom/jest-environment": "^20.0.2", "@mdx-js/loader": "^3.1.0", "@mdx-js/react": "^3.1.0", "@next/bundle-analyzer": "15.5.4", diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 3a99dbfe6f..2a5aba754d 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -348,8 +348,8 @@ importers: specifier: ^1.15.0 version: 1.52.3(eslint@9.35.0(jiti@2.6.1))(ts-api-utils@2.1.0(typescript@5.8.3))(typescript@5.8.3) '@happy-dom/jest-environment': - specifier: ^20.0.0 - version: 20.0.0(@jest/environment@29.7.0)(@jest/fake-timers@29.7.0)(@jest/types@29.6.3)(jest-mock@29.7.0)(jest-util@29.7.0) + specifier: ^20.0.2 + version: 20.0.4(@jest/environment@29.7.0)(@jest/fake-timers@29.7.0)(@jest/types@29.6.3)(jest-mock@29.7.0)(jest-util@29.7.0) '@mdx-js/loader': specifier: ^3.1.0 version: 3.1.0(acorn@8.15.0)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3)) @@ -1647,8 +1647,8 @@ packages: '@formatjs/intl-localematcher@0.5.10': resolution: {integrity: sha512-af3qATX+m4Rnd9+wHcjJ4w2ijq+rAVP3CCinJQvFv1kgSu1W6jypUmvleJxcewdxmutM8dmIRZFxO/IQBZmP2Q==} - '@happy-dom/jest-environment@20.0.0': - resolution: {integrity: sha512-dUyMDNJzPDFopSDyzKdbeYs8z9B4jLj9kXnru8TjYdGeLsQKf+6r0lq/9T2XVcu04QFxXMykt64A+KjsaJTaNA==} + '@happy-dom/jest-environment@20.0.4': + resolution: {integrity: sha512-75OcYtjO+jqxWiYiXvbwR8JZITX1/8iAjRSRpZ/rNjO6UnYebwX6HdI91Ix09xYZEO1X/xOof6HX1EiZnrgnXA==} engines: {node: '>=20.0.0'} peerDependencies: '@jest/environment': '>=25.0.0' @@ -5582,8 +5582,8 @@ packages: hachure-fill@0.5.2: resolution: {integrity: sha512-3GKBOn+m2LX9iq+JC1064cSFprJY4jL1jCXTcpnfER5HYE2l/4EfWSGzkPa/ZDBmYI0ZOEj5VHV/eKnPGkHuOg==} - happy-dom@20.0.0: - resolution: {integrity: sha512-GkWnwIFxVGCf2raNrxImLo397RdGhLapj5cT3R2PT7FwL62Ze1DROhzmYW7+J3p9105DYMVenEejEbnq5wA37w==} + happy-dom@20.0.4: + resolution: {integrity: sha512-WxFtvnij6G64/MtMimnZhF0nKx3LUQKc20zjATD6tKiqOykUwQkd+2FW/DZBAFNjk4oWh0xdv/HBleGJmSY/Iw==} engines: {node: '>=20.0.0'} has-flag@4.0.0: @@ -10143,12 +10143,12 @@ snapshots: dependencies: tslib: 2.8.1 - '@happy-dom/jest-environment@20.0.0(@jest/environment@29.7.0)(@jest/fake-timers@29.7.0)(@jest/types@29.6.3)(jest-mock@29.7.0)(jest-util@29.7.0)': + '@happy-dom/jest-environment@20.0.4(@jest/environment@29.7.0)(@jest/fake-timers@29.7.0)(@jest/types@29.6.3)(jest-mock@29.7.0)(jest-util@29.7.0)': dependencies: '@jest/environment': 29.7.0 '@jest/fake-timers': 29.7.0 '@jest/types': 29.6.3 - happy-dom: 20.0.0 + happy-dom: 20.0.4 jest-mock: 29.7.0 jest-util: 29.7.0 @@ -14802,7 +14802,7 @@ snapshots: hachure-fill@0.5.2: {} - happy-dom@20.0.0: + happy-dom@20.0.4: dependencies: '@types/node': 20.19.20 '@types/whatwg-mimetype': 3.0.2 diff --git a/web/service/base.ts b/web/service/base.ts index af4b01c0ae..ea5f716ba3 100644 --- a/web/service/base.ts +++ b/web/service/base.ts @@ -324,7 +324,7 @@ const baseFetch = base type UploadOptions = { xhr: XMLHttpRequest - method: string + method?: string url?: string headers?: Record data: FormData diff --git a/web/utils/urlValidation.ts b/web/utils/urlValidation.ts new file mode 100644 index 0000000000..abc15a1365 --- /dev/null +++ b/web/utils/urlValidation.ts @@ -0,0 +1,23 @@ +/** + * Validates that a URL is safe for redirection. + * Only allows HTTP and HTTPS protocols to prevent XSS attacks. + * + * @param url - The URL string to validate + * @throws Error if the URL has an unsafe protocol + */ +export function validateRedirectUrl(url: string): void { + try { + const parsedUrl = new URL(url) + if (parsedUrl.protocol !== 'http:' && parsedUrl.protocol !== 'https:') + throw new Error('Authorization URL must be HTTP or HTTPS') + } + catch (error) { + if ( + error instanceof Error + && error.message === 'Authorization URL must be HTTP or HTTPS' + ) + throw error + // If URL parsing fails, it's also invalid + throw new Error(`Invalid URL: ${url}`) + } +}