mirror of https://github.com/langgenius/dify.git
Merge branch 'main' into 12-24-json-for-translation
This commit is contained in:
commit
d5418babdf
|
|
@ -1,8 +1,9 @@
|
|||
import base64
|
||||
from typing import Literal
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
|
@ -15,22 +16,8 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
|||
|
||||
|
||||
class SubscriptionQuery(BaseModel):
|
||||
plan: str = Field(..., description="Subscription plan")
|
||||
interval: str = Field(..., description="Billing interval")
|
||||
|
||||
@field_validator("plan")
|
||||
@classmethod
|
||||
def validate_plan(cls, value: str) -> str:
|
||||
if value not in [CloudPlan.PROFESSIONAL, CloudPlan.TEAM]:
|
||||
raise ValueError("Invalid plan")
|
||||
return value
|
||||
|
||||
@field_validator("interval")
|
||||
@classmethod
|
||||
def validate_interval(cls, value: str) -> str:
|
||||
if value not in {"month", "year"}:
|
||||
raise ValueError("Invalid interval")
|
||||
return value
|
||||
plan: Literal[CloudPlan.PROFESSIONAL, CloudPlan.TEAM] = Field(..., description="Subscription plan")
|
||||
interval: Literal["month", "year"] = Field(..., description="Billing interval")
|
||||
|
||||
|
||||
class PartnerTenantsPayload(BaseModel):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import logging
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import marshal_with
|
||||
|
|
@ -26,6 +25,7 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni
|
|||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from fields.message_fields import message_infinite_scroll_pagination_fields
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
|
@ -44,8 +44,8 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class MessageListQuery(BaseModel):
|
||||
conversation_id: UUID
|
||||
first_id: UUID | None = None
|
||||
conversation_id: UUIDStrOrEmpty
|
||||
first_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
|
|
@ -10,19 +8,19 @@ from controllers.console import console_ns
|
|||
from controllers.console.explore.error import NotCompletionAppError
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField
|
||||
from libs.helper import TimestampField, UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
|
||||
|
||||
class SavedMessageListQuery(BaseModel):
|
||||
last_id: UUID | None = None
|
||||
last_id: UUIDStrOrEmpty | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SavedMessageCreatePayload(BaseModel):
|
||||
message_id: UUID
|
||||
message_id: UUIDStrOrEmpty
|
||||
|
||||
|
||||
register_schema_models(console_ns, SavedMessageListQuery, SavedMessageCreatePayload)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from flask_restx import Resource, reqparse
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
|
@ -10,10 +12,20 @@ from models import TenantAccountRole
|
|||
from services.model_load_balancing_service import ModelLoadBalancingService
|
||||
|
||||
|
||||
class LoadBalancingCredentialPayload(BaseModel):
|
||||
model: str
|
||||
model_type: ModelType
|
||||
credentials: dict[str, object]
|
||||
|
||||
|
||||
register_schema_models(console_ns, LoadBalancingCredentialPayload)
|
||||
|
||||
|
||||
@console_ns.route(
|
||||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/credentials-validate"
|
||||
)
|
||||
class LoadBalancingCredentialsValidateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -24,20 +36,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
|||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# validate model load balancing credentials
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
|
|
@ -49,9 +48,9 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
|||
model_load_balancing_service.validate_load_balancing_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
credentials=args["credentials"],
|
||||
model=payload.model,
|
||||
model_type=payload.model_type,
|
||||
credentials=payload.credentials,
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
result = False
|
||||
|
|
@ -69,6 +68,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
|
|||
"/workspaces/current/model-providers/<path:provider>/models/load-balancing-configs/<string:config_id>/credentials-validate"
|
||||
)
|
||||
class LoadBalancingConfigCredentialsValidateApi(Resource):
|
||||
@console_ns.expect(console_ns.models[LoadBalancingCredentialPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
|
|
@ -79,20 +79,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
|||
|
||||
tenant_id = current_tenant_id
|
||||
|
||||
parser = (
|
||||
reqparse.RequestParser()
|
||||
.add_argument("model", type=str, required=True, nullable=False, location="json")
|
||||
.add_argument(
|
||||
"model_type",
|
||||
type=str,
|
||||
required=True,
|
||||
nullable=False,
|
||||
choices=[mt.value for mt in ModelType],
|
||||
location="json",
|
||||
)
|
||||
.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||
)
|
||||
args = parser.parse_args()
|
||||
payload = LoadBalancingCredentialPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
# validate model load balancing config credentials
|
||||
model_load_balancing_service = ModelLoadBalancingService()
|
||||
|
|
@ -104,9 +91,9 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
|
|||
model_load_balancing_service.validate_load_balancing_credentials(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider,
|
||||
model=args["model"],
|
||||
model_type=args["model_type"],
|
||||
credentials=args["credentials"],
|
||||
model=payload.model,
|
||||
model_type=payload.model_type,
|
||||
credentials=payload.credentials,
|
||||
config_id=config_id,
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import io
|
||||
import logging
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from flask import make_response, redirect, request, send_file
|
||||
|
|
@ -17,6 +18,7 @@ from controllers.console.wraps import (
|
|||
is_admin_or_owner_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
|
||||
from core.helper.tool_provider_cache import ToolProviderListCache
|
||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
|
|
@ -40,6 +42,8 @@ from services.tools.tools_manage_service import ToolCommonService
|
|||
from services.tools.tools_transform_service import ToolTransformService
|
||||
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def is_valid_url(url: str) -> bool:
|
||||
if not url:
|
||||
|
|
@ -945,8 +949,8 @@ class ToolProviderMCPApi(Resource):
|
|||
configuration = MCPConfiguration.model_validate(args["configuration"])
|
||||
authentication = MCPAuthentication.model_validate(args["authentication"]) if args["authentication"] else None
|
||||
|
||||
# Create provider in transaction
|
||||
with Session(db.engine) as session, session.begin():
|
||||
# 1) Create provider in a short transaction (no network I/O inside)
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
result = service.create_provider(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -962,7 +966,28 @@ class ToolProviderMCPApi(Resource):
|
|||
authentication=authentication,
|
||||
)
|
||||
|
||||
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
|
||||
# 2) Try to fetch tools immediately after creation so they appear without a second save.
|
||||
# Perform network I/O outside any DB session to avoid holding locks.
|
||||
try:
|
||||
reconnect = MCPToolManageService.reconnect_with_url(
|
||||
server_url=args["server_url"],
|
||||
headers=args.get("headers") or {},
|
||||
timeout=configuration.timeout,
|
||||
sse_read_timeout=configuration.sse_read_timeout,
|
||||
)
|
||||
# Update just-created provider with authed/tools in a new short transaction
|
||||
with session_factory.create_session() as session, session.begin():
|
||||
service = MCPToolManageService(session=session)
|
||||
db_provider = service.get_provider(provider_id=result.id, tenant_id=tenant_id)
|
||||
db_provider.authed = reconnect.authed
|
||||
db_provider.tools = reconnect.tools
|
||||
|
||||
result = ToolTransformService.mcp_provider_to_user_provider(db_provider, for_list=True)
|
||||
except Exception:
|
||||
# Best-effort: if initial fetch fails (e.g., auth required), return created provider as-is
|
||||
logger.warning("Failed to fetch MCP tools after creation", exc_info=True)
|
||||
|
||||
# Final cache invalidation to ensure list views are up to date
|
||||
ToolProviderListCache.invalidate_cache(tenant_id)
|
||||
|
||||
return jsonable_encoder(result)
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from controllers.service_api.dataset.error import DatasetInUseError, DatasetName
|
|||
from controllers.service_api.wraps import (
|
||||
DatasetApiResource,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
validate_dataset_token,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
|
|
@ -460,9 +459,8 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
def get(self, _, dataset_id):
|
||||
def get(self, _):
|
||||
"""Get all knowledge type tags."""
|
||||
assert isinstance(current_user, Account)
|
||||
cid = current_user.current_tenant_id
|
||||
|
|
@ -482,8 +480,7 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
}
|
||||
)
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
def post(self, _):
|
||||
"""Add a knowledge type tag."""
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
|
|
@ -506,8 +503,7 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
}
|
||||
)
|
||||
@service_api_ns.marshal_with(build_dataset_tag_fields(service_api_ns))
|
||||
@validate_dataset_token
|
||||
def patch(self, _, dataset_id):
|
||||
def patch(self, _):
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
raise Forbidden()
|
||||
|
|
@ -533,9 +529,8 @@ class DatasetTagsApi(DatasetApiResource):
|
|||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
@edit_permission_required
|
||||
def delete(self, _, dataset_id):
|
||||
def delete(self, _):
|
||||
"""Delete a knowledge type tag."""
|
||||
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag(payload.tag_id)
|
||||
|
|
@ -555,8 +550,7 @@ class DatasetTagBindingApi(DatasetApiResource):
|
|||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
def post(self, _):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
|
|
@ -580,8 +574,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
|||
403: "Forbidden - insufficient permissions",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
def post(self, _, dataset_id):
|
||||
def post(self, _):
|
||||
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
|
||||
assert isinstance(current_user, Account)
|
||||
if not (current_user.has_edit_permission or current_user.is_dataset_editor):
|
||||
|
|
@ -604,7 +597,6 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
|
|||
401: "Unauthorized - invalid API token",
|
||||
}
|
||||
)
|
||||
@validate_dataset_token
|
||||
def get(self, _, *args, **kwargs):
|
||||
"""Get all knowledge type tags."""
|
||||
dataset_id = kwargs.get("dataset_id")
|
||||
|
|
|
|||
|
|
@ -319,8 +319,14 @@ class MCPToolManageService:
|
|||
except MCPError as e:
|
||||
raise ValueError(f"Failed to connect to MCP server: {e}")
|
||||
|
||||
# Update database with retrieved tools
|
||||
db_provider.tools = json.dumps([tool.model_dump() for tool in tools])
|
||||
# Update database with retrieved tools (ensure description is a non-null string)
|
||||
tools_payload = []
|
||||
for tool in tools:
|
||||
data = tool.model_dump()
|
||||
if data.get("description") is None:
|
||||
data["description"] = ""
|
||||
tools_payload.append(data)
|
||||
db_provider.tools = json.dumps(tools_payload)
|
||||
db_provider.authed = True
|
||||
db_provider.updated_at = datetime.now()
|
||||
self._session.flush()
|
||||
|
|
@ -620,6 +626,21 @@ class MCPToolManageService:
|
|||
server_url_hash=new_server_url_hash,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def reconnect_with_url(
|
||||
*,
|
||||
server_url: str,
|
||||
headers: dict[str, str],
|
||||
timeout: float | None,
|
||||
sse_read_timeout: float | None,
|
||||
) -> ReconnectResult:
|
||||
return MCPToolManageService._reconnect_with_url(
|
||||
server_url=server_url,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
sse_read_timeout=sse_read_timeout,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _reconnect_with_url(
|
||||
*,
|
||||
|
|
@ -642,9 +663,16 @@ class MCPToolManageService:
|
|||
sse_read_timeout=sse_read_timeout,
|
||||
) as mcp_client:
|
||||
tools = mcp_client.list_tools()
|
||||
# Ensure tool descriptions are non-null in payload
|
||||
tools_payload = []
|
||||
for t in tools:
|
||||
d = t.model_dump()
|
||||
if d.get("description") is None:
|
||||
d["description"] = ""
|
||||
tools_payload.append(d)
|
||||
return ReconnectResult(
|
||||
authed=True,
|
||||
tools=json.dumps([tool.model_dump() for tool in tools]),
|
||||
tools=json.dumps(tools_payload),
|
||||
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
|
||||
)
|
||||
except MCPAuthError:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,145 @@
|
|||
"""Unit tests for load balancing credential validation APIs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import importlib
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
from models.account import TenantAccountRole
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def load_balancing_module(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Reload controller module with lightweight decorators for testing."""
|
||||
|
||||
from controllers.console import console_ns, wraps
|
||||
from libs import login
|
||||
|
||||
def _noop(func):
|
||||
return func
|
||||
|
||||
monkeypatch.setattr(login, "login_required", _noop)
|
||||
monkeypatch.setattr(wraps, "setup_required", _noop)
|
||||
monkeypatch.setattr(wraps, "account_initialization_required", _noop)
|
||||
|
||||
def _noop_route(*args, **kwargs): # type: ignore[override]
|
||||
def _decorator(cls):
|
||||
return cls
|
||||
|
||||
return _decorator
|
||||
|
||||
monkeypatch.setattr(console_ns, "route", _noop_route)
|
||||
|
||||
module_name = "controllers.console.workspace.load_balancing_config"
|
||||
sys.modules.pop(module_name, None)
|
||||
module = importlib.import_module(module_name)
|
||||
return module
|
||||
|
||||
|
||||
def _mock_user(role: TenantAccountRole) -> SimpleNamespace:
|
||||
return SimpleNamespace(current_role=role)
|
||||
|
||||
|
||||
def _prepare_context(module, monkeypatch: pytest.MonkeyPatch, role=TenantAccountRole.OWNER):
|
||||
user = _mock_user(role)
|
||||
monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "tenant-123"))
|
||||
mock_service = MagicMock()
|
||||
monkeypatch.setattr(module, "ModelLoadBalancingService", lambda: mock_service)
|
||||
return mock_service
|
||||
|
||||
|
||||
def _request_payload():
|
||||
return {"model": "gpt-4o", "model_type": ModelType.LLM, "credentials": {"api_key": "sk-***"}}
|
||||
|
||||
|
||||
def test_validate_credentials_success(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
|
||||
service = _prepare_context(load_balancing_module, monkeypatch)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
|
||||
method="POST",
|
||||
json=_request_payload(),
|
||||
):
|
||||
response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai")
|
||||
|
||||
assert response == {"result": "success"}
|
||||
service.validate_load_balancing_credentials.assert_called_once_with(
|
||||
tenant_id="tenant-123",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
credentials={"api_key": "sk-***"},
|
||||
)
|
||||
|
||||
|
||||
def test_validate_credentials_returns_error_message(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
|
||||
service = _prepare_context(load_balancing_module, monkeypatch)
|
||||
service.validate_load_balancing_credentials.side_effect = CredentialsValidateFailedError("invalid credentials")
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
|
||||
method="POST",
|
||||
json=_request_payload(),
|
||||
):
|
||||
response = load_balancing_module.LoadBalancingCredentialsValidateApi().post(provider="openai")
|
||||
|
||||
assert response == {"result": "error", "error": "invalid credentials"}
|
||||
|
||||
|
||||
def test_validate_credentials_requires_privileged_role(
|
||||
app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
_prepare_context(load_balancing_module, monkeypatch, role=TenantAccountRole.NORMAL)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/model-providers/openai/models/load-balancing-configs/credentials-validate",
|
||||
method="POST",
|
||||
json=_request_payload(),
|
||||
):
|
||||
api = load_balancing_module.LoadBalancingCredentialsValidateApi()
|
||||
with pytest.raises(Forbidden):
|
||||
api.post(provider="openai")
|
||||
|
||||
|
||||
def test_validate_credentials_with_config_id(app: Flask, load_balancing_module, monkeypatch: pytest.MonkeyPatch):
|
||||
service = _prepare_context(load_balancing_module, monkeypatch)
|
||||
|
||||
with app.test_request_context(
|
||||
"/workspaces/current/model-providers/openai/models/load-balancing-configs/cfg-1/credentials-validate",
|
||||
method="POST",
|
||||
json=_request_payload(),
|
||||
):
|
||||
response = load_balancing_module.LoadBalancingConfigCredentialsValidateApi().post(
|
||||
provider="openai", config_id="cfg-1"
|
||||
)
|
||||
|
||||
assert response == {"result": "success"}
|
||||
service.validate_load_balancing_credentials.assert_called_once_with(
|
||||
tenant_id="tenant-123",
|
||||
provider="openai",
|
||||
model="gpt-4o",
|
||||
model_type=ModelType.LLM,
|
||||
credentials={"api_key": "sk-***"},
|
||||
config_id="cfg-1",
|
||||
)
|
||||
|
|
@ -0,0 +1,103 @@
|
|||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_restx import Api
|
||||
|
||||
from controllers.console.workspace.tool_providers import ToolProviderMCPApi
|
||||
from core.db.session_factory import configure_session_factory
|
||||
from extensions.ext_database import db
|
||||
from services.tools.mcp_tools_manage_service import ReconnectResult
|
||||
|
||||
|
||||
# Backward-compat fixtures referenced by @pytest.mark.usefixtures in this file.
|
||||
# They are intentionally no-ops because the test already patches the required
|
||||
# behaviors explicitly via @patch and context managers below.
|
||||
@pytest.fixture
|
||||
def _mock_cache():
|
||||
return
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _mock_user_tenant():
|
||||
return
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
|
||||
api = Api(app)
|
||||
api.add_resource(ToolProviderMCPApi, "/console/api/workspaces/current/tool-provider/mcp")
|
||||
db.init_app(app)
|
||||
# Configure session factory used by controller code
|
||||
with app.app_context():
|
||||
configure_session_factory(db.engine)
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")
|
||||
)
|
||||
@patch("controllers.console.workspace.tool_providers.ToolProviderListCache.invalidate_cache", return_value=None)
|
||||
@patch("controllers.console.workspace.tool_providers.Session")
|
||||
@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url")
|
||||
@pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant")
|
||||
def test_create_mcp_provider_populates_tools(
|
||||
mock_reconnect, mock_session, mock_invalidate_cache, mock_current_account_with_tenant, client
|
||||
):
|
||||
# Arrange: reconnect returns tools immediately
|
||||
mock_reconnect.return_value = ReconnectResult(
|
||||
authed=True,
|
||||
tools=json.dumps(
|
||||
[{"name": "ping", "description": "ok", "inputSchema": {"type": "object"}, "outputSchema": {}}]
|
||||
),
|
||||
encrypted_credentials="{}",
|
||||
)
|
||||
|
||||
# Fake service.create_provider -> returns object with id for reload
|
||||
svc = MagicMock()
|
||||
create_result = MagicMock()
|
||||
create_result.id = "provider-1"
|
||||
svc.create_provider.return_value = create_result
|
||||
svc.get_provider.return_value = MagicMock(id="provider-1", tenant_id="t1") # used by reload path
|
||||
mock_session.return_value.__enter__.return_value = MagicMock()
|
||||
# Patch MCPToolManageService constructed inside controller
|
||||
with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc):
|
||||
payload = {
|
||||
"server_url": "http://example.com/mcp",
|
||||
"name": "demo",
|
||||
"icon": "😀",
|
||||
"icon_type": "emoji",
|
||||
"icon_background": "#000",
|
||||
"server_identifier": "demo-sid",
|
||||
"configuration": {"timeout": 5, "sse_read_timeout": 30},
|
||||
"headers": {},
|
||||
"authentication": {},
|
||||
}
|
||||
# Act
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), # bypass setup_required DB check
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")),
|
||||
patch("libs.login.check_csrf_token", return_value=None), # bypass CSRF in login_required
|
||||
patch("libs.login._get_user", return_value=MagicMock(id="u1", is_authenticated=True)), # login
|
||||
patch(
|
||||
"services.tools.tools_transform_service.ToolTransformService.mcp_provider_to_user_provider",
|
||||
return_value={"id": "provider-1", "tools": [{"name": "ping"}]},
|
||||
),
|
||||
):
|
||||
resp = client.post(
|
||||
"/console/api/workspaces/current/tool-provider/mcp",
|
||||
data=json.dumps(payload),
|
||||
content_type="application/json",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert resp.status_code == 200
|
||||
body = resp.get_json()
|
||||
assert body.get("id") == "provider-1"
|
||||
# 若 transform 后包含 tools 字段,确保非空
|
||||
assert isinstance(body.get("tools"), list)
|
||||
assert body["tools"]
|
||||
|
|
@ -399,6 +399,7 @@ CONSOLE_CORS_ALLOW_ORIGINS=*
|
|||
COOKIE_DOMAIN=
|
||||
# When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1.
|
||||
NEXT_PUBLIC_COOKIE_DOMAIN=
|
||||
NEXT_PUBLIC_BATCH_CONCURRENCY=5
|
||||
|
||||
# ------------------------------
|
||||
# File Storage Configuration
|
||||
|
|
|
|||
|
|
@ -108,6 +108,7 @@ x-shared-env: &shared-api-worker-env
|
|||
CONSOLE_CORS_ALLOW_ORIGINS: ${CONSOLE_CORS_ALLOW_ORIGINS:-*}
|
||||
COOKIE_DOMAIN: ${COOKIE_DOMAIN:-}
|
||||
NEXT_PUBLIC_COOKIE_DOMAIN: ${NEXT_PUBLIC_COOKIE_DOMAIN:-}
|
||||
NEXT_PUBLIC_BATCH_CONCURRENCY: ${NEXT_PUBLIC_BATCH_CONCURRENCY:-5}
|
||||
STORAGE_TYPE: ${STORAGE_TYPE:-opendal}
|
||||
OPENDAL_SCHEME: ${OPENDAL_SCHEME:-fs}
|
||||
OPENDAL_FS_ROOT: ${OPENDAL_FS_ROOT:-storage}
|
||||
|
|
|
|||
|
|
@ -73,3 +73,6 @@ NEXT_PUBLIC_MAX_TREE_DEPTH=50
|
|||
|
||||
# The API key of amplitude
|
||||
NEXT_PUBLIC_AMPLITUDE_API_KEY=
|
||||
|
||||
# number of concurrency
|
||||
NEXT_PUBLIC_BATCH_CONCURRENCY=5
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ const DatasetConfig: FC = () => {
|
|||
}))
|
||||
}, [setDatasetConfigs, datasetConfigsRef])
|
||||
|
||||
const handleAddCondition = useCallback<HandleAddCondition>(({ name, type }) => {
|
||||
const handleAddCondition = useCallback<HandleAddCondition>(({ id, name, type }) => {
|
||||
let operator: ComparisonOperator = ComparisonOperator.is
|
||||
|
||||
if (type === MetadataFilteringVariableType.number)
|
||||
|
|
@ -184,6 +184,7 @@ const DatasetConfig: FC = () => {
|
|||
|
||||
const newCondition = {
|
||||
id: uuid4(),
|
||||
metadata_id: id, // Save metadata.id for reliable reference
|
||||
name,
|
||||
comparison_operator: operator,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -679,7 +679,7 @@ const Configuration: FC = () => {
|
|||
const toolInCollectionList = collectionList.find(c => tool.provider_id === c.id)
|
||||
return {
|
||||
...tool,
|
||||
isDeleted: res.deleted_tools?.some((deletedTool: any) => deletedTool.id === tool.id && deletedTool.tool_name === tool.tool_name) ?? false,
|
||||
isDeleted: res.deleted_tools?.some((deletedTool: any) => deletedTool.provider_id === tool.provider_id && deletedTool.tool_name === tool.tool_name) ?? false,
|
||||
notAuthor: toolInCollectionList?.is_team_authorization === false,
|
||||
...(tool.provider_type === 'builtin'
|
||||
? {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,136 @@
|
|||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import Apps from './index'
|
||||
|
||||
const mockUseExploreAppList = vi.fn()
|
||||
|
||||
vi.mock('ahooks', () => ({
|
||||
useDebounceFn: (fn: () => void) => ({
|
||||
run: () => setTimeout(fn, 0),
|
||||
cancel: vi.fn(),
|
||||
flush: () => fn(),
|
||||
}),
|
||||
}))
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: () => ({ isCurrentWorkspaceEditor: true }),
|
||||
}))
|
||||
vi.mock('use-context-selector', async () => {
|
||||
const actual = await vi.importActual<typeof import('use-context-selector')>('use-context-selector')
|
||||
return {
|
||||
...actual,
|
||||
useContext: () => ({ hasEditPermission: true }),
|
||||
}
|
||||
})
|
||||
vi.mock('@/hooks/use-tab-searchparams', () => ({
|
||||
useTabSearchParams: () => ['Recommended', vi.fn()],
|
||||
}))
|
||||
vi.mock('@/service/use-explore', () => ({
|
||||
useExploreAppList: () => mockUseExploreAppList(),
|
||||
}))
|
||||
vi.mock('@/app/components/app/type-selector', () => ({
|
||||
__esModule: true,
|
||||
default: ({ value, onChange }: { value: AppModeEnum[], onChange: (value: AppModeEnum[]) => void }) => (
|
||||
<button data-testid="type-selector" onClick={() => onChange([...value, 'chat' as AppModeEnum])}>{value.join(',')}</button>
|
||||
),
|
||||
}))
|
||||
vi.mock('../app-card', () => ({
|
||||
__esModule: true,
|
||||
default: ({ app, onCreate }: { app: any, onCreate: () => void }) => (
|
||||
<div
|
||||
data-testid="app-card"
|
||||
data-name={app.app.name}
|
||||
onClick={onCreate}
|
||||
>
|
||||
{app.app.name}
|
||||
</div>
|
||||
),
|
||||
}))
|
||||
vi.mock('@/app/components/explore/create-app-modal', () => ({
|
||||
__esModule: true,
|
||||
default: () => <div data-testid="create-from-template-modal" />,
|
||||
}))
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: { notify: vi.fn() },
|
||||
}))
|
||||
vi.mock('@/app/components/base/amplitude', () => ({
|
||||
trackEvent: vi.fn(),
|
||||
}))
|
||||
vi.mock('@/service/apps', () => ({
|
||||
importDSL: vi.fn().mockResolvedValue({ app_id: '1' }),
|
||||
}))
|
||||
vi.mock('@/service/explore', () => ({
|
||||
fetchAppDetail: vi.fn().mockResolvedValue({
|
||||
export_data: 'dsl',
|
||||
mode: 'chat',
|
||||
}),
|
||||
}))
|
||||
vi.mock('@/app/components/workflow/plugin-dependency/hooks', () => ({
|
||||
usePluginDependencies: () => ({
|
||||
handleCheckPluginDependencies: vi.fn(),
|
||||
}),
|
||||
}))
|
||||
vi.mock('@/utils/app-redirection', () => ({
|
||||
getRedirection: vi.fn(),
|
||||
}))
|
||||
vi.mock('next/navigation', () => ({
|
||||
useRouter: () => ({ push: vi.fn() }),
|
||||
}))
|
||||
|
||||
const createAppEntry = (name: string, category: string) => ({
|
||||
app_id: name,
|
||||
category,
|
||||
app: {
|
||||
id: name,
|
||||
name,
|
||||
icon_type: 'emoji',
|
||||
icon: '🙂',
|
||||
icon_background: '#000',
|
||||
icon_url: null,
|
||||
description: 'desc',
|
||||
mode: AppModeEnum.CHAT,
|
||||
},
|
||||
})
|
||||
|
||||
describe('Apps', () => {
|
||||
const defaultData = {
|
||||
allList: [
|
||||
createAppEntry('Alpha', 'Cat A'),
|
||||
createAppEntry('Bravo', 'Cat B'),
|
||||
],
|
||||
categories: ['Cat A', 'Cat B'],
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseExploreAppList.mockReturnValue({
|
||||
data: defaultData,
|
||||
isLoading: false,
|
||||
})
|
||||
})
|
||||
|
||||
it('renders template cards when data is available', () => {
|
||||
render(<Apps />)
|
||||
|
||||
expect(screen.getAllByTestId('app-card')).toHaveLength(2)
|
||||
expect(screen.getByText('Alpha')).toBeInTheDocument()
|
||||
expect(screen.getByText('Bravo')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('opens create modal when a template card is clicked', () => {
|
||||
render(<Apps />)
|
||||
|
||||
fireEvent.click(screen.getAllByTestId('app-card')[0])
|
||||
expect(screen.getByTestId('create-from-template-modal')).toBeInTheDocument()
|
||||
})
|
||||
it('shows no template message when list is empty', () => {
|
||||
mockUseExploreAppList.mockReturnValueOnce({
|
||||
data: { allList: [], categories: [] },
|
||||
isLoading: false,
|
||||
})
|
||||
|
||||
render(<Apps />)
|
||||
|
||||
expect(screen.getByText('app.newApp.noTemplateFound')).toBeInTheDocument()
|
||||
expect(screen.getByText('app.newApp.noTemplateFoundTip')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import Sidebar, { AppCategories } from './sidebar'
|
||||
|
||||
vi.mock('@remixicon/react', () => ({
|
||||
RiStickyNoteAddLine: () => <span>sticky</span>,
|
||||
RiThumbUpLine: () => <span>thumb</span>,
|
||||
}))
|
||||
describe('Sidebar', () => {
|
||||
it('renders recommended and custom categories', () => {
|
||||
render(<Sidebar current={AppCategories.RECOMMENDED} categories={['Cat A', 'Cat B']} />)
|
||||
|
||||
expect(screen.getByText('app.newAppFromTemplate.sidebar.Recommended')).toBeInTheDocument()
|
||||
expect(screen.getByText('Cat A')).toBeInTheDocument()
|
||||
expect(screen.getByText('Cat B')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('notifies callbacks when items are clicked', () => {
|
||||
const onClick = vi.fn()
|
||||
const onCreate = vi.fn()
|
||||
render(
|
||||
<Sidebar
|
||||
current="Cat A"
|
||||
categories={['Cat A']}
|
||||
onClick={onClick}
|
||||
onCreateFromBlank={onCreate}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('app.newAppFromTemplate.sidebar.Recommended'))
|
||||
expect(onClick).toHaveBeenCalledWith(AppCategories.RECOMMENDED)
|
||||
|
||||
fireEvent.click(screen.getByText('Cat A'))
|
||||
expect(onClick).toHaveBeenCalledWith('Cat A')
|
||||
|
||||
fireEvent.click(screen.getByText('app.newApp.startFromBlank'))
|
||||
expect(onCreate).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,217 @@
|
|||
import type { ReactNode } from 'react'
|
||||
import type { ModalContextState } from '@/context/modal-context'
|
||||
import type { ProviderContextState } from '@/context/provider-context'
|
||||
import type { AppDetailResponse } from '@/models/app'
|
||||
import type { AppSSO } from '@/types/app'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { Plan } from '@/app/components/billing/type'
|
||||
import { baseProviderContextValue } from '@/context/provider-context'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import SettingsModal from './index'
|
||||
|
||||
vi.mock('react-i18next', async () => {
|
||||
const actual = await vi.importActual<typeof import('react-i18next')>('react-i18next')
|
||||
return {
|
||||
...actual,
|
||||
useTranslation: () => ({
|
||||
t: (key: string, options?: Record<string, unknown>) => {
|
||||
if (options?.returnObjects)
|
||||
return [`${key}-feature-1`, `${key}-feature-2`]
|
||||
if (options)
|
||||
return `${key}:${JSON.stringify(options)}`
|
||||
return key
|
||||
},
|
||||
i18n: {
|
||||
language: 'en',
|
||||
changeLanguage: vi.fn(),
|
||||
},
|
||||
}),
|
||||
Trans: ({ children }: { children?: ReactNode }) => <>{children}</>,
|
||||
}
|
||||
})
|
||||
|
||||
const mockNotify = vi.fn()
|
||||
const mockOnClose = vi.fn()
|
||||
const mockOnSave = vi.fn()
|
||||
const mockSetShowPricingModal = vi.fn()
|
||||
const mockSetShowAccountSettingModal = vi.fn()
|
||||
const mockUseProviderContext = vi.fn<() => ProviderContextState>()
|
||||
|
||||
const buildModalContext = (): ModalContextState => ({
|
||||
setShowAccountSettingModal: mockSetShowAccountSettingModal,
|
||||
setShowApiBasedExtensionModal: vi.fn(),
|
||||
setShowModerationSettingModal: vi.fn(),
|
||||
setShowExternalDataToolModal: vi.fn(),
|
||||
setShowPricingModal: mockSetShowPricingModal,
|
||||
setShowAnnotationFullModal: vi.fn(),
|
||||
setShowModelModal: vi.fn(),
|
||||
setShowExternalKnowledgeAPIModal: vi.fn(),
|
||||
setShowModelLoadBalancingModal: vi.fn(),
|
||||
setShowOpeningModal: vi.fn(),
|
||||
setShowUpdatePluginModal: vi.fn(),
|
||||
setShowEducationExpireNoticeModal: vi.fn(),
|
||||
setShowTriggerEventsLimitModal: vi.fn(),
|
||||
})
|
||||
|
||||
vi.mock('@/context/modal-context', () => ({
|
||||
useModalContext: () => buildModalContext(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/toast', async () => {
|
||||
const actual = await vi.importActual<typeof import('@/app/components/base/toast')>('@/app/components/base/toast')
|
||||
return {
|
||||
...actual,
|
||||
useToastContext: () => ({
|
||||
notify: mockNotify,
|
||||
close: vi.fn(),
|
||||
}),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/context/i18n', async () => {
|
||||
const actual = await vi.importActual<typeof import('@/context/i18n')>('@/context/i18n')
|
||||
return {
|
||||
...actual,
|
||||
useDocLink: () => (path?: string) => `https://docs.example.com${path ?? ''}`,
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/context/provider-context', async () => {
|
||||
const actual = await vi.importActual<typeof import('@/context/provider-context')>('@/context/provider-context')
|
||||
return {
|
||||
...actual,
|
||||
useProviderContext: () => mockUseProviderContext(),
|
||||
}
|
||||
})
|
||||
|
||||
const mockAppInfo = {
|
||||
site: {
|
||||
title: 'Test App',
|
||||
icon_type: 'emoji',
|
||||
icon: '😀',
|
||||
icon_background: '#ABCDEF',
|
||||
icon_url: 'https://example.com/icon.png',
|
||||
description: 'A description',
|
||||
chat_color_theme: '#123456',
|
||||
chat_color_theme_inverted: true,
|
||||
copyright: '© Dify',
|
||||
privacy_policy: '',
|
||||
custom_disclaimer: 'Disclaimer',
|
||||
default_language: 'en-US',
|
||||
show_workflow_steps: true,
|
||||
use_icon_as_answer_icon: true,
|
||||
},
|
||||
mode: AppModeEnum.ADVANCED_CHAT,
|
||||
enable_sso: false,
|
||||
} as unknown as AppDetailResponse & Partial<AppSSO>
|
||||
|
||||
const renderSettingsModal = () => render(
|
||||
<SettingsModal
|
||||
isChat
|
||||
isShow
|
||||
appInfo={mockAppInfo}
|
||||
onClose={mockOnClose}
|
||||
onSave={mockOnSave}
|
||||
/>,
|
||||
)
|
||||
|
||||
describe('SettingsModal', () => {
|
||||
beforeEach(() => {
|
||||
mockNotify.mockClear()
|
||||
mockOnClose.mockClear()
|
||||
mockOnSave.mockClear()
|
||||
mockSetShowPricingModal.mockClear()
|
||||
mockSetShowAccountSettingModal.mockClear()
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
...baseProviderContextValue,
|
||||
enableBilling: true,
|
||||
plan: {
|
||||
...baseProviderContextValue.plan,
|
||||
type: Plan.sandbox,
|
||||
},
|
||||
webappCopyrightEnabled: true,
|
||||
})
|
||||
})
|
||||
|
||||
it('should render the modal and expose the expanded settings section', async () => {
|
||||
renderSettingsModal()
|
||||
expect(screen.getByText('appOverview.overview.appInfo.settings.title')).toBeInTheDocument()
|
||||
|
||||
const showMoreEntry = screen.getByText('appOverview.overview.appInfo.settings.more.entry')
|
||||
fireEvent.click(showMoreEntry)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('appOverview.overview.appInfo.settings.more.copyRightPlaceholder')).toBeInTheDocument()
|
||||
expect(screen.getByPlaceholderText('appOverview.overview.appInfo.settings.more.privacyPolicyPlaceholder')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should notify the user when the name is empty', async () => {
|
||||
renderSettingsModal()
|
||||
const nameInput = screen.getByPlaceholderText('app.appNamePlaceholder')
|
||||
fireEvent.change(nameInput, { target: { value: '' } })
|
||||
fireEvent.click(screen.getByText('common.operation.save'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({ message: 'app.newApp.nameNotEmpty' }))
|
||||
})
|
||||
expect(mockOnSave).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should validate the theme color and show an error when the hex is invalid', async () => {
|
||||
renderSettingsModal()
|
||||
const colorInput = screen.getByPlaceholderText('E.g #A020F0')
|
||||
fireEvent.change(colorInput, { target: { value: 'not-a-hex' } })
|
||||
|
||||
fireEvent.click(screen.getByText('common.operation.save'))
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
message: 'appOverview.overview.appInfo.settings.invalidHexMessage',
|
||||
}))
|
||||
})
|
||||
expect(mockOnSave).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should validate the privacy policy URL when advanced settings are open', async () => {
|
||||
renderSettingsModal()
|
||||
fireEvent.click(screen.getByText('appOverview.overview.appInfo.settings.more.entry'))
|
||||
const privacyInput = screen.getByPlaceholderText('appOverview.overview.appInfo.settings.more.privacyPolicyPlaceholder')
|
||||
// eslint-disable-next-line sonarjs/no-clear-text-protocols
|
||||
fireEvent.change(privacyInput, { target: { value: 'ftp://invalid-url' } })
|
||||
|
||||
fireEvent.click(screen.getByText('common.operation.save'))
|
||||
await waitFor(() => {
|
||||
expect(mockNotify).toHaveBeenCalledWith(expect.objectContaining({
|
||||
message: 'appOverview.overview.appInfo.settings.invalidPrivacyPolicy',
|
||||
}))
|
||||
})
|
||||
expect(mockOnSave).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should save valid settings and close the modal', async () => {
|
||||
mockOnSave.mockResolvedValueOnce(undefined)
|
||||
renderSettingsModal()
|
||||
|
||||
fireEvent.click(screen.getByText('common.operation.save'))
|
||||
|
||||
await waitFor(() => expect(mockOnSave).toHaveBeenCalled())
|
||||
expect(mockOnSave).toHaveBeenCalledWith(expect.objectContaining({
|
||||
title: mockAppInfo.site.title,
|
||||
description: mockAppInfo.site.description,
|
||||
default_language: mockAppInfo.site.default_language,
|
||||
chat_color_theme: mockAppInfo.site.chat_color_theme,
|
||||
chat_color_theme_inverted: mockAppInfo.site.chat_color_theme_inverted,
|
||||
prompt_public: false,
|
||||
copyright: mockAppInfo.site.copyright,
|
||||
privacy_policy: mockAppInfo.site.privacy_policy,
|
||||
custom_disclaimer: mockAppInfo.site.custom_disclaimer,
|
||||
icon_type: 'emoji',
|
||||
icon: mockAppInfo.site.icon,
|
||||
icon_background: mockAppInfo.site.icon_background,
|
||||
show_workflow_steps: mockAppInfo.site.show_workflow_steps,
|
||||
use_icon_as_answer_icon: mockAppInfo.site.use_icon_as_answer_icon,
|
||||
enable_sso: mockAppInfo.enable_sso,
|
||||
}))
|
||||
expect(mockOnClose).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
|
@ -26,7 +26,7 @@ import DifyLogo from '@/app/components/base/logo/dify-logo'
|
|||
import Toast from '@/app/components/base/toast'
|
||||
import Res from '@/app/components/share/text-generation/result'
|
||||
import RunOnce from '@/app/components/share/text-generation/run-once'
|
||||
import { appDefaultIconBackground, DEFAULT_VALUE_MAX_LEN } from '@/config'
|
||||
import { appDefaultIconBackground, BATCH_CONCURRENCY, DEFAULT_VALUE_MAX_LEN } from '@/config'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useWebAppStore } from '@/context/web-app-context'
|
||||
import { useAppFavicon } from '@/hooks/use-app-favicon'
|
||||
|
|
@ -43,7 +43,7 @@ import MenuDropdown from './menu-dropdown'
|
|||
import RunBatch from './run-batch'
|
||||
import ResDownload from './run-batch/res-download'
|
||||
|
||||
const GROUP_SIZE = 5 // to avoid RPM(Request per minute) limit. The group task finished then the next group.
|
||||
const GROUP_SIZE = BATCH_CONCURRENCY // to avoid RPM(Request per minute) limit. The group task finished then the next group.
|
||||
enum TaskStatus {
|
||||
pending = 'pending',
|
||||
running = 'running',
|
||||
|
|
|
|||
|
|
@ -0,0 +1,60 @@
|
|||
import type { Credential } from '@/app/components/tools/types'
|
||||
import { act, fireEvent, render, screen } from '@testing-library/react'
|
||||
import { AuthHeaderPrefix, AuthType } from '@/app/components/tools/types'
|
||||
import ConfigCredential from './config-credentials'
|
||||
|
||||
describe('ConfigCredential', () => {
|
||||
const baseCredential: Credential = {
|
||||
auth_type: AuthType.none,
|
||||
}
|
||||
const mockOnChange = vi.fn()
|
||||
const mockOnHide = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('renders and calls onHide when cancel is pressed', async () => {
|
||||
await act(async () => {
|
||||
render(
|
||||
<ConfigCredential
|
||||
credential={baseCredential}
|
||||
onChange={mockOnChange}
|
||||
onHide={mockOnHide}
|
||||
/>,
|
||||
)
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByText('common.operation.cancel'))
|
||||
|
||||
expect(mockOnHide).toHaveBeenCalledTimes(1)
|
||||
expect(mockOnChange).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('allows selecting apiKeyHeader and submits the new credential', async () => {
|
||||
await act(async () => {
|
||||
render(
|
||||
<ConfigCredential
|
||||
credential={baseCredential}
|
||||
onChange={mockOnChange}
|
||||
onHide={mockOnHide}
|
||||
/>,
|
||||
)
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByText('tools.createTool.authMethod.types.api_key_header'))
|
||||
const headerInput = screen.getByPlaceholderText('tools.createTool.authMethod.types.apiKeyPlaceholder')
|
||||
const valueInput = screen.getByPlaceholderText('tools.createTool.authMethod.types.apiValuePlaceholder')
|
||||
fireEvent.change(headerInput, { target: { value: 'X-Auth' } })
|
||||
fireEvent.change(valueInput, { target: { value: 'sEcReT' } })
|
||||
fireEvent.click(screen.getByText('common.operation.save'))
|
||||
|
||||
expect(mockOnChange).toHaveBeenCalledWith({
|
||||
auth_type: AuthType.apiKeyHeader,
|
||||
api_key_header: 'X-Auth',
|
||||
api_key_header_prefix: AuthHeaderPrefix.custom,
|
||||
api_key_value: 'sEcReT',
|
||||
})
|
||||
expect(mockOnHide).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { importSchemaFromURL } from '@/service/tools'
|
||||
import Toast from '../../base/toast'
|
||||
import examples from './examples'
|
||||
import GetSchema from './get-schema'
|
||||
|
||||
vi.mock('@/service/tools', () => ({
|
||||
importSchemaFromURL: vi.fn(),
|
||||
}))
|
||||
const importSchemaFromURLMock = vi.mocked(importSchemaFromURL)
|
||||
|
||||
describe('GetSchema', () => {
|
||||
const notifySpy = vi.spyOn(Toast, 'notify')
|
||||
const mockOnChange = vi.fn()
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
notifySpy.mockClear()
|
||||
importSchemaFromURLMock.mockReset()
|
||||
render(<GetSchema onChange={mockOnChange} />)
|
||||
})
|
||||
|
||||
it('shows an error when the URL is not http', () => {
|
||||
fireEvent.click(screen.getByText('tools.createTool.importFromUrl'))
|
||||
const input = screen.getByPlaceholderText('tools.createTool.importFromUrlPlaceHolder')
|
||||
// eslint-disable-next-line sonarjs/no-clear-text-protocols
|
||||
fireEvent.change(input, { target: { value: 'ftp://invalid' } })
|
||||
fireEvent.click(screen.getByText('common.operation.ok'))
|
||||
|
||||
expect(notifySpy).toHaveBeenCalledWith({
|
||||
type: 'error',
|
||||
message: 'tools.createTool.urlError',
|
||||
})
|
||||
})
|
||||
|
||||
it('imports schema from url when valid', async () => {
|
||||
fireEvent.click(screen.getByText('tools.createTool.importFromUrl'))
|
||||
const input = screen.getByPlaceholderText('tools.createTool.importFromUrlPlaceHolder')
|
||||
fireEvent.change(input, { target: { value: 'https://example.com' } })
|
||||
importSchemaFromURLMock.mockResolvedValueOnce({ schema: 'result-schema' })
|
||||
|
||||
fireEvent.click(screen.getByText('common.operation.ok'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockOnChange).toHaveBeenCalledWith('result-schema')
|
||||
})
|
||||
})
|
||||
|
||||
it('selects example schema when example option clicked', () => {
|
||||
fireEvent.click(screen.getByText('tools.createTool.examples'))
|
||||
fireEvent.click(screen.getByText(`tools.createTool.exampleOptions.${examples[0].key}`))
|
||||
|
||||
expect(mockOnChange).toHaveBeenCalledWith(examples[0].content)
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,154 @@
|
|||
import type { ModalContextState } from '@/context/modal-context'
|
||||
import type { ProviderContextState } from '@/context/provider-context'
|
||||
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { Plan } from '@/app/components/billing/type'
|
||||
import { parseParamsSchema } from '@/service/tools'
|
||||
import EditCustomCollectionModal from './index'
|
||||
|
||||
vi.mock('ahooks', async () => {
|
||||
const actual = await vi.importActual<typeof import('ahooks')>('ahooks')
|
||||
return {
|
||||
...actual,
|
||||
useDebounce: (value: unknown) => value,
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/service/tools', () => ({
|
||||
parseParamsSchema: vi.fn(),
|
||||
}))
|
||||
const parseParamsSchemaMock = vi.mocked(parseParamsSchema)
|
||||
|
||||
const mockSetShowPricingModal = vi.fn()
|
||||
const mockSetShowAccountSettingModal = vi.fn()
|
||||
vi.mock('@/context/modal-context', () => ({
|
||||
useModalContext: (): ModalContextState => ({
|
||||
setShowAccountSettingModal: mockSetShowAccountSettingModal,
|
||||
setShowApiBasedExtensionModal: vi.fn(),
|
||||
setShowModerationSettingModal: vi.fn(),
|
||||
setShowExternalDataToolModal: vi.fn(),
|
||||
setShowPricingModal: mockSetShowPricingModal,
|
||||
setShowAnnotationFullModal: vi.fn(),
|
||||
setShowModelModal: vi.fn(),
|
||||
setShowExternalKnowledgeAPIModal: vi.fn(),
|
||||
setShowModelLoadBalancingModal: vi.fn(),
|
||||
setShowOpeningModal: vi.fn(),
|
||||
setShowUpdatePluginModal: vi.fn(),
|
||||
setShowEducationExpireNoticeModal: vi.fn(),
|
||||
setShowTriggerEventsLimitModal: vi.fn(),
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockUseProviderContext = vi.fn()
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: () => mockUseProviderContext(),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/i18n', async () => {
|
||||
const actual = await vi.importActual<typeof import('@/context/i18n')>('@/context/i18n')
|
||||
return {
|
||||
...actual,
|
||||
useDocLink: () => (path?: string) => `https://docs.example.com${path ?? ''}`,
|
||||
}
|
||||
})
|
||||
|
||||
describe('EditCustomCollectionModal', () => {
|
||||
const mockOnHide = vi.fn()
|
||||
const mockOnAdd = vi.fn()
|
||||
const mockOnEdit = vi.fn()
|
||||
const mockOnRemove = vi.fn()
|
||||
const toastNotifySpy = vi.spyOn(Toast, 'notify')
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
toastNotifySpy.mockClear()
|
||||
parseParamsSchemaMock.mockResolvedValue({
|
||||
parameters_schema: [],
|
||||
schema_type: 'openapi',
|
||||
})
|
||||
mockUseProviderContext.mockReturnValue({
|
||||
plan: {
|
||||
type: Plan.sandbox,
|
||||
},
|
||||
enableBilling: false,
|
||||
webappCopyrightEnabled: true,
|
||||
} as ProviderContextState)
|
||||
})
|
||||
|
||||
const renderModal = () => render(
|
||||
<EditCustomCollectionModal
|
||||
payload={undefined}
|
||||
onHide={mockOnHide}
|
||||
onAdd={mockOnAdd}
|
||||
onEdit={mockOnEdit}
|
||||
onRemove={mockOnRemove}
|
||||
/>,
|
||||
)
|
||||
|
||||
it('shows an error when the provider name is missing', async () => {
|
||||
renderModal()
|
||||
|
||||
const schemaInput = screen.getByPlaceholderText('tools.createTool.schemaPlaceHolder')
|
||||
fireEvent.change(schemaInput, { target: { value: '{}' } })
|
||||
await waitFor(() => {
|
||||
expect(parseParamsSchemaMock).toHaveBeenCalledWith('{}')
|
||||
})
|
||||
|
||||
fireEvent.click(screen.getByText('common.operation.save'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(toastNotifySpy).toHaveBeenCalledWith(expect.objectContaining({
|
||||
message: 'common.errorMsg.fieldRequired:{"field":"tools.createTool.name"}',
|
||||
type: 'error',
|
||||
}))
|
||||
})
|
||||
expect(mockOnAdd).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('shows an error when the schema is missing', async () => {
|
||||
renderModal()
|
||||
|
||||
const providerInput = screen.getByPlaceholderText('tools.createTool.toolNamePlaceHolder')
|
||||
fireEvent.change(providerInput, { target: { value: 'provider' } })
|
||||
|
||||
fireEvent.click(screen.getByText('common.operation.save'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(toastNotifySpy).toHaveBeenCalledWith(expect.objectContaining({
|
||||
message: 'common.errorMsg.fieldRequired:{"field":"tools.createTool.schema"}',
|
||||
type: 'error',
|
||||
}))
|
||||
})
|
||||
expect(mockOnAdd).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('saves a valid custom collection', async () => {
|
||||
renderModal()
|
||||
const providerInput = screen.getByPlaceholderText('tools.createTool.toolNamePlaceHolder')
|
||||
fireEvent.change(providerInput, { target: { value: 'provider' } })
|
||||
|
||||
const schemaInput = screen.getByPlaceholderText('tools.createTool.schemaPlaceHolder')
|
||||
fireEvent.change(schemaInput, { target: { value: '{}' } })
|
||||
|
||||
await waitFor(() => {
|
||||
expect(parseParamsSchemaMock).toHaveBeenCalledWith('{}')
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
fireEvent.click(screen.getByText('common.operation.save'))
|
||||
})
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockOnAdd).toHaveBeenCalledWith(expect.objectContaining({
|
||||
provider: 'provider',
|
||||
schema: '{}',
|
||||
schema_type: 'openapi',
|
||||
credentials: {
|
||||
auth_type: 'none',
|
||||
},
|
||||
labels: [],
|
||||
}))
|
||||
expect(toastNotifySpy).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
import type { CustomCollectionBackend, CustomParamSchema } from '@/app/components/tools/types'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { AuthType } from '@/app/components/tools/types'
|
||||
import I18n from '@/context/i18n'
|
||||
import { testAPIAvailable } from '@/service/tools'
|
||||
import TestApi from './test-api'
|
||||
|
||||
vi.mock('@/service/tools', () => ({
|
||||
testAPIAvailable: vi.fn(),
|
||||
}))
|
||||
const testAPIAvailableMock = vi.mocked(testAPIAvailable)
|
||||
|
||||
describe('TestApi', () => {
|
||||
const customCollection: CustomCollectionBackend = {
|
||||
provider: 'custom',
|
||||
credentials: {
|
||||
auth_type: AuthType.none,
|
||||
},
|
||||
schema_type: 'openapi',
|
||||
schema: '{ }',
|
||||
icon: { background: '', content: '' },
|
||||
privacy_policy: '',
|
||||
custom_disclaimer: '',
|
||||
id: 'test-id',
|
||||
labels: [],
|
||||
}
|
||||
const tool: CustomParamSchema = {
|
||||
operation_id: 'testOp',
|
||||
summary: 'summary',
|
||||
method: 'GET',
|
||||
server_url: 'https://api.example.com',
|
||||
parameters: [{
|
||||
name: 'limit',
|
||||
label: {
|
||||
en_US: 'Limit',
|
||||
zh_Hans: '限制',
|
||||
},
|
||||
// eslint-disable-next-line ts/no-explicit-any
|
||||
} as any],
|
||||
}
|
||||
|
||||
const renderTestApi = () => {
|
||||
const providerValue = {
|
||||
locale: 'en-US',
|
||||
i18n: {},
|
||||
setLocaleOnClient: vi.fn(),
|
||||
}
|
||||
return render(
|
||||
<I18n.Provider value={providerValue as any}>
|
||||
<TestApi
|
||||
customCollection={customCollection}
|
||||
tool={tool}
|
||||
onHide={vi.fn()}
|
||||
/>
|
||||
</I18n.Provider>,
|
||||
)
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('renders parameters and runs the API test', async () => {
|
||||
testAPIAvailableMock.mockResolvedValueOnce({ result: 'ok' })
|
||||
renderTestApi()
|
||||
|
||||
const parameterInput = screen.getAllByRole('textbox')[0]
|
||||
fireEvent.change(parameterInput, { target: { value: '5' } })
|
||||
fireEvent.click(screen.getByRole('button', { name: 'tools.test.title' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(testAPIAvailableMock).toHaveBeenCalledWith({
|
||||
provider_name: customCollection.provider,
|
||||
tool_name: tool.operation_id,
|
||||
credentials: {
|
||||
auth_type: AuthType.none,
|
||||
},
|
||||
schema_type: customCollection.schema_type,
|
||||
schema: customCollection.schema,
|
||||
parameters: {
|
||||
limit: '5',
|
||||
},
|
||||
})
|
||||
expect(screen.getByText('ok')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
|
@ -188,8 +188,8 @@ const FeaturesTrigger = () => {
|
|||
{isChatMode && (
|
||||
<Button
|
||||
className={cn(
|
||||
'text-components-button-secondary-text',
|
||||
theme === 'dark' && 'rounded-lg border border-black/5 bg-white/10 backdrop-blur-sm',
|
||||
'rounded-lg border border-transparent text-components-button-secondary-text',
|
||||
theme === 'dark' && 'border-black/5 bg-white/10 backdrop-blur-sm',
|
||||
)}
|
||||
onClick={handleShowFeatures}
|
||||
>
|
||||
|
|
|
|||
|
|
@ -23,8 +23,8 @@ const ChatVariableButton = ({ disabled }: { disabled: boolean }) => {
|
|||
return (
|
||||
<Button
|
||||
className={cn(
|
||||
'p-2',
|
||||
theme === 'dark' && showChatVariablePanel && 'rounded-lg border border-black/5 bg-white/10 backdrop-blur-sm',
|
||||
'rounded-lg border border-transparent p-2',
|
||||
theme === 'dark' && showChatVariablePanel && 'border-black/5 bg-white/10 backdrop-blur-sm',
|
||||
)}
|
||||
disabled={disabled}
|
||||
onClick={handleClick}
|
||||
|
|
|
|||
|
|
@ -26,8 +26,8 @@ const EnvButton = ({ disabled }: { disabled: boolean }) => {
|
|||
return (
|
||||
<Button
|
||||
className={cn(
|
||||
'p-2',
|
||||
theme === 'dark' && showEnvPanel && 'rounded-lg border border-black/5 bg-white/10 backdrop-blur-sm',
|
||||
'rounded-lg border border-transparent p-2',
|
||||
theme === 'dark' && showEnvPanel && 'border-black/5 bg-white/10 backdrop-blur-sm',
|
||||
)}
|
||||
variant="ghost"
|
||||
disabled={disabled}
|
||||
|
|
|
|||
|
|
@ -26,8 +26,8 @@ const GlobalVariableButton = ({ disabled }: { disabled: boolean }) => {
|
|||
return (
|
||||
<Button
|
||||
className={cn(
|
||||
'p-2',
|
||||
theme === 'dark' && showGlobalVariablePanel && 'rounded-lg border border-black/5 bg-white/10 backdrop-blur-sm',
|
||||
'rounded-lg border border-transparent p-2',
|
||||
theme === 'dark' && showGlobalVariablePanel && 'border-black/5 bg-white/10 backdrop-blur-sm',
|
||||
)}
|
||||
disabled={disabled}
|
||||
onClick={handleClick}
|
||||
|
|
|
|||
|
|
@ -86,7 +86,8 @@ const HeaderInRestoring = ({
|
|||
disabled={!currentVersion || currentVersion.version === WorkflowVersion.Draft}
|
||||
variant="primary"
|
||||
className={cn(
|
||||
theme === 'dark' && 'rounded-lg border border-black/5 bg-white/10 backdrop-blur-sm',
|
||||
'rounded-lg border border-transparent',
|
||||
theme === 'dark' && 'border-black/5 bg-white/10 backdrop-blur-sm',
|
||||
)}
|
||||
>
|
||||
{t('workflow.common.restore')}
|
||||
|
|
@ -94,8 +95,8 @@ const HeaderInRestoring = ({
|
|||
<Button
|
||||
onClick={handleCancelRestore}
|
||||
className={cn(
|
||||
'text-components-button-secondary-accent-text',
|
||||
theme === 'dark' && 'rounded-lg border border-black/5 bg-white/10 backdrop-blur-sm',
|
||||
'rounded-lg border border-transparent text-components-button-secondary-accent-text',
|
||||
theme === 'dark' && 'border-black/5 bg-white/10 backdrop-blur-sm',
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-x-0.5">
|
||||
|
|
|
|||
|
|
@ -61,8 +61,8 @@ const VersionHistoryButton: FC<VersionHistoryButtonProps> = ({
|
|||
>
|
||||
<Button
|
||||
className={cn(
|
||||
'p-2',
|
||||
theme === 'dark' && 'rounded-lg border border-black/5 bg-white/10 backdrop-blur-sm',
|
||||
'p-2 rounded-lg border border-transparent',
|
||||
theme === 'dark' && 'border-black/5 bg-white/10 backdrop-blur-sm',
|
||||
)}
|
||||
onClick={handleViewVersionHistory}
|
||||
>
|
||||
|
|
|
|||
|
|
@ -62,8 +62,15 @@ const ConditionItem = ({
|
|||
}, [onRemoveCondition, condition.id])
|
||||
|
||||
const currentMetadata = useMemo(() => {
|
||||
// Try to match by metadata_id first (reliable reference)
|
||||
if (condition.metadata_id) {
|
||||
const found = metadataList.find(metadata => metadata.id === condition.metadata_id)
|
||||
if (found)
|
||||
return found
|
||||
}
|
||||
// Fallback to name matching for backward compatibility with old conditions
|
||||
return metadataList.find(metadata => metadata.name === condition.name)
|
||||
}, [metadataList, condition.name])
|
||||
}, [metadataList, condition.metadata_id, condition.name])
|
||||
|
||||
const handleConditionOperatorChange = useCallback((operator: ComparisonOperator) => {
|
||||
onUpdateCondition?.(
|
||||
|
|
|
|||
|
|
@ -27,11 +27,17 @@ const MetadataTrigger = ({
|
|||
useEffect(() => {
|
||||
if (selectedDatasetsLoaded) {
|
||||
conditions.forEach((condition) => {
|
||||
if (!metadataList.find(metadata => metadata.name === condition.name))
|
||||
// First try to match by metadata_id for reliable reference
|
||||
const foundById = condition.metadata_id && metadataList.find(metadata => metadata.id === condition.metadata_id)
|
||||
// Fallback to name matching only for backward compatibility with old conditions
|
||||
const foundByName = !condition.metadata_id && metadataList.find(metadata => metadata.name === condition.name)
|
||||
|
||||
// Only remove condition if both metadata_id and name matching fail
|
||||
if (!foundById && !foundByName)
|
||||
handleRemoveCondition(condition.id)
|
||||
})
|
||||
}
|
||||
}, [metadataList, handleRemoveCondition, selectedDatasetsLoaded])
|
||||
}, [metadataFilteringConditions, metadataList, handleRemoveCondition, selectedDatasetsLoaded])
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ export enum MetadataFilteringVariableType {
|
|||
export type MetadataFilteringCondition = {
|
||||
id: string
|
||||
name: string
|
||||
metadata_id?: string
|
||||
comparison_operator: ComparisonOperator
|
||||
value?: string | number
|
||||
}
|
||||
|
|
|
|||
|
|
@ -305,7 +305,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||
}))
|
||||
}, [setInputs])
|
||||
|
||||
const handleAddCondition = useCallback<HandleAddCondition>(({ name, type }) => {
|
||||
const handleAddCondition = useCallback<HandleAddCondition>(({ id, name, type }) => {
|
||||
let operator: ComparisonOperator = ComparisonOperator.is
|
||||
|
||||
if (type === MetadataFilteringVariableType.number)
|
||||
|
|
@ -313,6 +313,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||
|
||||
const newCondition = {
|
||||
id: uuid4(),
|
||||
metadata_id: id, // Save metadata.id for reliable reference
|
||||
name,
|
||||
comparison_operator: operator,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -67,6 +67,7 @@ const LocaleLayout = async ({
|
|||
[DatasetAttr.NEXT_PUBLIC_ZENDESK_FIELD_ID_EMAIL]: process.env.NEXT_PUBLIC_ZENDESK_FIELD_ID_EMAIL,
|
||||
[DatasetAttr.NEXT_PUBLIC_ZENDESK_FIELD_ID_WORKSPACE_ID]: process.env.NEXT_PUBLIC_ZENDESK_FIELD_ID_WORKSPACE_ID,
|
||||
[DatasetAttr.NEXT_PUBLIC_ZENDESK_FIELD_ID_PLAN]: process.env.NEXT_PUBLIC_ZENDESK_FIELD_ID_PLAN,
|
||||
[DatasetAttr.DATA_PUBLIC_BATCH_CONCURRENCY]: process.env.NEXT_PUBLIC_BATCH_CONCURRENCY,
|
||||
}
|
||||
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -164,6 +164,13 @@ const COOKIE_DOMAIN = getStringConfig(
|
|||
DatasetAttr.DATA_PUBLIC_COOKIE_DOMAIN,
|
||||
'',
|
||||
).trim()
|
||||
|
||||
export const BATCH_CONCURRENCY = getNumberConfig(
|
||||
process.env.NEXT_PUBLIC_BATCH_CONCURRENCY,
|
||||
DatasetAttr.DATA_PUBLIC_BATCH_CONCURRENCY,
|
||||
5, // default
|
||||
)
|
||||
|
||||
export const CSRF_COOKIE_NAME = () => {
|
||||
if (COOKIE_DOMAIN)
|
||||
return 'csrf_token'
|
||||
|
|
|
|||
|
|
@ -131,4 +131,5 @@ export enum DatasetAttr {
|
|||
NEXT_PUBLIC_ZENDESK_FIELD_ID_EMAIL = 'next-public-zendesk-field-id-email',
|
||||
NEXT_PUBLIC_ZENDESK_FIELD_ID_WORKSPACE_ID = 'next-public-zendesk-field-id-workspace-id',
|
||||
NEXT_PUBLIC_ZENDESK_FIELD_ID_PLAN = 'next-public-zendesk-field-id-plan',
|
||||
DATA_PUBLIC_BATCH_CONCURRENCY = 'data-public-batch-concurrency',
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue