Merge branch 'main' into pydantic-remaining

This commit is contained in:
Asuka Minato 2025-12-22 13:18:59 +09:00 committed by GitHub
commit 7273a1bd07
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 937 additions and 143 deletions

View File

@ -1,5 +1,5 @@
---
name: Dify Frontend Testing
name: frontend-testing
description: Generate Jest + React Testing Library tests for Dify frontend components, hooks, and utilities. Triggers on testing, spec files, coverage, Jest, RTL, unit tests, integration tests, or write/review test requests.
---
@ -178,7 +178,7 @@ Process in this order for multi-file testing:
- **500+ lines**: Consider splitting before testing
- **Many dependencies**: Extract logic into hooks first
> 📖 See `guides/workflow.md` for complete workflow details and todo list format.
> 📖 See `references/workflow.md` for complete workflow details and todo list format.
## Testing Strategy
@ -289,17 +289,18 @@ For each test file generated, aim for:
- ✅ **>95%** branch coverage
- ✅ **>95%** line coverage
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `guides/workflow.md`.
> **Note**: For multi-file directories, process one file at a time with full coverage each. See `references/workflow.md`.
## Detailed Guides
For more detailed information, refer to:
- `guides/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
- `guides/mocking.md` - Mock patterns and best practices
- `guides/async-testing.md` - Async operations and API calls
- `guides/domain-components.md` - Workflow, Dataset, Configuration testing
- `guides/common-patterns.md` - Frequently used testing patterns
- `references/workflow.md` - **Incremental testing workflow** (MUST READ for multi-file testing)
- `references/mocking.md` - Mock patterns and best practices
- `references/async-testing.md` - Async operations and API calls
- `references/domain-components.md` - Workflow, Dataset, Configuration testing
- `references/common-patterns.md` - Frequently used testing patterns
- `references/checklist.md` - Test generation checklist and validation steps
## Authoritative References

1
.codex/skills Symbolic link
View File

@ -0,0 +1 @@
../.claude/skills

10
.github/CODEOWNERS vendored
View File

@ -124,9 +124,15 @@ api/controllers/web/feature.py @GarfieldDai @GareArc
# Backend - Database Migrations
api/migrations/ @snakevash @laipz8200 @MRZHUH
# Backend - Vector DB Middleware
api/configs/middleware/vdb/* @JohnJyong
# Frontend
web/ @iamjoel
# Frontend - Web Tests
.github/workflows/web-tests.yml @iamjoel
# Frontend - App - Orchestration
web/app/components/workflow/ @iamjoel @zxhlyh
web/app/components/workflow-app/ @iamjoel @zxhlyh
@ -198,6 +204,7 @@ web/app/components/plugins/marketplace/ @iamjoel @Yessenia-d
web/app/signin/ @douxc @iamjoel
web/app/signup/ @douxc @iamjoel
web/app/reset-password/ @douxc @iamjoel
web/app/install/ @douxc @iamjoel
web/app/init/ @douxc @iamjoel
web/app/forgot-password/ @douxc @iamjoel
@ -238,3 +245,6 @@ web/app/education-apply/ @iamjoel @zxhlyh
# Frontend - Workspace
web/app/components/header/account-dropdown/workplace-selector/ @iamjoel @zxhlyh
# Docker
docker/* @laipz8200

View File

@ -66,7 +66,7 @@ jobs:
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
- name: mdformat
run: |
uvx --python 3.13 mdformat . --exclude ".claude/skills/**"
uvx --python 3.13 mdformat . --exclude ".claude/skills/**/SKILL.md"
- name: Install pnpm
uses: pnpm/action-setup@v4

View File

@ -119,14 +119,14 @@ class InstalledAppsListApi(Resource):
recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first()
if recommended_app is None:
raise NotFound("App not found")
raise NotFound("Recommended app not found")
_, current_tenant_id = current_account_with_tenant()
app = db.session.query(App).where(App.id == payload.app_id).first()
if app is None:
raise NotFound("App not found")
raise NotFound("App entity not found")
if not app.is_public:
raise Forbidden("You can't install a non-public app")

View File

@ -18,6 +18,7 @@ from controllers.console.wraps import (
setup_required,
)
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration
from core.helper.tool_provider_cache import ToolProviderListCache
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError
from core.mcp.mcp_client import MCPClient
@ -969,7 +970,7 @@ class ToolProviderMCPApi(Resource):
configuration = MCPConfiguration.model_validate(payload.configuration or {})
authentication = MCPAuthentication.model_validate(payload.authentication) if payload.authentication else None
# Create provider
# Create provider in transaction
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
result = service.create_provider(
@ -985,7 +986,11 @@ class ToolProviderMCPApi(Resource):
configuration=configuration,
authentication=authentication,
)
return jsonable_encoder(result)
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
ToolProviderListCache.invalidate_cache(tenant_id)
return jsonable_encoder(result)
@console_ns.expect(console_ns.models[MCPProviderUpdatePayload.__name__])
@setup_required
@ -997,19 +1002,23 @@ class ToolProviderMCPApi(Resource):
authentication = MCPAuthentication.model_validate(payload.authentication) if payload.authentication else None
_, current_tenant_id = current_account_with_tenant()
# Step 1: Validate server URL change if needed (includes URL format validation and network operation)
validation_result = None
# Step 1: Get provider data for URL validation (short-lived session, no network I/O)
validation_data = None
with Session(db.engine) as session:
service = MCPToolManageService(session=session)
validation_result = service.validate_server_url_change(
tenant_id=current_tenant_id,
provider_id=payload.provider_id,
new_server_url=payload.server_url,
validation_data = service.get_provider_for_url_validation(
tenant_id=current_tenant_id, provider_id=args.provider_id
)
# No need to check for errors here, exceptions will be raised directly
# Step 2: Perform URL validation with network I/O OUTSIDE of any database session
# This prevents holding database locks during potentially slow network operations
validation_result = MCPToolManageService.validate_server_url_standalone(
tenant_id=current_tenant_id,
new_server_url=args["server_url"],
validation_data=validation_data,
)
# Step 2: Perform database update in a transaction
# Step 3: Perform database update in a transaction
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.update_provider(
@ -1026,7 +1035,11 @@ class ToolProviderMCPApi(Resource):
authentication=authentication,
validation_result=validation_result,
)
return {"result": "success"}
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
ToolProviderListCache.invalidate_cache(current_tenant_id)
return {"result": "success"}
@console_ns.expect(console_ns.models[MCPProviderDeletePayload.__name__])
@setup_required
@ -1038,8 +1051,12 @@ class ToolProviderMCPApi(Resource):
with Session(db.engine) as session, session.begin():
service = MCPToolManageService(session=session)
service.delete_provider(tenant_id=current_tenant_id, provider_id=payload.provider_id)
return {"result": "success"}
service.delete_provider(tenant_id=current_tenant_id, provider_id=args.provider_id)
# Invalidate cache AFTER transaction commits to avoid holding locks during Redis operations
ToolProviderListCache.invalidate_cache(current_tenant_id)
return {"result": "success"}
@console_ns.route("/workspaces/current/tool-provider/mcp/auth")

View File

@ -47,7 +47,11 @@ def build_protected_resource_metadata_discovery_urls(
"""
Build a list of URLs to try for Protected Resource Metadata discovery.
Per SEP-985, supports fallback when discovery fails at one URL.
Per RFC 9728 Section 5.1, supports fallback when discovery fails at one URL.
Priority order:
1. URL from WWW-Authenticate header (if provided)
2. Well-known URI with path: https://example.com/.well-known/oauth-protected-resource/public/mcp
3. Well-known URI at root: https://example.com/.well-known/oauth-protected-resource
"""
urls = []
@ -58,9 +62,18 @@ def build_protected_resource_metadata_discovery_urls(
# Fallback: construct from server URL
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
fallback_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
if fallback_url not in urls:
urls.append(fallback_url)
path = parsed.path.rstrip("/")
# Priority 2: With path insertion (e.g., /.well-known/oauth-protected-resource/public/mcp)
if path:
path_url = f"{base_url}/.well-known/oauth-protected-resource{path}"
if path_url not in urls:
urls.append(path_url)
# Priority 3: At root (e.g., /.well-known/oauth-protected-resource)
root_url = f"{base_url}/.well-known/oauth-protected-resource"
if root_url not in urls:
urls.append(root_url)
return urls
@ -71,30 +84,34 @@ def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: st
Supports both OAuth 2.0 (RFC 8414) and OpenID Connect discovery.
Per RFC 8414 section 3:
- If issuer has no path: https://example.com/.well-known/oauth-authorization-server
- If issuer has path: https://example.com/.well-known/oauth-authorization-server{path}
Example:
- issuer: https://example.com/oauth
- metadata: https://example.com/.well-known/oauth-authorization-server/oauth
Per RFC 8414 section 3.1 and section 5, try all possible endpoints:
- OAuth 2.0 with path insertion: https://example.com/.well-known/oauth-authorization-server/tenant1
- OpenID Connect with path insertion: https://example.com/.well-known/openid-configuration/tenant1
- OpenID Connect path appending: https://example.com/tenant1/.well-known/openid-configuration
- OAuth 2.0 at root: https://example.com/.well-known/oauth-authorization-server
- OpenID Connect at root: https://example.com/.well-known/openid-configuration
"""
urls = []
base_url = auth_server_url or server_url
parsed = urlparse(base_url)
base = f"{parsed.scheme}://{parsed.netloc}"
path = parsed.path.rstrip("/") # Remove trailing slash
path = parsed.path.rstrip("/")
# OAuth 2.0 Authorization Server Metadata at root (MCP-03-26)
urls.append(f"{base}/.well-known/oauth-authorization-server")
# Try OpenID Connect discovery first (more common)
urls.append(urljoin(base + "/", ".well-known/openid-configuration"))
# OpenID Connect Discovery at root
urls.append(f"{base}/.well-known/openid-configuration")
# OAuth 2.0 Authorization Server Metadata (RFC 8414)
# Include the path component if present in the issuer URL
if path:
urls.append(urljoin(base, f".well-known/oauth-authorization-server{path}"))
else:
urls.append(urljoin(base, ".well-known/oauth-authorization-server"))
# OpenID Connect Discovery with path insertion
urls.append(f"{base}/.well-known/openid-configuration{path}")
# OpenID Connect Discovery path appending
urls.append(f"{base}{path}/.well-known/openid-configuration")
# OAuth 2.0 Authorization Server Metadata with path insertion
urls.append(f"{base}/.well-known/oauth-authorization-server{path}")
return urls

View File

@ -59,7 +59,7 @@ class MCPClient:
try:
logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
self.connect_server(sse_client, "sse")
except MCPConnectionError:
except (MCPConnectionError, ValueError):
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
self.connect_server(streamablehttp_client, "mcp")

View File

@ -83,6 +83,7 @@ class WordExtractor(BaseExtractor):
def _extract_images_from_docx(self, doc):
image_count = 0
image_map = {}
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
for r_id, rel in doc.part.rels.items():
if "image" in rel.target_ref:
@ -121,8 +122,7 @@ class WordExtractor(BaseExtractor):
used_at=naive_utc_now(),
)
db.session.add(upload_file)
# Use r_id as key for external images since target_part is undefined
image_map[r_id] = f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)"
image_map[r_id] = f"![image]({base_url}/files/{upload_file.id}/file-preview)"
else:
image_ext = rel.target_ref.split(".")[-1]
if image_ext is None:
@ -150,10 +150,7 @@ class WordExtractor(BaseExtractor):
used_at=naive_utc_now(),
)
db.session.add(upload_file)
# Use target_part as key for internal images
image_map[rel.target_part] = (
f"![image]({dify_config.FILES_URL}/files/{upload_file.id}/file-preview)"
)
image_map[rel.target_part] = f"![image]({base_url}/files/{upload_file.id}/file-preview)"
db.session.commit()
return image_map

View File

@ -86,6 +86,11 @@ class Executor:
node_data.authorization.config.api_key = variable_pool.convert_template(
node_data.authorization.config.api_key
).text
# Validate that API key is not empty after template conversion
if not node_data.authorization.config.api_key or not node_data.authorization.config.api_key.strip():
raise AuthorizationConfigError(
"API key is required for authorization but was empty. Please provide a valid API key."
)
self.url = node_data.url
self.method = node_data.method

View File

@ -15,7 +15,6 @@ from sqlalchemy.orm import Session
from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity
from core.helper import encrypter
from core.helper.provider_cache import NoOpProviderCredentialCache
from core.helper.tool_provider_cache import ToolProviderListCache
from core.mcp.auth.auth_flow import auth
from core.mcp.auth_client import MCPClientWithAuthRetry
from core.mcp.error import MCPAuthError, MCPError
@ -65,6 +64,15 @@ class ServerUrlValidationResult(BaseModel):
return self.needs_validation and self.validation_passed and self.reconnect_result is not None
class ProviderUrlValidationData(BaseModel):
"""Data required for URL validation, extracted from database to perform network operations outside of session"""
current_server_url_hash: str
headers: dict[str, str]
timeout: float | None
sse_read_timeout: float | None
class MCPToolManageService:
"""Service class for managing MCP tools and providers."""
@ -166,9 +174,6 @@ class MCPToolManageService:
self._session.add(mcp_tool)
self._session.flush()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
mcp_providers = ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
return mcp_providers
@ -192,7 +197,7 @@ class MCPToolManageService:
Update an MCP provider.
Args:
validation_result: Pre-validation result from validate_server_url_change.
validation_result: Pre-validation result from validate_server_url_standalone.
If provided and contains reconnect_result, it will be used
instead of performing network operations.
"""
@ -251,8 +256,6 @@ class MCPToolManageService:
# Flush changes to database
self._session.flush()
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
except IntegrityError as e:
self._handle_integrity_error(e, name, server_url, server_identifier)
@ -261,9 +264,6 @@ class MCPToolManageService:
mcp_tool = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
self._session.delete(mcp_tool)
# Invalidate tool providers cache
ToolProviderListCache.invalidate_cache(tenant_id)
def list_providers(
self, *, tenant_id: str, for_list: bool = False, include_sensitive: bool = True
) -> list[ToolProviderApiEntity]:
@ -546,30 +546,39 @@ class MCPToolManageService:
)
return self.execute_auth_actions(auth_result)
def _reconnect_provider(self, *, server_url: str, provider: MCPToolProvider) -> ReconnectResult:
"""Attempt to reconnect to MCP provider with new server URL."""
def get_provider_for_url_validation(self, *, tenant_id: str, provider_id: str) -> ProviderUrlValidationData:
"""
Get provider data required for URL validation.
This method performs database read and should be called within a session.
Returns:
ProviderUrlValidationData: Data needed for standalone URL validation
"""
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
provider_entity = provider.to_entity()
headers = provider_entity.headers
return ProviderUrlValidationData(
current_server_url_hash=provider.server_url_hash,
headers=provider_entity.headers,
timeout=provider_entity.timeout,
sse_read_timeout=provider_entity.sse_read_timeout,
)
try:
tools = self._retrieve_remote_mcp_tools(server_url, headers, provider_entity)
return ReconnectResult(
authed=True,
tools=json.dumps([tool.model_dump() for tool in tools]),
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
)
except MCPAuthError:
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
except MCPError as e:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
def validate_server_url_change(
self, *, tenant_id: str, provider_id: str, new_server_url: str
@staticmethod
def validate_server_url_standalone(
*,
tenant_id: str,
new_server_url: str,
validation_data: ProviderUrlValidationData,
) -> ServerUrlValidationResult:
"""
Validate server URL change by attempting to connect to the new server.
This method should be called BEFORE update_provider to perform network operations
outside of the database transaction.
This method performs network operations and MUST be called OUTSIDE of any database session
to avoid holding locks during network I/O.
Args:
tenant_id: Tenant ID for encryption
new_server_url: The new server URL to validate
validation_data: Provider data obtained from get_provider_for_url_validation
Returns:
ServerUrlValidationResult: Validation result with connection status and tools if successful
@ -579,25 +588,30 @@ class MCPToolManageService:
return ServerUrlValidationResult(needs_validation=False)
# Validate URL format
if not self._is_valid_url(new_server_url):
parsed = urlparse(new_server_url)
if not all([parsed.scheme, parsed.netloc]) or parsed.scheme not in ["http", "https"]:
raise ValueError("Server URL is not valid.")
# Always encrypt and hash the URL
encrypted_server_url = encrypter.encrypt_token(tenant_id, new_server_url)
new_server_url_hash = hashlib.sha256(new_server_url.encode()).hexdigest()
# Get current provider
provider = self.get_provider(provider_id=provider_id, tenant_id=tenant_id)
# Check if URL is actually different
if new_server_url_hash == provider.server_url_hash:
if new_server_url_hash == validation_data.current_server_url_hash:
# URL hasn't changed, but still return the encrypted data
return ServerUrlValidationResult(
needs_validation=False, encrypted_server_url=encrypted_server_url, server_url_hash=new_server_url_hash
needs_validation=False,
encrypted_server_url=encrypted_server_url,
server_url_hash=new_server_url_hash,
)
# Perform validation by attempting to connect
reconnect_result = self._reconnect_provider(server_url=new_server_url, provider=provider)
# Perform network validation - this is the expensive operation that should be outside session
reconnect_result = MCPToolManageService._reconnect_with_url(
server_url=new_server_url,
headers=validation_data.headers,
timeout=validation_data.timeout,
sse_read_timeout=validation_data.sse_read_timeout,
)
return ServerUrlValidationResult(
needs_validation=True,
validation_passed=True,
@ -606,6 +620,38 @@ class MCPToolManageService:
server_url_hash=new_server_url_hash,
)
@staticmethod
def _reconnect_with_url(
*,
server_url: str,
headers: dict[str, str],
timeout: float | None,
sse_read_timeout: float | None,
) -> ReconnectResult:
"""
Attempt to connect to MCP server with given URL.
This is a static method that performs network I/O without database access.
"""
from core.mcp.mcp_client import MCPClient
try:
with MCPClient(
server_url=server_url,
headers=headers,
timeout=timeout,
sse_read_timeout=sse_read_timeout,
) as mcp_client:
tools = mcp_client.list_tools()
return ReconnectResult(
authed=True,
tools=json.dumps([tool.model_dump() for tool in tools]),
encrypted_credentials=EMPTY_CREDENTIALS_JSON,
)
except MCPAuthError:
return ReconnectResult(authed=False, tools=EMPTY_TOOLS_JSON, encrypted_credentials=EMPTY_CREDENTIALS_JSON)
except MCPError as e:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
def _build_tool_provider_response(
self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list
) -> ToolProviderApiEntity:

View File

@ -2,7 +2,6 @@ import logging
import time
import click
import sqlalchemy as sa
from celery import shared_task
from sqlalchemy import select
@ -12,7 +11,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from models.source import DataSourceOauthBinding
from services.datasource_provider_service import DatasourceProviderService
logger = logging.getLogger(__name__)
@ -48,27 +47,36 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"]
credential_id = data_source_info.get("credential_id")
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.where(
sa.and_(
DataSourceOauthBinding.tenant_id == document.tenant_id,
DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
.first()
# Get credentials from datasource provider
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=document.tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
if not data_source_binding:
raise ValueError("Data source binding not found.")
if not credential:
logger.error(
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
document_id,
document.tenant_id,
credential_id,
)
document.indexing_status = "error"
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
document.stopped_at = naive_utc_now()
db.session.commit()
db.session.close()
return
loader = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=data_source_binding.access_token,
notion_access_token=credential.get("integration_secret"),
tenant_id=document.tenant_id,
)

View File

@ -6,6 +6,7 @@ import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.nodes.http_request.node import HttpRequestNode
from core.workflow.nodes.node_factory import DifyNodeFactory
@ -169,13 +170,14 @@ def test_custom_authorization_header(setup_http_mock):
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
"""Test: In custom authentication mode, when the api_key is empty, no header should be set."""
def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock):
"""Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised."""
from core.workflow.nodes.http_request.entities import (
HttpRequestNodeAuthorization,
HttpRequestNodeData,
HttpRequestNodeTimeout,
)
from core.workflow.nodes.http_request.exc import AuthorizationConfigError
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
@ -208,16 +210,13 @@ def test_custom_auth_with_empty_api_key_does_not_set_header(setup_http_mock):
ssl_verify=True,
)
# Create executor
executor = Executor(
node_data=node_data, timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10), variable_pool=variable_pool
)
# Get assembled headers
headers = executor._assembling_headers()
# When api_key is empty, the custom header should NOT be set
assert "X-Custom-Auth" not in headers
# Create executor should raise AuthorizationConfigError
with pytest.raises(AuthorizationConfigError, match="API key is required"):
Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10),
variable_pool=variable_pool,
)
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
@ -305,9 +304,10 @@ def test_basic_authorization_with_custom_header_ignored(setup_http_mock):
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_custom_authorization_with_empty_api_key(setup_http_mock):
"""
Test that custom authorization doesn't set header when api_key is empty.
This test verifies the fix for issue #23554.
Test that custom authorization raises error when api_key is empty.
This test verifies the fix for issue #21830.
"""
node = init_http_node(
config={
"id": "1",
@ -333,11 +333,10 @@ def test_custom_authorization_with_empty_api_key(setup_http_mock):
)
result = node._run()
assert result.process_data is not None
data = result.process_data.get("request", "")
# Custom header should NOT be set when api_key is empty
assert "X-Custom-Auth:" not in data
# Should fail with AuthorizationConfigError
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "API key is required" in result.error
assert result.error_type == "AuthorizationConfigError"
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)

View File

@ -1308,18 +1308,17 @@ class TestMCPToolManageService:
type("MockTool", (), {"model_dump": lambda self: {"name": "test_tool_2", "description": "Test tool 2"}})(),
]
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
# Setup mock client
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.return_value = mock_tools
# Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service._reconnect_provider(
result = MCPToolManageService._reconnect_with_url(
server_url="https://example.com/mcp",
provider=mcp_provider,
headers={"X-Test": "1"},
timeout=mcp_provider.timeout,
sse_read_timeout=mcp_provider.sse_read_timeout,
)
# Assert: Verify the expected outcomes
@ -1337,8 +1336,12 @@ class TestMCPToolManageService:
assert tools_data[1]["name"] == "test_tool_2"
# Verify mock interactions
provider_entity = mcp_provider.to_entity()
mock_mcp_client.assert_called_once()
mock_mcp_client.assert_called_once_with(
server_url="https://example.com/mcp",
headers={"X-Test": "1"},
timeout=mcp_provider.timeout,
sse_read_timeout=mcp_provider.sse_read_timeout,
)
def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies):
"""
@ -1361,19 +1364,18 @@ class TestMCPToolManageService:
)
# Mock MCPClient to raise authentication error
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
from core.mcp.error import MCPAuthError
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required")
# Act: Execute the method under test
from extensions.ext_database import db
service = MCPToolManageService(db.session())
result = service._reconnect_provider(
result = MCPToolManageService._reconnect_with_url(
server_url="https://example.com/mcp",
provider=mcp_provider,
headers={},
timeout=mcp_provider.timeout,
sse_read_timeout=mcp_provider.sse_read_timeout,
)
# Assert: Verify the expected outcomes
@ -1404,18 +1406,17 @@ class TestMCPToolManageService:
)
# Mock MCPClient to raise connection error
with patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") as mock_mcp_client:
with patch("core.mcp.mcp_client.MCPClient") as mock_mcp_client:
from core.mcp.error import MCPError
mock_client_instance = mock_mcp_client.return_value.__enter__.return_value
mock_client_instance.list_tools.side_effect = MCPError("Connection failed")
# Act & Assert: Verify proper error handling
from extensions.ext_database import db
service = MCPToolManageService(db.session())
with pytest.raises(ValueError, match="Failed to re-connect MCP server: Connection failed"):
service._reconnect_provider(
MCPToolManageService._reconnect_with_url(
server_url="https://example.com/mcp",
provider=mcp_provider,
headers={"X-Test": "1"},
timeout=mcp_provider.timeout,
sse_read_timeout=mcp_provider.sse_read_timeout,
)

View File

@ -132,3 +132,36 @@ def test_extract_images_from_docx(monkeypatch):
# DB interactions should be recorded
assert len(db_stub.session.added) == 2
assert db_stub.session.committed is True
def test_extract_images_from_docx_uses_internal_files_url():
"""Test that INTERNAL_FILES_URL takes precedence over FILES_URL for plugin access."""
# Test the URL generation logic directly
from configs import dify_config
# Mock the configuration values
original_files_url = getattr(dify_config, "FILES_URL", None)
original_internal_files_url = getattr(dify_config, "INTERNAL_FILES_URL", None)
try:
# Set both URLs - INTERNAL should take precedence
dify_config.FILES_URL = "http://external.example.com"
dify_config.INTERNAL_FILES_URL = "http://internal.docker:5001"
# Test the URL generation logic (same as in word_extractor.py)
upload_file_id = "test_file_id"
# This is the pattern we fixed in the word extractor
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
generated_url = f"{base_url}/files/{upload_file_id}/file-preview"
# Verify that INTERNAL_FILES_URL is used instead of FILES_URL
assert "http://internal.docker:5001" in generated_url, f"Expected internal URL, got: {generated_url}"
assert "http://external.example.com" not in generated_url, f"Should not use external URL, got: {generated_url}"
finally:
# Restore original values
if original_files_url is not None:
dify_config.FILES_URL = original_files_url
if original_internal_files_url is not None:
dify_config.INTERNAL_FILES_URL = original_internal_files_url

View File

@ -1,3 +1,5 @@
import pytest
from core.workflow.nodes.http_request import (
BodyData,
HttpRequestNodeAuthorization,
@ -5,6 +7,7 @@ from core.workflow.nodes.http_request import (
HttpRequestNodeData,
)
from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout
from core.workflow.nodes.http_request.exc import AuthorizationConfigError
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
@ -348,3 +351,127 @@ def test_init_params():
executor = create_executor("key1:value1\n\nkey2:value2\n\n")
executor._init_params()
assert executor.params == [("key1", "value1"), ("key2", "value2")]
def test_empty_api_key_raises_error_bearer():
"""Test that empty API key raises AuthorizationConfigError for bearer auth."""
variable_pool = VariablePool(system_variables=SystemVariable.empty())
node_data = HttpRequestNodeData(
title="test",
method="get",
url="http://example.com",
headers="",
params="",
authorization=HttpRequestNodeAuthorization(
type="api-key",
config={"type": "bearer", "api_key": ""},
),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
with pytest.raises(AuthorizationConfigError, match="API key is required"):
Executor(
node_data=node_data,
timeout=timeout,
variable_pool=variable_pool,
)
def test_empty_api_key_raises_error_basic():
"""Test that empty API key raises AuthorizationConfigError for basic auth."""
variable_pool = VariablePool(system_variables=SystemVariable.empty())
node_data = HttpRequestNodeData(
title="test",
method="get",
url="http://example.com",
headers="",
params="",
authorization=HttpRequestNodeAuthorization(
type="api-key",
config={"type": "basic", "api_key": ""},
),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
with pytest.raises(AuthorizationConfigError, match="API key is required"):
Executor(
node_data=node_data,
timeout=timeout,
variable_pool=variable_pool,
)
def test_empty_api_key_raises_error_custom():
"""Test that empty API key raises AuthorizationConfigError for custom auth."""
variable_pool = VariablePool(system_variables=SystemVariable.empty())
node_data = HttpRequestNodeData(
title="test",
method="get",
url="http://example.com",
headers="",
params="",
authorization=HttpRequestNodeAuthorization(
type="api-key",
config={"type": "custom", "api_key": "", "header": "X-Custom-Auth"},
),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
with pytest.raises(AuthorizationConfigError, match="API key is required"):
Executor(
node_data=node_data,
timeout=timeout,
variable_pool=variable_pool,
)
def test_whitespace_only_api_key_raises_error():
"""Test that whitespace-only API key raises AuthorizationConfigError."""
variable_pool = VariablePool(system_variables=SystemVariable.empty())
node_data = HttpRequestNodeData(
title="test",
method="get",
url="http://example.com",
headers="",
params="",
authorization=HttpRequestNodeAuthorization(
type="api-key",
config={"type": "bearer", "api_key": " "},
),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
with pytest.raises(AuthorizationConfigError, match="API key is required"):
Executor(
node_data=node_data,
timeout=timeout,
variable_pool=variable_pool,
)
def test_valid_api_key_works():
"""Test that valid API key works correctly for bearer auth."""
variable_pool = VariablePool(system_variables=SystemVariable.empty())
node_data = HttpRequestNodeData(
title="test",
method="get",
url="http://example.com",
headers="",
params="",
authorization=HttpRequestNodeAuthorization(
type="api-key",
config={"type": "bearer", "api_key": "valid-api-key-123"},
),
)
timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30)
executor = Executor(
node_data=node_data,
timeout=timeout,
variable_pool=variable_pool,
)
# Should not raise an error
headers = executor._assembling_headers()
assert "Authorization" in headers
assert headers["Authorization"] == "Bearer valid-api-key-123"

View File

@ -0,0 +1,520 @@
"""
Unit tests for document indexing sync task.
This module tests the document indexing sync task functionality including:
- Syncing Notion documents when updated
- Validating document and data source existence
- Credential validation and retrieval
- Cleaning old segments before re-indexing
- Error handling and edge cases
"""
import uuid
from unittest.mock import MagicMock, Mock, patch
import pytest
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from models.dataset import Dataset, Document, DocumentSegment
from tasks.document_indexing_sync_task import document_indexing_sync_task
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
def tenant_id():
"""Generate a unique tenant ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def dataset_id():
"""Generate a unique dataset ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def document_id():
"""Generate a unique document ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def notion_workspace_id():
"""Generate a Notion workspace ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def notion_page_id():
"""Generate a Notion page ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def credential_id():
"""Generate a credential ID for testing."""
return str(uuid.uuid4())
@pytest.fixture
def mock_dataset(dataset_id, tenant_id):
"""Create a mock Dataset object."""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.tenant_id = tenant_id
dataset.indexing_technique = "high_quality"
dataset.embedding_model_provider = "openai"
dataset.embedding_model = "text-embedding-ada-002"
return dataset
@pytest.fixture
def mock_document(document_id, dataset_id, tenant_id, notion_workspace_id, notion_page_id, credential_id):
"""Create a mock Document object with Notion data source."""
doc = Mock(spec=Document)
doc.id = document_id
doc.dataset_id = dataset_id
doc.tenant_id = tenant_id
doc.data_source_type = "notion_import"
doc.indexing_status = "completed"
doc.error = None
doc.stopped_at = None
doc.processing_started_at = None
doc.doc_form = "text_model"
doc.data_source_info_dict = {
"notion_workspace_id": notion_workspace_id,
"notion_page_id": notion_page_id,
"type": "page",
"last_edited_time": "2024-01-01T00:00:00Z",
"credential_id": credential_id,
}
return doc
@pytest.fixture
def mock_document_segments(document_id):
"""Create mock DocumentSegment objects."""
segments = []
for i in range(3):
segment = Mock(spec=DocumentSegment)
segment.id = str(uuid.uuid4())
segment.document_id = document_id
segment.index_node_id = f"node-{document_id}-{i}"
segments.append(segment)
return segments
@pytest.fixture
def mock_db_session():
"""Mock database session."""
with patch("tasks.document_indexing_sync_task.db.session") as mock_session:
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_session.scalars.return_value = MagicMock()
yield mock_session
@pytest.fixture
def mock_datasource_provider_service():
"""Mock DatasourceProviderService."""
with patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_service_class:
mock_service = MagicMock()
mock_service.get_datasource_credentials.return_value = {"integration_secret": "test_token"}
mock_service_class.return_value = mock_service
yield mock_service
@pytest.fixture
def mock_notion_extractor():
"""Mock NotionExtractor."""
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
mock_extractor = MagicMock()
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" # Updated time
mock_extractor_class.return_value = mock_extractor
yield mock_extractor
@pytest.fixture
def mock_index_processor_factory():
"""Mock IndexProcessorFactory."""
with patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_factory:
mock_processor = MagicMock()
mock_processor.clean = Mock()
mock_factory.return_value.init_index_processor.return_value = mock_processor
yield mock_factory
@pytest.fixture
def mock_indexing_runner():
"""Mock IndexingRunner."""
with patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_runner_class:
mock_runner = MagicMock(spec=IndexingRunner)
mock_runner.run = Mock()
mock_runner_class.return_value = mock_runner
yield mock_runner
# ============================================================================
# Tests for document_indexing_sync_task
# ============================================================================
class TestDocumentIndexingSyncTask:
"""Tests for the document_indexing_sync_task function."""
def test_document_not_found(self, mock_db_session, dataset_id, document_id):
"""Test that task handles document not found gracefully."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = None
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
mock_db_session.close.assert_called_once()
def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id):
"""Test that task raises error when notion_workspace_id is missing."""
# Arrange
mock_document.data_source_info_dict = {"notion_page_id": "page123", "type": "page"}
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
# Act & Assert
with pytest.raises(ValueError, match="no notion page found"):
document_indexing_sync_task(dataset_id, document_id)
def test_missing_notion_page_id(self, mock_db_session, mock_document, dataset_id, document_id):
"""Test that task raises error when notion_page_id is missing."""
# Arrange
mock_document.data_source_info_dict = {"notion_workspace_id": "ws123", "type": "page"}
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
# Act & Assert
with pytest.raises(ValueError, match="no notion page found"):
document_indexing_sync_task(dataset_id, document_id)
def test_empty_data_source_info(self, mock_db_session, mock_document, dataset_id, document_id):
"""Test that task raises error when data_source_info is empty."""
# Arrange
mock_document.data_source_info_dict = None
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
# Act & Assert
with pytest.raises(ValueError, match="no notion page found"):
document_indexing_sync_task(dataset_id, document_id)
def test_credential_not_found(
self,
mock_db_session,
mock_datasource_provider_service,
mock_document,
dataset_id,
document_id,
):
"""Test that task handles missing credentials by updating document status."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_datasource_provider_service.get_datasource_credentials.return_value = None
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
assert mock_document.indexing_status == "error"
assert "Datasource credential not found" in mock_document.error
assert mock_document.stopped_at is not None
mock_db_session.commit.assert_called()
mock_db_session.close.assert_called()
def test_page_not_updated(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_document,
dataset_id,
document_id,
):
"""Test that task does nothing when page has not been updated."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
# Return same time as stored in document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Document status should remain unchanged
assert mock_document.indexing_status == "completed"
# No session operations should be performed beyond the initial query
mock_db_session.close.assert_not_called()
def test_successful_sync_when_page_updated(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_index_processor_factory,
mock_indexing_runner,
mock_dataset,
mock_document,
mock_document_segments,
dataset_id,
document_id,
):
"""Test successful sync flow when Notion page has been updated."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
# NotionExtractor returns updated time
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Verify document status was updated to parsing
assert mock_document.indexing_status == "parsing"
assert mock_document.processing_started_at is not None
# Verify segments were cleaned
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_processor.clean.assert_called_once()
# Verify segments were deleted from database
for segment in mock_document_segments:
mock_db_session.delete.assert_any_call(segment)
# Verify indexing runner was called
mock_indexing_runner.run.assert_called_once_with([mock_document])
# Verify session operations
assert mock_db_session.commit.called
mock_db_session.close.assert_called_once()
def test_dataset_not_found_during_cleaning(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_document,
dataset_id,
document_id,
):
"""Test that task handles dataset not found during cleaning phase."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, None]
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Document should still be set to parsing
assert mock_document.indexing_status == "parsing"
# Session should be closed after error
mock_db_session.close.assert_called_once()
def test_cleaning_error_continues_to_indexing(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_index_processor_factory,
mock_indexing_runner,
mock_dataset,
mock_document,
dataset_id,
document_id,
):
"""Test that indexing continues even if cleaning fails."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.side_effect = Exception("Cleaning error")
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Indexing should still be attempted despite cleaning error
mock_indexing_runner.run.assert_called_once_with([mock_document])
mock_db_session.close.assert_called_once()
def test_indexing_runner_document_paused_error(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_index_processor_factory,
mock_indexing_runner,
mock_dataset,
mock_document,
mock_document_segments,
dataset_id,
document_id,
):
"""Test that DocumentIsPausedError is handled gracefully."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Session should be closed after handling error
mock_db_session.close.assert_called_once()
def test_indexing_runner_general_error(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_index_processor_factory,
mock_indexing_runner,
mock_dataset,
mock_document,
mock_document_segments,
dataset_id,
document_id,
):
"""Test that general exceptions during indexing are handled."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
mock_indexing_runner.run.side_effect = Exception("Indexing error")
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
# Session should be closed after error
mock_db_session.close.assert_called_once()
def test_notion_extractor_initialized_with_correct_params(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_document,
dataset_id,
document_id,
notion_workspace_id,
notion_page_id,
):
"""Test that NotionExtractor is initialized with correct parameters."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" # No update
# Act
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
mock_extractor = MagicMock()
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
mock_extractor_class.return_value = mock_extractor
document_indexing_sync_task(dataset_id, document_id)
# Assert
mock_extractor_class.assert_called_once_with(
notion_workspace_id=notion_workspace_id,
notion_obj_id=notion_page_id,
notion_page_type="page",
notion_access_token="test_token",
tenant_id=mock_document.tenant_id,
)
def test_datasource_credentials_requested_correctly(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_document,
dataset_id,
document_id,
credential_id,
):
"""Test that datasource credentials are requested with correct parameters."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with(
tenant_id=mock_document.tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
def test_credential_id_missing_uses_none(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_document,
dataset_id,
document_id,
):
"""Test that task handles missing credential_id by passing None."""
# Arrange
mock_document.data_source_info_dict = {
"notion_workspace_id": "ws123",
"notion_page_id": "page123",
"type": "page",
"last_edited_time": "2024-01-01T00:00:00Z",
}
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
mock_datasource_provider_service.get_datasource_credentials.assert_called_once_with(
tenant_id=mock_document.tenant_id,
credential_id=None,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
def test_index_processor_clean_called_with_correct_params(
self,
mock_db_session,
mock_datasource_provider_service,
mock_notion_extractor,
mock_index_processor_factory,
mock_indexing_runner,
mock_dataset,
mock_document,
mock_document_segments,
dataset_id,
document_id,
):
"""Test that index processor clean is called with correct parameters."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
expected_node_ids = [seg.index_node_id for seg in mock_document_segments]
mock_processor.clean.assert_called_once_with(
mock_dataset, expected_node_ids, with_keywords=True, delete_child_chunks=True
)

View File

@ -1,6 +1,7 @@
import {
createContext,
useContext,
useEffect,
useRef,
} from 'react'
import {
@ -10,6 +11,7 @@ import {
import type {
FileEntity,
} from './types'
import { isEqual } from 'lodash-es'
type Shape = {
files: FileEntity[]
@ -55,10 +57,20 @@ export const FileContextProvider = ({
onChange,
}: FileProviderProps) => {
const storeRef = useRef<FileStore | undefined>(undefined)
if (!storeRef.current)
storeRef.current = createFileStore(value, onChange)
useEffect(() => {
if (!storeRef.current)
return
if (isEqual(value, storeRef.current.getState().files))
return
storeRef.current.setState({
files: value ? [...value] : [],
})
}, [value])
return (
<FileContext.Provider value={storeRef.current}>
{children}