Merge remote-tracking branch 'origin/main' into feat/trigger

This commit is contained in:
lyzno1 2025-10-17 19:21:15 +08:00
commit 8a5174d078
No known key found for this signature in database
56 changed files with 673 additions and 164 deletions

View File

@ -14,4 +14,4 @@ yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.ya
yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml
yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml
echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"
echo "Ports exposed for sandbox, weaviate (HTTP 8080, gRPC 50051), tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss"

View File

@ -22,7 +22,7 @@ from core.errors.error import (
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from libs import helper
from libs.login import current_user
from libs.login import current_user as current_user_
from models.model import AppMode, InstalledApp
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
@ -31,6 +31,8 @@ from .. import console_ns
logger = logging.getLogger(__name__)
current_user = current_user_._get_current_object() # type: ignore
@console_ns.route("/installed-apps/<uuid:installed_app_id>/workflows/run")
class InstalledAppWorkflowRunApi(InstalledAppResource):

View File

@ -303,7 +303,12 @@ def edit_permission_required(f: Callable[P, R]):
def decorated_function(*args: P.args, **kwargs: P.kwargs):
from werkzeug.exceptions import Forbidden
current_user, _ = current_account_with_tenant()
from libs.login import current_user
from models import Account
user = current_user._get_current_object() # type: ignore
if not isinstance(user, Account):
raise Forbidden()
if not current_user.has_edit_permission:
raise Forbidden()
return f(*args, **kwargs)

View File

@ -1,11 +1,9 @@
import logging
from threading import Lock
from typing import Union
import contexts
from core.datasource.__base.datasource_plugin import DatasourcePlugin
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.datasource.entities.common_entities import I18nObject
from core.datasource.entities.datasource_entities import DatasourceProviderType
from core.datasource.errors import DatasourceProviderNotFoundError
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
@ -18,11 +16,6 @@ logger = logging.getLogger(__name__)
class DatasourceManager:
_builtin_provider_lock = Lock()
_hardcoded_providers: dict[str, DatasourcePluginProviderController] = {}
_builtin_providers_loaded = False
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
@classmethod
def get_datasource_plugin_provider(
cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType

View File

@ -250,7 +250,6 @@ class WeaviateVector(BaseVector):
)
)
batch_size = max(1, int(dify_config.WEAVIATE_BATCH_SIZE or 100))
with col.batch.dynamic() as batch:
for obj in objs:
batch.add_object(properties=obj.properties, uuid=obj.uuid, vector=obj.vector)
@ -348,7 +347,10 @@ class WeaviateVector(BaseVector):
for obj in res.objects:
properties = dict(obj.properties or {})
text = properties.pop(Field.TEXT_KEY.value, "")
distance = (obj.metadata.distance if obj.metadata else None) or 1.0
if obj.metadata and obj.metadata.distance is not None:
distance = obj.metadata.distance
else:
distance = 1.0
score = 1.0 - distance
if score > score_threshold:

View File

@ -3,6 +3,7 @@ import logging
from collections.abc import Generator
from typing import Any
from flask import has_request_context
from sqlalchemy import select
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
@ -18,7 +19,8 @@ from core.tools.errors import ToolInvokeError
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from libs.login import current_user
from models.model import App
from models import Account, Tenant
from models.model import App, EndUser
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@ -79,11 +81,16 @@ class WorkflowTool(Tool):
generator = WorkflowAppGenerator()
assert self.runtime is not None
assert self.runtime.invoke_from is not None
assert current_user is not None
user = self._resolve_user(user_id=user_id)
if user is None:
raise ToolInvokeError("User not found")
result = generator.generate(
app_model=app,
workflow=workflow,
user=current_user,
user=user,
args={"inputs": tool_parameters, "files": files},
invoke_from=self.runtime.invoke_from,
streaming=False,
@ -123,6 +130,51 @@ class WorkflowTool(Tool):
label=self.label,
)
def _resolve_user(self, user_id: str) -> Account | EndUser | None:
"""
Resolve user object in both HTTP and worker contexts.
In HTTP context: dereference the current_user LocalProxy (can return Account or EndUser).
In worker context: load Account from database by user_id (only returns Account, never EndUser).
Returns:
Account | EndUser | None: The resolved user object, or None if resolution fails.
"""
if has_request_context():
return self._resolve_user_from_request()
else:
return self._resolve_user_from_database(user_id=user_id)
def _resolve_user_from_request(self) -> Account | EndUser | None:
"""
Resolve user from Flask request context.
"""
try:
# Note: `current_user` is a LocalProxy. Never compare it with None directly.
return getattr(current_user, "_get_current_object", lambda: current_user)()
except Exception as e:
logger.warning("Failed to resolve user from request context: %s", e)
return None
def _resolve_user_from_database(self, user_id: str) -> Account | None:
"""
Resolve user from database (worker/Celery context).
"""
user_stmt = select(Account).where(Account.id == user_id)
user = db.session.scalar(user_stmt)
if not user:
return None
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
tenant = db.session.scalar(tenant_stmt)
if not tenant:
return None
user.current_tenant = tenant
return user
def _get_workflow(self, app_id: str, version: str) -> Workflow:
"""
get the workflow by app id and version

View File

@ -99,6 +99,8 @@ class Dispatcher:
self._execution_coordinator.check_commands()
self._event_queue.task_done()
except queue.Empty:
# Process commands even when no new events arrive so abort requests are not missed
self._execution_coordinator.check_commands()
# Check if execution is complete
if self._execution_coordinator.is_execution_complete():
break

View File

@ -1,6 +1,6 @@
from collections.abc import Callable
from functools import wraps
from typing import Union, cast
from typing import Any
from flask import current_app, g, has_request_context, request
from flask_login.config import EXEMPT_METHODS # type: ignore
@ -10,16 +10,21 @@ from configs import dify_config
from models import Account
from models.model import EndUser
#: A proxy for the current user. If no user is logged in, this will be an
#: anonymous user
current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user()))
def current_account_with_tenant():
if not isinstance(current_user, Account):
"""
Resolve the underlying account for the current user proxy and ensure tenant context exists.
Allows tests to supply plain Account mocks without the LocalProxy helper.
"""
user_proxy = current_user
get_current_object = getattr(user_proxy, "_get_current_object", None)
user = get_current_object() if callable(get_current_object) else user_proxy # type: ignore
if not isinstance(user, Account):
raise ValueError("current_user must be an Account instance")
assert current_user.current_tenant_id is not None, "The tenant information should be loaded."
return current_user, current_user.current_tenant_id
assert user.current_tenant_id is not None, "The tenant information should be loaded."
return user, user.current_tenant_id
from typing import ParamSpec, TypeVar
@ -81,3 +86,9 @@ def _get_user() -> EndUser | Account | None:
return g._login_user # type: ignore
return None
#: A proxy for the current user. If no user is logged in, this will be an
#: anonymous user
# NOTE: Any here, but use _get_current_object to check the fields
current_user: Any = LocalProxy(lambda: _get_user())

View File

@ -1479,7 +1479,7 @@ class EndUser(Base, UserMixin):
sa.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id = mapped_column(StringUUID, nullable=True)
type: Mapped[str] = mapped_column(String(255), nullable=False)

View File

@ -168,6 +168,7 @@ dev = [
"mypy~=1.17.1",
# "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved.
"sseclient-py>=1.8.0",
"pytest-timeout>=2.4.0",
]
############################################################
@ -216,7 +217,7 @@ vdb = [
"tidb-vector==0.0.9",
"upstash-vector==0.6.0",
"volcengine-compat~=1.0.0",
"weaviate-client>=4.0.0,<5.0.0",
"weaviate-client==4.17.0",
"xinference-client~=1.2.2",
"mo-vector~=0.1.13",
"mysql-connector-python>=9.3.0",

View File

@ -17,7 +17,6 @@ from core.plugin.impl.oauth import OAuthHandler
from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.login import current_account_with_tenant
from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider
from models.provider_ids import DatasourceProviderID
from services.plugin.plugin_service import PluginService
@ -25,6 +24,16 @@ from services.plugin.plugin_service import PluginService
logger = logging.getLogger(__name__)
def get_current_user():
from libs.login import current_user
from models.account import Account
from models.model import EndUser
if not isinstance(current_user._get_current_object(), (Account, EndUser)): # type: ignore
raise TypeError(f"current_user must be Account or EndUser, got {type(current_user).__name__}")
return current_user
class DatasourceProviderService:
"""
Model Provider Service
@ -93,8 +102,6 @@ class DatasourceProviderService:
"""
get credential by id
"""
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session:
if credential_id:
datasource_provider = (
@ -111,6 +118,7 @@ class DatasourceProviderService:
return {}
# refresh the credentials
if datasource_provider.expires_at != -1 and (datasource_provider.expires_at - 60) < int(time.time()):
current_user = get_current_user()
decrypted_credentials = self.decrypt_datasource_provider_credentials(
tenant_id=tenant_id,
datasource_provider=datasource_provider,
@ -159,8 +167,6 @@ class DatasourceProviderService:
"""
get all datasource credentials by provider
"""
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session:
datasource_providers = (
session.query(DatasourceProvider)
@ -170,6 +176,7 @@ class DatasourceProviderService:
)
if not datasource_providers:
return []
current_user = get_current_user()
# refresh the credentials
real_credentials_list = []
for datasource_provider in datasource_providers:
@ -608,7 +615,6 @@ class DatasourceProviderService:
"""
provider_name = provider_id.provider_name
plugin_id = provider_id.plugin_id
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session:
lock = f"datasource_provider_create_lock:{tenant_id}_{provider_id}_{CredentialType.API_KEY}"
@ -630,6 +636,7 @@ class DatasourceProviderService:
raise ValueError("Authorization name is already exists")
try:
current_user = get_current_user()
self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id,
user_id=current_user.id,
@ -907,7 +914,6 @@ class DatasourceProviderService:
"""
update datasource credentials.
"""
current_user, _ = current_account_with_tenant()
with Session(db.engine) as session:
datasource_provider = (
@ -944,6 +950,7 @@ class DatasourceProviderService:
for key, value in credentials.items()
}
try:
current_user = get_current_user()
self.provider_manager.validate_provider_credentials(
tenant_id=tenant_id,
user_id=current_user.id,

View File

@ -0,0 +1,134 @@
"""
TestContainers-based integration tests for mail_register_task.py
This module provides integration tests for email registration tasks
using TestContainers to ensure real database and service interactions.
"""
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from libs.email_i18n import EmailType
from tasks.mail_register_task import send_email_register_mail_task, send_email_register_mail_task_when_account_exist
class TestMailRegisterTask:
"""Integration tests for mail_register_task using testcontainers."""
@pytest.fixture
def mock_mail_dependencies(self):
"""Mock setup for mail service dependencies."""
with (
patch("tasks.mail_register_task.mail") as mock_mail,
patch("tasks.mail_register_task.get_email_i18n_service") as mock_get_email_service,
):
# Setup mock mail service
mock_mail.is_inited.return_value = True
# Setup mock email i18n service
mock_email_service = MagicMock()
mock_get_email_service.return_value = mock_email_service
yield {
"mail": mock_mail,
"email_service": mock_email_service,
"get_email_service": mock_get_email_service,
}
def test_send_email_register_mail_task_success(self, db_session_with_containers, mock_mail_dependencies):
"""Test successful email registration mail sending."""
fake = Faker()
language = "en-US"
to_email = fake.email()
code = fake.numerify("######")
send_email_register_mail_task(language=language, to=to_email, code=code)
mock_mail_dependencies["mail"].is_inited.assert_called_once()
mock_mail_dependencies["email_service"].send_email.assert_called_once_with(
email_type=EmailType.EMAIL_REGISTER,
language_code=language,
to=to_email,
template_context={
"to": to_email,
"code": code,
},
)
def test_send_email_register_mail_task_mail_not_initialized(
self, db_session_with_containers, mock_mail_dependencies
):
"""Test email registration task when mail service is not initialized."""
mock_mail_dependencies["mail"].is_inited.return_value = False
send_email_register_mail_task(language="en-US", to="test@example.com", code="123456")
mock_mail_dependencies["get_email_service"].assert_not_called()
mock_mail_dependencies["email_service"].send_email.assert_not_called()
def test_send_email_register_mail_task_exception_handling(self, db_session_with_containers, mock_mail_dependencies):
"""Test email registration task exception handling."""
mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error")
fake = Faker()
to_email = fake.email()
code = fake.numerify("######")
with patch("tasks.mail_register_task.logger") as mock_logger:
send_email_register_mail_task(language="en-US", to=to_email, code=code)
mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", to_email)
def test_send_email_register_mail_task_when_account_exist_success(
self, db_session_with_containers, mock_mail_dependencies
):
"""Test successful email registration mail sending when account exists."""
fake = Faker()
language = "en-US"
to_email = fake.email()
account_name = fake.name()
with patch("tasks.mail_register_task.dify_config") as mock_config:
mock_config.CONSOLE_WEB_URL = "https://console.dify.ai"
send_email_register_mail_task_when_account_exist(language=language, to=to_email, account_name=account_name)
mock_mail_dependencies["email_service"].send_email.assert_called_once_with(
email_type=EmailType.EMAIL_REGISTER_WHEN_ACCOUNT_EXIST,
language_code=language,
to=to_email,
template_context={
"to": to_email,
"login_url": "https://console.dify.ai/signin",
"reset_password_url": "https://console.dify.ai/reset-password",
"account_name": account_name,
},
)
def test_send_email_register_mail_task_when_account_exist_mail_not_initialized(
self, db_session_with_containers, mock_mail_dependencies
):
"""Test account exist email task when mail service is not initialized."""
mock_mail_dependencies["mail"].is_inited.return_value = False
send_email_register_mail_task_when_account_exist(
language="en-US", to="test@example.com", account_name="Test User"
)
mock_mail_dependencies["get_email_service"].assert_not_called()
mock_mail_dependencies["email_service"].send_email.assert_not_called()
def test_send_email_register_mail_task_when_account_exist_exception_handling(
self, db_session_with_containers, mock_mail_dependencies
):
"""Test account exist email task exception handling."""
mock_mail_dependencies["email_service"].send_email.side_effect = Exception("Email service error")
fake = Faker()
to_email = fake.email()
account_name = fake.name()
with patch("tasks.mail_register_task.logger") as mock_logger:
send_email_register_mail_task_when_account_exist(language="en-US", to=to_email, account_name=account_name)
mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", to_email)

View File

@ -34,12 +34,17 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None)
monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None)
# Mock user resolution to avoid database access
from unittest.mock import Mock
mock_user = Mock()
monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user)
# replace `WorkflowAppGenerator.generate` 's return value.
monkeypatch.setattr(
"core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate",
lambda *args, **kwargs: {"data": {"error": "oops"}},
)
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
with pytest.raises(ToolInvokeError) as exc_info:
# WorkflowTool always returns a generator, so we need to iterate to

View File

@ -7,14 +7,11 @@ This test suite validates the behavior of a workflow that:
3. Handles multiple answer nodes with different outputs
"""
import pytest
from core.workflow.graph_events import (
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
)
from .test_mock_config import MockConfigBuilder
@ -29,7 +26,6 @@ class TestComplexBranchWorkflow:
self.runner = TableTestRunner()
self.fixture_path = "test_complex_branch"
@pytest.mark.skip(reason="output in this workflow can be random")
def test_hello_branch_with_llm(self):
"""
Test when query contains 'hello' - should trigger true branch.
@ -41,42 +37,17 @@ class TestComplexBranchWorkflow:
fixture_path=self.fixture_path,
query="hello world",
expected_outputs={
"answer": f"{mock_text_1}contains 'hello'",
"answer": f"contains 'hello'{mock_text_1}",
},
description="Basic hello case with parallel LLM execution",
use_auto_mock=True,
mock_config=(MockConfigBuilder().with_node_output("1755502777322", {"text": mock_text_1}).build()),
expected_event_sequence=[
GraphRunStartedEvent,
# Start
NodeRunStartedEvent,
NodeRunSucceededEvent,
# If/Else (no streaming)
NodeRunStartedEvent,
NodeRunSucceededEvent,
# LLM (with streaming)
NodeRunStartedEvent,
]
# LLM
+ [NodeRunStreamChunkEvent] * (mock_text_1.count(" ") + 2)
+ [
# Answer's text
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
# Answer
NodeRunStartedEvent,
NodeRunSucceededEvent,
# Answer 2
NodeRunStartedEvent,
NodeRunSucceededEvent,
GraphRunSucceededEvent,
],
),
WorkflowTestCase(
fixture_path=self.fixture_path,
query="say hello to everyone",
expected_outputs={
"answer": "Mocked response for greetingcontains 'hello'",
"answer": "contains 'hello'Mocked response for greeting",
},
description="Hello in middle of sentence",
use_auto_mock=True,
@ -93,6 +64,35 @@ class TestComplexBranchWorkflow:
for result in suite_result.results:
assert result.success, f"Test '{result.test_case.description}' failed: {result.error}"
assert result.actual_outputs
assert any(isinstance(event, GraphRunStartedEvent) for event in result.events)
assert any(isinstance(event, GraphRunSucceededEvent) for event in result.events)
start_index = next(
idx for idx, event in enumerate(result.events) if isinstance(event, GraphRunStartedEvent)
)
success_index = max(
idx for idx, event in enumerate(result.events) if isinstance(event, GraphRunSucceededEvent)
)
assert start_index < success_index
started_node_ids = {event.node_id for event in result.events if isinstance(event, NodeRunStartedEvent)}
assert {"1755502773326", "1755502777322"}.issubset(started_node_ids), (
f"Branch or LLM nodes missing in events: {started_node_ids}"
)
assert any(isinstance(event, NodeRunStreamChunkEvent) for event in result.events), (
"Expected streaming chunks from LLM execution"
)
llm_start_index = next(
idx
for idx, event in enumerate(result.events)
if isinstance(event, NodeRunStartedEvent) and event.node_id == "1755502777322"
)
assert any(
idx > llm_start_index and isinstance(event, NodeRunStreamChunkEvent)
for idx, event in enumerate(result.events)
), "Streaming chunks should follow LLM node start"
def test_non_hello_branch_with_llm(self):
"""

View File

@ -95,10 +95,10 @@ def _make_succeeded_event() -> NodeRunSucceededEvent:
)
def test_dispatcher_checks_commands_after_node_completion() -> None:
"""Dispatcher should only check commands after node completion events."""
def test_dispatcher_checks_commands_during_idle_and_on_completion() -> None:
"""Dispatcher polls commands when idle and re-checks after completion events."""
started_checks = _run_dispatcher_for_event(_make_started_event())
succeeded_checks = _run_dispatcher_for_event(_make_succeeded_event())
assert started_checks == 0
assert succeeded_checks == 1
assert started_checks == 1
assert succeeded_checks == 2

View File

@ -21,7 +21,6 @@ from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom,
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
from core.workflow.entities import GraphInitParams, GraphRuntimeState, VariablePool
from core.workflow.graph import Graph
from core.workflow.nodes.llm import llm_utils
from core.workflow.nodes.llm.entities import (
ContextConfig,
@ -83,14 +82,6 @@ def graph_init_params() -> GraphInitParams:
)
@pytest.fixture
def graph() -> Graph:
# TODO: This fixture uses old Graph constructor parameters that are incompatible
# with the new queue-based engine. Need to rewrite for new engine architecture.
pytest.skip("Graph fixture incompatible with new queue-based engine - needs rewrite for ResponseStreamCoordinator")
return Graph()
@pytest.fixture
def graph_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
@ -105,7 +96,7 @@ def graph_runtime_state() -> GraphRuntimeState:
@pytest.fixture
def llm_node(
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState
) -> LLMNode:
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
node_config = {
@ -493,9 +484,7 @@ def test_handle_list_messages_basic(llm_node):
@pytest.fixture
def llm_node_for_multimodal(
llm_node_data, graph_init_params, graph, graph_runtime_state
) -> tuple[LLMNode, LLMFileSaver]:
def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_state) -> tuple[LLMNode, LLMFileSaver]:
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
node_config = {
"id": "1",
@ -655,7 +644,7 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
contents=frozenset(["hello world"]), file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == ["frozenset({'hello world'})"]
assert list(gen) == ["hello world"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()

View File

@ -181,14 +181,11 @@ class TestAuthIntegration:
)
def test_all_providers_factory_creation(self, provider, credentials):
"""Test factory creation for all supported providers"""
try:
auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider)
assert auth_class is not None
auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider)
assert auth_class is not None
factory = ApiKeyAuthFactory(provider, credentials)
assert factory.auth is not None
except ImportError:
pytest.skip(f"Provider {provider} not implemented yet")
factory = ApiKeyAuthFactory(provider, credentials)
assert factory.auth is not None
def _create_success_response(self, status_code=200):
"""Create successful HTTP response mock"""

View File

@ -41,7 +41,10 @@ class TestMetadataBugCompleteValidation:
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch("services.metadata_service.current_user", mock_user):
with patch(
"services.metadata_service.current_account_with_tenant",
return_value=(mock_user, mock_user.current_tenant_id),
):
# Should crash with TypeError
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args)
@ -51,7 +54,10 @@ class TestMetadataBugCompleteValidation:
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch("services.metadata_service.current_user", mock_user):
with patch(
"services.metadata_service.current_account_with_tenant",
return_value=(mock_user, mock_user.current_tenant_id),
):
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.update_metadata_name("dataset-123", "metadata-456", None)

View File

@ -29,7 +29,10 @@ class TestMetadataNullableBug:
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch("services.metadata_service.current_user", mock_user):
with patch(
"services.metadata_service.current_account_with_tenant",
return_value=(mock_user, mock_user.current_tenant_id),
):
# This should crash with TypeError when calling len(None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args)
@ -40,7 +43,10 @@ class TestMetadataNullableBug:
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch("services.metadata_service.current_user", mock_user):
with patch(
"services.metadata_service.current_account_with_tenant",
return_value=(mock_user, mock_user.current_tenant_id),
):
# This should crash with TypeError when calling len(None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.update_metadata_name("dataset-123", "metadata-456", None)
@ -88,7 +94,10 @@ class TestMetadataNullableBug:
mock_user.current_tenant_id = "tenant-123"
mock_user.id = "user-456"
with patch("services.metadata_service.current_user", mock_user):
with patch(
"services.metadata_service.current_account_with_tenant",
return_value=(mock_user, mock_user.current_tenant_id),
):
# Step 4: Service layer crashes on len(None)
with pytest.raises(TypeError, match="object of type 'NoneType' has no len"):
MetadataService.create_metadata("dataset-123", mock_metadata_args)

View File

@ -1409,6 +1409,7 @@ dev = [
{ name = "pytest-cov" },
{ name = "pytest-env" },
{ name = "pytest-mock" },
{ name = "pytest-timeout" },
{ name = "ruff" },
{ name = "scipy-stubs" },
{ name = "sseclient-py" },
@ -1600,6 +1601,7 @@ dev = [
{ name = "pytest-cov", specifier = "~=4.1.0" },
{ name = "pytest-env", specifier = "~=1.1.3" },
{ name = "pytest-mock", specifier = "~=3.14.0" },
{ name = "pytest-timeout", specifier = ">=2.4.0" },
{ name = "ruff", specifier = "~=0.14.0" },
{ name = "scipy-stubs", specifier = ">=1.15.3.0" },
{ name = "sseclient-py", specifier = ">=1.8.0" },
@ -1684,7 +1686,7 @@ vdb = [
{ name = "tidb-vector", specifier = "==0.0.9" },
{ name = "upstash-vector", specifier = "==0.6.0" },
{ name = "volcengine-compat", specifier = "~=1.0.0" },
{ name = "weaviate-client", specifier = ">=4.0.0,<5.0.0" },
{ name = "weaviate-client", specifier = "==4.17.0" },
{ name = "xinference-client", specifier = "~=1.2.2" },
]
@ -4996,6 +4998,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b2/05/77b60e520511c53d1c1ca75f1930c7dd8e971d0c4379b7f4b3f9644685ba/pytest_mock-3.14.1-py3-none-any.whl", hash = "sha256:178aefcd11307d874b4cd3100344e7e2d888d9791a6a1d9bfe90fbc1b74fd1d0", size = 9923, upload-time = "2025-05-26T13:58:43.487Z" },
]
[[package]]
name = "pytest-timeout"
version = "2.4.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pytest" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ac/82/4c9ecabab13363e72d880f2fb504c5f750433b2b6f16e99f4ec21ada284c/pytest_timeout-2.4.0.tar.gz", hash = "sha256:7e68e90b01f9eff71332b25001f85c75495fc4e3a836701876183c4bcfd0540a", size = 17973, upload-time = "2025-05-05T19:44:34.99Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/b6/3127540ecdf1464a00e5a01ee60a1b09175f6913f0644ac748494d9c4b21/pytest_timeout-2.4.0-py3-none-any.whl", hash = "sha256:c42667e5cdadb151aeb5b26d114aff6bdf5a907f176a007a30b940d3d865b5c2", size = 14382, upload-time = "2025-05-05T19:44:33.502Z" },
]
[[package]]
name = "python-calamine"
version = "0.5.3"

View File

@ -4,4 +4,6 @@ set -x
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
cd "$SCRIPT_DIR/../.."
pytest api/tests/artifact_tests/
PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}"
pytest --timeout "${PYTEST_TIMEOUT}" api/tests/artifact_tests/

View File

@ -4,7 +4,9 @@ set -x
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
cd "$SCRIPT_DIR/../.."
pytest api/tests/integration_tests/model_runtime/anthropic \
PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-180}"
pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/model_runtime/anthropic \
api/tests/integration_tests/model_runtime/azure_openai \
api/tests/integration_tests/model_runtime/openai api/tests/integration_tests/model_runtime/chatglm \
api/tests/integration_tests/model_runtime/google api/tests/integration_tests/model_runtime/xinference \

View File

@ -4,4 +4,6 @@ set -x
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
cd "$SCRIPT_DIR/../.."
pytest api/tests/test_containers_integration_tests
PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}"
pytest --timeout "${PYTEST_TIMEOUT}" api/tests/test_containers_integration_tests

View File

@ -4,4 +4,6 @@ set -x
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
cd "$SCRIPT_DIR/../.."
pytest api/tests/integration_tests/tools
PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}"
pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/tools

View File

@ -4,5 +4,7 @@ set -x
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
cd "$SCRIPT_DIR/../.."
PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-20}"
# libs
pytest api/tests/unit_tests
pytest --timeout "${PYTEST_TIMEOUT}" api/tests/unit_tests

View File

@ -4,7 +4,9 @@ set -x
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
cd "$SCRIPT_DIR/../.."
pytest api/tests/integration_tests/vdb/chroma \
PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-180}"
pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/vdb/chroma \
api/tests/integration_tests/vdb/milvus \
api/tests/integration_tests/vdb/pgvecto_rs \
api/tests/integration_tests/vdb/pgvector \

View File

@ -4,4 +4,6 @@ set -x
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
cd "$SCRIPT_DIR/../.."
pytest api/tests/integration_tests/workflow
PYTEST_TIMEOUT="${PYTEST_TIMEOUT:-120}"
pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/workflow

View File

@ -24,6 +24,13 @@ services:
volumes:
# Mount the storage directory to the container, for storing user files.
- ./volumes/app/storage:/app/api/storage
# TODO: Remove this entrypoint override when weaviate-client 4.17.0 is included in the next Dify release
entrypoint:
- /bin/bash
- -c
- |
uv pip install --system weaviate-client==4.17.0
exec /bin/bash /app/api/docker/entrypoint.sh
networks:
- ssrf_proxy_network
- default
@ -51,6 +58,13 @@ services:
volumes:
# Mount the storage directory to the container, for storing user files.
- ./volumes/app/storage:/app/api/storage
# TODO: Remove this entrypoint override when weaviate-client 4.17.0 is included in the next Dify release
entrypoint:
- /bin/bash
- -c
- |
uv pip install --system weaviate-client==4.17.0
exec /bin/bash /app/api/docker/entrypoint.sh
networks:
- ssrf_proxy_network
- default
@ -331,7 +345,6 @@ services:
weaviate:
image: semitechnologies/weaviate:1.27.0
profiles:
- ""
- weaviate
restart: always
volumes:

View File

@ -1,9 +0,0 @@
services:
api:
volumes:
- ../api/core/rag/datasource/vdb/weaviate/weaviate_vector.py:/app/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py:ro
command: >
sh -c "
pip install --no-cache-dir 'weaviate>=4.0.0' &&
/bin/bash /entrypoint.sh
"

View File

@ -636,6 +636,13 @@ services:
volumes:
# Mount the storage directory to the container, for storing user files.
- ./volumes/app/storage:/app/api/storage
# TODO: Remove this entrypoint override when weaviate-client 4.17.0 is included in the next Dify release
entrypoint:
- /bin/bash
- -c
- |
uv pip install --system weaviate-client==4.17.0
exec /bin/bash /app/api/docker/entrypoint.sh
networks:
- ssrf_proxy_network
- default
@ -663,6 +670,13 @@ services:
volumes:
# Mount the storage directory to the container, for storing user files.
- ./volumes/app/storage:/app/api/storage
# TODO: Remove this entrypoint override when weaviate-client 4.17.0 is included in the next Dify release
entrypoint:
- /bin/bash
- -c
- |
uv pip install --system weaviate-client==4.17.0
exec /bin/bash /app/api/docker/entrypoint.sh
networks:
- ssrf_proxy_network
- default
@ -943,7 +957,6 @@ services:
weaviate:
image: semitechnologies/weaviate:1.27.0
profiles:
- ""
- weaviate
restart: always
volumes:

View File

@ -0,0 +1,187 @@
# Weaviate Migration Guide: v1.19 → v1.27
## Overview
Dify has upgraded from Weaviate v1.19 to v1.27 with the Python client updated from v3.24 to v4.17.
## What Changed
### Breaking Changes
1. **Weaviate Server**: `1.19.0``1.27.0`
1. **Python Client**: `weaviate-client~=3.24.0``weaviate-client==4.17.0`
1. **gRPC Required**: Weaviate v1.27 requires gRPC port `50051` (in addition to HTTP port `8080`)
1. **Docker Compose**: Added temporary entrypoint overrides for client installation
### Key Improvements
- Faster vector operations via gRPC
- Improved batch processing
- Better error handling
## Migration Steps
### For Docker Users
#### Step 1: Backup Your Data
```bash
cd docker
docker compose down
sudo cp -r ./volumes/weaviate ./volumes/weaviate_backup_$(date +%Y%m%d)
```
#### Step 2: Update Dify
```bash
git pull origin main
docker compose pull
```
#### Step 3: Start Services
```bash
docker compose up -d
sleep 30
curl http://localhost:8080/v1/meta
```
#### Step 4: Verify Migration
```bash
# Check both ports are accessible
curl http://localhost:8080/v1/meta
netstat -tulpn | grep 50051
# Test in Dify UI:
# 1. Go to Knowledge Base
# 2. Test search functionality
# 3. Upload a test document
```
### For Source Installation
#### Step 1: Update Dependencies
```bash
cd api
uv sync --dev
uv run python -c "import weaviate; print(weaviate.__version__)"
# Should show: 4.17.0
```
#### Step 2: Update Weaviate Server
```bash
cd docker
docker compose -f docker-compose.middleware.yaml --profile weaviate up -d weaviate
curl http://localhost:8080/v1/meta
netstat -tulpn | grep 50051
```
## Troubleshooting
### Error: "No module named 'weaviate.classes'"
**Solution**:
```bash
cd api
uv sync --reinstall-package weaviate-client
uv run python -c "import weaviate; print(weaviate.__version__)"
# Should show: 4.17.0
```
### Error: "gRPC health check failed"
**Solution**:
```bash
# Check Weaviate ports
docker ps | grep weaviate
# Should show: 0.0.0.0:8080->8080/tcp, 0.0.0.0:50051->50051/tcp
# If missing gRPC port, add to docker-compose:
# ports:
# - "8080:8080"
# - "50051:50051"
```
### Error: "Weaviate version 1.19.0 is not supported"
**Solution**:
```bash
# Update Weaviate image in docker-compose
# Change: semitechnologies/weaviate:1.19.0
# To: semitechnologies/weaviate:1.27.0
docker compose down
docker compose up -d
```
### Data Migration Failed
**Solution**:
```bash
cd docker
docker compose down
sudo rm -rf ./volumes/weaviate
sudo cp -r ./volumes/weaviate_backup_YYYYMMDD ./volumes/weaviate
docker compose up -d
```
## Rollback Instructions
```bash
# 1. Stop services
docker compose down
# 2. Restore data backup
sudo rm -rf ./volumes/weaviate
sudo cp -r ./volumes/weaviate_backup_YYYYMMDD ./volumes/weaviate
# 3. Checkout previous version
git checkout <previous-commit>
# 4. Restart services
docker compose up -d
```
## Compatibility
| Component | Old Version | New Version | Compatible |
|-----------|-------------|-------------|------------|
| Weaviate Server | 1.19.0 | 1.27.0 | ✅ Yes |
| weaviate-client | ~3.24.0 | ==4.17.0 | ✅ Yes |
| Existing Data | v1.19 format | v1.27 format | ✅ Yes |
## Testing Checklist
Before deploying to production:
- [ ] Backup all Weaviate data
- [ ] Test in staging environment
- [ ] Verify existing collections are accessible
- [ ] Test vector search functionality
- [ ] Test document upload and retrieval
- [ ] Monitor gRPC connection stability
- [ ] Check performance metrics
## Support
If you encounter issues:
1. Check GitHub Issues: https://github.com/langgenius/dify/issues
1. Create a bug report with:
- Error messages
- Docker logs: `docker compose logs weaviate`
- Dify version
- Migration steps attempted
## Important Notes
- **Data Safety**: Existing vector data remains fully compatible
- **No Re-indexing**: No need to rebuild vector indexes
- **Temporary Workaround**: The entrypoint overrides are temporary until next Dify release
- **Performance**: May see improved performance due to gRPC usage

View File

@ -100,7 +100,10 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
})
}
}
catch (e: any) {
if (e.code === 'authentication_failed')
Toast.notify({ type: 'error', message: e.message })
}
finally {
setIsLoading(false)
}

View File

@ -32,7 +32,7 @@ const TopKItem: FC<Props> = ({
}) => {
const { t } = useTranslation()
const handleParamChange = (key: string, value: number) => {
let notOutRangeValue = Number.parseFloat(value.toFixed(2))
let notOutRangeValue = Number.parseInt(value.toFixed(0))
notOutRangeValue = Math.max(VALUE_LIMIT.min, notOutRangeValue)
notOutRangeValue = Math.min(VALUE_LIMIT.max, notOutRangeValue)
onChange(key, notOutRangeValue)

View File

@ -25,8 +25,8 @@ export type TextareaProps = {
destructive?: boolean
styleCss?: CSSProperties
ref?: React.Ref<HTMLTextAreaElement>
onFocus?: () => void
onBlur?: () => void
onFocus?: React.FocusEventHandler<HTMLTextAreaElement>
onBlur?: React.FocusEventHandler<HTMLTextAreaElement>
} & React.TextareaHTMLAttributes<HTMLTextAreaElement> & VariantProps<typeof textareaVariants>
const Textarea = React.forwardRef<HTMLTextAreaElement, TextareaProps>(

View File

@ -234,6 +234,9 @@ const ConditionItem = ({
draft.varType = resolvedVarType
draft.value = resolvedVarType === VarType.boolean ? false : ''
draft.comparison_operator = getOperators(resolvedVarType)[0]
delete draft.key
delete draft.sub_variable_condition
delete draft.numberVarType
setTimeout(() => setControlPromptEditorRerenderKey(Date.now()))
})
doUpdateCondition(newCondition)

View File

@ -1,8 +1,8 @@
import { memo } from 'react'
import { memo, useCallback } from 'react'
import { useTranslation } from 'react-i18next'
import Tooltip from '@/app/components/base/tooltip'
import Input from '@/app/components/base/input'
import Switch from '@/app/components/base/switch'
import { InputNumber } from '@/app/components/base/input-number'
export type TopKAndScoreThresholdProps = {
topK: number
@ -14,6 +14,24 @@ export type TopKAndScoreThresholdProps = {
readonly?: boolean
hiddenScoreThreshold?: boolean
}
const maxTopK = (() => {
const configValue = Number.parseInt(globalThis.document?.body?.getAttribute('data-public-top-k-max-value') || '', 10)
if (configValue && !isNaN(configValue))
return configValue
return 10
})()
const TOP_K_VALUE_LIMIT = {
amount: 1,
min: 1,
max: maxTopK,
}
const SCORE_THRESHOLD_VALUE_LIMIT = {
step: 0.01,
min: 0,
max: 1,
}
const TopKAndScoreThreshold = ({
topK,
onTopKChange,
@ -25,18 +43,18 @@ const TopKAndScoreThreshold = ({
hiddenScoreThreshold,
}: TopKAndScoreThresholdProps) => {
const { t } = useTranslation()
const handleTopKChange = (e: React.ChangeEvent<HTMLInputElement>) => {
const value = Number(e.target.value)
if (Number.isNaN(value))
return
onTopKChange?.(value)
}
const handleTopKChange = useCallback((value: number) => {
let notOutRangeValue = Number.parseInt(value.toFixed(0))
notOutRangeValue = Math.max(TOP_K_VALUE_LIMIT.min, notOutRangeValue)
notOutRangeValue = Math.min(TOP_K_VALUE_LIMIT.max, notOutRangeValue)
onTopKChange?.(notOutRangeValue)
}, [onTopKChange])
const handleScoreThresholdChange = (e: React.ChangeEvent<HTMLInputElement>) => {
const value = Number(e.target.value)
if (Number.isNaN(value))
return
onScoreThresholdChange?.(value)
const handleScoreThresholdChange = (value: number) => {
let notOutRangeValue = Number.parseFloat(value.toFixed(2))
notOutRangeValue = Math.max(SCORE_THRESHOLD_VALUE_LIMIT.min, notOutRangeValue)
notOutRangeValue = Math.min(SCORE_THRESHOLD_VALUE_LIMIT.max, notOutRangeValue)
onScoreThresholdChange?.(notOutRangeValue)
}
return (
@ -49,11 +67,13 @@ const TopKAndScoreThreshold = ({
popupContent={t('appDebug.datasetConfig.top_kTip')}
/>
</div>
<Input
<InputNumber
disabled={readonly}
type='number'
{...TOP_K_VALUE_LIMIT}
size='regular'
value={topK}
onChange={handleTopKChange}
disabled={readonly}
/>
</div>
{
@ -74,11 +94,13 @@ const TopKAndScoreThreshold = ({
popupContent={t('appDebug.datasetConfig.score_thresholdTip')}
/>
</div>
<Input
<InputNumber
disabled={readonly || !isScoreThresholdEnabled}
type='number'
{...SCORE_THRESHOLD_VALUE_LIMIT}
size='regular'
value={scoreThreshold}
onChange={handleScoreThresholdChange}
disabled={readonly || !isScoreThresholdEnabled}
/>
</div>
)

View File

@ -18,7 +18,7 @@ type ConditionNumberProps = {
nodesOutputVars: NodeOutPutVar[]
availableNodes: Node[]
isCommonVariable?: boolean
commonVariables: { name: string, type: string }[]
commonVariables: { name: string; type: string; value: string }[]
} & ConditionValueMethodProps
const ConditionNumber = ({
value,

View File

@ -18,7 +18,7 @@ type ConditionStringProps = {
nodesOutputVars: NodeOutPutVar[]
availableNodes: Node[]
isCommonVariable?: boolean
commonVariables: { name: string, type: string }[]
commonVariables: { name: string; type: string; value: string }[]
} & ConditionValueMethodProps
const ConditionString = ({
value,

View File

@ -128,6 +128,6 @@ export type MetadataShape = {
availableNumberVars?: NodeOutPutVar[]
availableNumberNodesWithParent?: Node[]
isCommonVariable?: boolean
availableCommonStringVars?: { name: string; type: string; }[]
availableCommonNumberVars?: { name: string; type: string; }[]
availableCommonStringVars?: { name: string; type: string; value: string }[]
availableCommonNumberVars?: { name: string; type: string; value: string }[]
}

View File

@ -24,7 +24,7 @@ const JsonImporter: FC<JsonImporterProps> = ({
const [open, setOpen] = useState(false)
const [json, setJson] = useState('')
const [parseError, setParseError] = useState<any>(null)
const importBtnRef = useRef<HTMLButtonElement>(null)
const importBtnRef = useRef<HTMLElement>(null)
const advancedEditing = useVisualEditorStore(state => state.advancedEditing)
const isAddingNewField = useVisualEditorStore(state => state.isAddingNewField)
const { emit } = useMittContext()

View File

@ -18,7 +18,7 @@ type VisualEditorProviderProps = {
export const VisualEditorContext = createContext<VisualEditorContextType>(null)
export const VisualEditorContextProvider = ({ children }: VisualEditorProviderProps) => {
const storeRef = useRef<VisualEditorStore>()
const storeRef = useRef<VisualEditorStore | null>(null)
if (!storeRef.current)
storeRef.current = createVisualEditorStore()

View File

@ -23,7 +23,7 @@ const useConfig = (id: string, payload: LLMNodeType) => {
const { nodesReadOnly: readOnly } = useNodesReadOnly()
const isChatMode = useIsChatMode()
const defaultConfig = useStore(s => s.nodesDefaultConfigs)[payload.type]
const defaultConfig = useStore(s => s.nodesDefaultConfigs)?.[payload.type]
const [defaultRolePrefix, setDefaultRolePrefix] = useState<{ user: string; assistant: string }>({ user: '', assistant: '' })
const { inputs, setInputs: doSetInputs } = useNodeCrud<LLMNodeType>(id, payload)
const inputRef = useRef(inputs)

View File

@ -10,7 +10,7 @@ export const checkNodeValid = (_payload: LLMNodeType) => {
export const getFieldType = (field: Field) => {
const { type, items } = field
if(field.schemaType === 'file') return 'file'
if(field.schemaType === 'file') return Type.file
if (type !== Type.array || !items)
return type

View File

@ -196,6 +196,9 @@ const ConditionItem = ({
draft.varType = varItem.type
draft.value = ''
draft.comparison_operator = getOperators(varItem.type)[0]
delete draft.key
delete draft.sub_variable_condition
delete draft.numberVarType
})
doUpdateCondition(newCondition)
setOpen(false)

View File

@ -9,7 +9,10 @@ import BlockSelector from '../../../../block-selector'
import type { Param, ParamType } from '../../types'
import cn from '@/utils/classnames'
import { useStore } from '@/app/components/workflow/store'
import type { PluginDefaultValue } from '@/app/components/workflow/block-selector/types'
import type {
PluginDefaultValue,
ToolDefaultValue,
} from '@/app/components/workflow/block-selector/types'
import type { ToolParameter } from '@/app/components/tools/types'
import { CollectionType } from '@/app/components/tools/types'
import type { BlockEnum } from '@/app/components/workflow/types'
@ -44,9 +47,10 @@ const ImportFromTool: FC<Props> = ({
const workflowTools = useStore(s => s.workflowTools)
const handleSelectTool = useCallback((_type: BlockEnum, toolInfo?: PluginDefaultValue) => {
if (!toolInfo || !('tool_name' in toolInfo))
if (!toolInfo || 'datasource_name' in toolInfo || !('tool_name' in toolInfo))
return
const { provider_id, provider_type, tool_name: tool_name } = toolInfo!
const { provider_id, provider_type, tool_name } = toolInfo as ToolDefaultValue
const currentTools = (() => {
switch (provider_type) {
case CollectionType.builtIn:

View File

@ -27,7 +27,7 @@ const useConfig = (id: string, payload: ParameterExtractorNodeType) => {
const { handleOutVarRenameChange } = useWorkflow()
const isChatMode = useIsChatMode()
const defaultConfig = useStore(s => s.nodesDefaultConfigs)[payload.type]
const defaultConfig = useStore(s => s.nodesDefaultConfigs)?.[payload.type]
const [defaultRolePrefix, setDefaultRolePrefix] = useState<{ user: string; assistant: string }>({ user: '', assistant: '' })
const { inputs, setInputs: doSetInputs } = useNodeCrud<ParameterExtractorNodeType>(id, payload)

View File

@ -20,7 +20,7 @@ const useConfig = (id: string, payload: QuestionClassifierNodeType) => {
const updateNodeInternals = useUpdateNodeInternals()
const { nodesReadOnly: readOnly } = useNodesReadOnly()
const isChatMode = useIsChatMode()
const defaultConfig = useStore(s => s.nodesDefaultConfigs)[payload.type]
const defaultConfig = useStore(s => s.nodesDefaultConfigs)?.[payload.type]
const { getBeforeNodesInSameBranch } = useWorkflow()
const startNode = getBeforeNodesInSameBranch(id).find(node => node.data.type === BlockEnum.Start)
const startNodeId = startNode?.id

View File

@ -13,7 +13,7 @@ import useAvailableVarList from '@/app/components/workflow/nodes/_base/hooks/use
const useConfig = (id: string, payload: TemplateTransformNodeType) => {
const { nodesReadOnly: readOnly } = useNodesReadOnly()
const defaultConfig = useStore(s => s.nodesDefaultConfigs)[payload.type]
const defaultConfig = useStore(s => s.nodesDefaultConfigs)?.[payload.type]
const { inputs, setInputs: doSetInputs } = useNodeCrud<TemplateTransformNodeType>(id, payload)
const inputsRef = useRef(inputs)

View File

@ -9,7 +9,7 @@ import Button from '@/app/components/base/button'
import type { AgentLogItemWithChildren } from '@/types/workflow'
type AgentLogNavMoreProps = {
options: { id: string; label: string }[]
options: AgentLogItemWithChildren[]
onShowAgentOrToolLog: (detail?: AgentLogItemWithChildren) => void
}
const AgentLogNavMore = ({
@ -41,10 +41,10 @@ const AgentLogNavMore = ({
{
options.map(option => (
<div
key={option.id}
key={option.message_id}
className='system-md-regular flex h-8 cursor-pointer items-center rounded-lg px-2 text-text-secondary hover:bg-state-base-hover'
onClick={() => {
onShowAgentOrToolLog(option as AgentLogItemWithChildren)
onShowAgentOrToolLog(option)
setOpen(false)
}}
>

View File

@ -23,8 +23,10 @@ import {
} from '../node-handle'
import ErrorHandleOnNode from '../error-handle-on-node'
type NodeChildElement = ReactElement<Partial<NodeProps>>
type NodeCardProps = NodeProps & {
children?: ReactElement
children?: NodeChildElement
}
const BaseCard = ({

View File

@ -242,7 +242,7 @@ const DebugConfigurationContext = createContext<IDebugConfiguration>({
},
datasetConfigsRef: {
current: null,
},
} as unknown as RefObject<DatasetConfigs>,
setDatasetConfigs: noop,
hasSetContextVar: false,
isShowVisionConfig: false,

View File

@ -1,5 +1,6 @@
'use client'
import { useEffect } from 'react'
import { validateRedirectUrl } from '@/utils/urlValidation'
export const useOAuthCallback = () => {
useEffect(() => {
@ -40,6 +41,7 @@ export const openOAuthPopup = (url: string, callback: (data?: any) => void) => {
const left = window.screenX + (window.outerWidth - width) / 2
const top = window.screenY + (window.outerHeight - height) / 2
validateRedirectUrl(url)
const popup = window.open(
url,
'OAuth',

View File

@ -145,7 +145,7 @@
"@babel/core": "^7.28.3",
"@chromatic-com/storybook": "^3.1.0",
"@eslint-react/eslint-plugin": "^1.15.0",
"@happy-dom/jest-environment": "^20.0.0",
"@happy-dom/jest-environment": "^20.0.2",
"@mdx-js/loader": "^3.1.0",
"@mdx-js/react": "^3.1.0",
"@next/bundle-analyzer": "15.5.4",

View File

@ -348,8 +348,8 @@ importers:
specifier: ^1.15.0
version: 1.52.3(eslint@9.35.0(jiti@2.6.1))(ts-api-utils@2.1.0(typescript@5.8.3))(typescript@5.8.3)
'@happy-dom/jest-environment':
specifier: ^20.0.0
version: 20.0.0(@jest/environment@29.7.0)(@jest/fake-timers@29.7.0)(@jest/types@29.6.3)(jest-mock@29.7.0)(jest-util@29.7.0)
specifier: ^20.0.2
version: 20.0.4(@jest/environment@29.7.0)(@jest/fake-timers@29.7.0)(@jest/types@29.6.3)(jest-mock@29.7.0)(jest-util@29.7.0)
'@mdx-js/loader':
specifier: ^3.1.0
version: 3.1.0(acorn@8.15.0)(webpack@5.100.2(esbuild@0.25.0)(uglify-js@3.19.3))
@ -1647,8 +1647,8 @@ packages:
'@formatjs/intl-localematcher@0.5.10':
resolution: {integrity: sha512-af3qATX+m4Rnd9+wHcjJ4w2ijq+rAVP3CCinJQvFv1kgSu1W6jypUmvleJxcewdxmutM8dmIRZFxO/IQBZmP2Q==}
'@happy-dom/jest-environment@20.0.0':
resolution: {integrity: sha512-dUyMDNJzPDFopSDyzKdbeYs8z9B4jLj9kXnru8TjYdGeLsQKf+6r0lq/9T2XVcu04QFxXMykt64A+KjsaJTaNA==}
'@happy-dom/jest-environment@20.0.4':
resolution: {integrity: sha512-75OcYtjO+jqxWiYiXvbwR8JZITX1/8iAjRSRpZ/rNjO6UnYebwX6HdI91Ix09xYZEO1X/xOof6HX1EiZnrgnXA==}
engines: {node: '>=20.0.0'}
peerDependencies:
'@jest/environment': '>=25.0.0'
@ -5582,8 +5582,8 @@ packages:
hachure-fill@0.5.2:
resolution: {integrity: sha512-3GKBOn+m2LX9iq+JC1064cSFprJY4jL1jCXTcpnfER5HYE2l/4EfWSGzkPa/ZDBmYI0ZOEj5VHV/eKnPGkHuOg==}
happy-dom@20.0.0:
resolution: {integrity: sha512-GkWnwIFxVGCf2raNrxImLo397RdGhLapj5cT3R2PT7FwL62Ze1DROhzmYW7+J3p9105DYMVenEejEbnq5wA37w==}
happy-dom@20.0.4:
resolution: {integrity: sha512-WxFtvnij6G64/MtMimnZhF0nKx3LUQKc20zjATD6tKiqOykUwQkd+2FW/DZBAFNjk4oWh0xdv/HBleGJmSY/Iw==}
engines: {node: '>=20.0.0'}
has-flag@4.0.0:
@ -10143,12 +10143,12 @@ snapshots:
dependencies:
tslib: 2.8.1
'@happy-dom/jest-environment@20.0.0(@jest/environment@29.7.0)(@jest/fake-timers@29.7.0)(@jest/types@29.6.3)(jest-mock@29.7.0)(jest-util@29.7.0)':
'@happy-dom/jest-environment@20.0.4(@jest/environment@29.7.0)(@jest/fake-timers@29.7.0)(@jest/types@29.6.3)(jest-mock@29.7.0)(jest-util@29.7.0)':
dependencies:
'@jest/environment': 29.7.0
'@jest/fake-timers': 29.7.0
'@jest/types': 29.6.3
happy-dom: 20.0.0
happy-dom: 20.0.4
jest-mock: 29.7.0
jest-util: 29.7.0
@ -14802,7 +14802,7 @@ snapshots:
hachure-fill@0.5.2: {}
happy-dom@20.0.0:
happy-dom@20.0.4:
dependencies:
'@types/node': 20.19.20
'@types/whatwg-mimetype': 3.0.2

View File

@ -324,7 +324,7 @@ const baseFetch = base
type UploadOptions = {
xhr: XMLHttpRequest
method: string
method?: string
url?: string
headers?: Record<string, string>
data: FormData

View File

@ -0,0 +1,23 @@
/**
* Validates that a URL is safe for redirection.
* Only allows HTTP and HTTPS protocols to prevent XSS attacks.
*
* @param url - The URL string to validate
* @throws Error if the URL has an unsafe protocol
*/
export function validateRedirectUrl(url: string): void {
try {
const parsedUrl = new URL(url)
if (parsedUrl.protocol !== 'http:' && parsedUrl.protocol !== 'https:')
throw new Error('Authorization URL must be HTTP or HTTPS')
}
catch (error) {
if (
error instanceof Error
&& error.message === 'Authorization URL must be HTTP or HTTPS'
)
throw error
// If URL parsing fails, it's also invalid
throw new Error(`Invalid URL: ${url}`)
}
}