mirror of
https://github.com/langgenius/dify.git
synced 2026-03-10 11:10:19 +08:00
test(api): add autospec to MagicMock-based patch usage (#32752)
This commit is contained in:
parent
c034eb036c
commit
20fcc95db9
@ -360,7 +360,7 @@ class TestEndToEndCacheFlow:
|
||||
class TestRedisFailover:
|
||||
"""Test behavior when Redis is unavailable."""
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
@patch("services.api_token_service.redis_client", autospec=True)
|
||||
def test_graceful_degradation_when_redis_fails(self, mock_redis):
|
||||
"""Test system degrades gracefully when Redis is unavailable."""
|
||||
from redis import RedisError
|
||||
|
||||
@ -41,17 +41,15 @@ class TestOpenSearchConfig:
|
||||
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
|
||||
assert params["http_auth"] == ("admin", "password")
|
||||
|
||||
@patch("boto3.Session")
|
||||
@patch("core.rag.datasource.vdb.opensearch.opensearch_vector.Urllib3AWSV4SignerAuth")
|
||||
@patch("boto3.Session", autospec=True)
|
||||
@patch("core.rag.datasource.vdb.opensearch.opensearch_vector.Urllib3AWSV4SignerAuth", autospec=True)
|
||||
def test_to_opensearch_params_with_aws_managed_iam(
|
||||
self, mock_aws_signer_auth: MagicMock, mock_boto_session: MagicMock
|
||||
):
|
||||
mock_credentials = MagicMock()
|
||||
mock_boto_session.return_value.get_credentials.return_value = mock_credentials
|
||||
|
||||
mock_auth_instance = MagicMock()
|
||||
mock_aws_signer_auth.return_value = mock_auth_instance
|
||||
|
||||
mock_auth_instance = mock_aws_signer_auth.return_value
|
||||
aws_region = "ap-southeast-2"
|
||||
aws_service = "aoss"
|
||||
host = f"aoss-endpoint.{aws_region}.aoss.amazonaws.com"
|
||||
@ -157,7 +155,7 @@ class TestOpenSearchVector:
|
||||
doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id})
|
||||
embedding = [0.1] * 128
|
||||
|
||||
with patch("opensearchpy.helpers.bulk") as mock_bulk:
|
||||
with patch("opensearchpy.helpers.bulk", autospec=True) as mock_bulk:
|
||||
mock_bulk.return_value = ([], [])
|
||||
self.vector.add_texts([doc], [embedding])
|
||||
|
||||
@ -171,7 +169,7 @@ class TestOpenSearchVector:
|
||||
doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id})
|
||||
embedding = [0.1] * 128
|
||||
|
||||
with patch("opensearchpy.helpers.bulk") as mock_bulk:
|
||||
with patch("opensearchpy.helpers.bulk", autospec=True) as mock_bulk:
|
||||
mock_bulk.return_value = ([], [])
|
||||
self.vector.add_texts([doc], [embedding])
|
||||
|
||||
|
||||
@ -19,14 +19,14 @@ class TestAgentService:
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.agent_service.PluginAgentClient") as mock_plugin_agent_client,
|
||||
patch("services.agent_service.ToolManager") as mock_tool_manager,
|
||||
patch("services.agent_service.AgentConfigManager") as mock_agent_config_manager,
|
||||
patch("services.agent_service.PluginAgentClient", autospec=True) as mock_plugin_agent_client,
|
||||
patch("services.agent_service.ToolManager", autospec=True) as mock_tool_manager,
|
||||
patch("services.agent_service.AgentConfigManager", autospec=True) as mock_agent_config_manager,
|
||||
patch("services.agent_service.current_user", create_autospec(Account, instance=True)) as mock_current_user,
|
||||
patch("services.app_service.FeatureService") as mock_feature_service,
|
||||
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
|
||||
patch("services.app_service.ModelManager") as mock_model_manager,
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
patch("services.app_service.FeatureService", autospec=True) as mock_feature_service,
|
||||
patch("services.app_service.EnterpriseService", autospec=True) as mock_enterprise_service,
|
||||
patch("services.app_service.ModelManager", autospec=True) as mock_model_manager,
|
||||
patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service,
|
||||
):
|
||||
# Setup default mock returns for agent service
|
||||
mock_plugin_agent_client_instance = mock_plugin_agent_client.return_value
|
||||
|
||||
@ -18,18 +18,22 @@ class TestAppGenerateService:
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.billing_service.BillingService") as mock_billing_service,
|
||||
patch("services.app_generate_service.WorkflowService") as mock_workflow_service,
|
||||
patch("services.app_generate_service.RateLimit") as mock_rate_limit,
|
||||
patch("services.app_generate_service.CompletionAppGenerator") as mock_completion_generator,
|
||||
patch("services.app_generate_service.ChatAppGenerator") as mock_chat_generator,
|
||||
patch("services.app_generate_service.AgentChatAppGenerator") as mock_agent_chat_generator,
|
||||
patch("services.app_generate_service.AdvancedChatAppGenerator") as mock_advanced_chat_generator,
|
||||
patch("services.app_generate_service.WorkflowAppGenerator") as mock_workflow_generator,
|
||||
patch("services.app_generate_service.MessageBasedAppGenerator") as mock_message_based_generator,
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
patch("services.app_generate_service.dify_config") as mock_dify_config,
|
||||
patch("configs.dify_config") as mock_global_dify_config,
|
||||
patch("services.billing_service.BillingService", autospec=True) as mock_billing_service,
|
||||
patch("services.app_generate_service.WorkflowService", autospec=True) as mock_workflow_service,
|
||||
patch("services.app_generate_service.RateLimit", autospec=True) as mock_rate_limit,
|
||||
patch("services.app_generate_service.CompletionAppGenerator", autospec=True) as mock_completion_generator,
|
||||
patch("services.app_generate_service.ChatAppGenerator", autospec=True) as mock_chat_generator,
|
||||
patch("services.app_generate_service.AgentChatAppGenerator", autospec=True) as mock_agent_chat_generator,
|
||||
patch(
|
||||
"services.app_generate_service.AdvancedChatAppGenerator", autospec=True
|
||||
) as mock_advanced_chat_generator,
|
||||
patch("services.app_generate_service.WorkflowAppGenerator", autospec=True) as mock_workflow_generator,
|
||||
patch(
|
||||
"services.app_generate_service.MessageBasedAppGenerator", autospec=True
|
||||
) as mock_message_based_generator,
|
||||
patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service,
|
||||
patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config,
|
||||
patch("configs.dify_config", autospec=True) as mock_global_dify_config,
|
||||
):
|
||||
# Setup default mock returns for billing service
|
||||
mock_billing_service.update_tenant_feature_plan_usage.return_value = {
|
||||
@ -983,7 +987,7 @@ class TestAppGenerateService:
|
||||
}
|
||||
|
||||
# Execute the method under test
|
||||
with patch("services.app_generate_service.AppExecutionParams") as mock_exec_params:
|
||||
with patch("services.app_generate_service.AppExecutionParams", autospec=True) as mock_exec_params:
|
||||
mock_payload = MagicMock()
|
||||
mock_payload.workflow_run_id = fake.uuid4()
|
||||
mock_payload.model_dump_json.return_value = "{}"
|
||||
|
||||
@ -17,10 +17,12 @@ class TestModelLoadBalancingService:
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.model_load_balancing_service.ProviderManager") as mock_provider_manager,
|
||||
patch("services.model_load_balancing_service.LBModelManager") as mock_lb_model_manager,
|
||||
patch("services.model_load_balancing_service.ModelProviderFactory") as mock_model_provider_factory,
|
||||
patch("services.model_load_balancing_service.encrypter") as mock_encrypter,
|
||||
patch("services.model_load_balancing_service.ProviderManager", autospec=True) as mock_provider_manager,
|
||||
patch("services.model_load_balancing_service.LBModelManager", autospec=True) as mock_lb_model_manager,
|
||||
patch(
|
||||
"services.model_load_balancing_service.ModelProviderFactory", autospec=True
|
||||
) as mock_model_provider_factory,
|
||||
patch("services.model_load_balancing_service.encrypter", autospec=True) as mock_encrypter,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_provider_manager_instance = mock_provider_manager.return_value
|
||||
|
||||
@ -17,8 +17,8 @@ class TestModelProviderService:
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("services.model_provider_service.ProviderManager") as mock_provider_manager,
|
||||
patch("services.model_provider_service.ModelProviderFactory") as mock_model_provider_factory,
|
||||
patch("services.model_provider_service.ProviderManager", autospec=True) as mock_provider_manager,
|
||||
patch("services.model_provider_service.ModelProviderFactory", autospec=True) as mock_model_provider_factory,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_provider_manager.return_value.get_configurations.return_value = MagicMock()
|
||||
@ -526,7 +526,9 @@ class TestModelProviderService:
|
||||
|
||||
# Act: Execute the method under test
|
||||
service = ModelProviderService()
|
||||
with patch.object(service, "get_provider_credential", return_value=expected_credentials) as mock_method:
|
||||
with patch.object(
|
||||
service, "get_provider_credential", return_value=expected_credentials, autospec=True
|
||||
) as mock_method:
|
||||
result = service.get_provider_credential(tenant.id, "openai")
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
@ -854,7 +856,9 @@ class TestModelProviderService:
|
||||
|
||||
# Act: Execute the method under test
|
||||
service = ModelProviderService()
|
||||
with patch.object(service, "get_model_credential", return_value=expected_credentials) as mock_method:
|
||||
with patch.object(
|
||||
service, "get_model_credential", return_value=expected_credentials, autospec=True
|
||||
) as mock_method:
|
||||
result = service.get_model_credential(tenant.id, "openai", "llm", "gpt-4", None)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
|
||||
@ -22,16 +22,13 @@ class TestWebhookService:
|
||||
def mock_external_dependencies(self):
|
||||
"""Mock external service dependencies."""
|
||||
with (
|
||||
patch("services.trigger.webhook_service.AsyncWorkflowService") as mock_async_service,
|
||||
patch("services.trigger.webhook_service.ToolFileManager") as mock_tool_file_manager,
|
||||
patch("services.trigger.webhook_service.file_factory") as mock_file_factory,
|
||||
patch("services.account_service.FeatureService") as mock_feature_service,
|
||||
patch("services.trigger.webhook_service.AsyncWorkflowService", autospec=True) as mock_async_service,
|
||||
patch("services.trigger.webhook_service.ToolFileManager", autospec=True) as mock_tool_file_manager,
|
||||
patch("services.trigger.webhook_service.file_factory", autospec=True) as mock_file_factory,
|
||||
patch("services.account_service.FeatureService", autospec=True) as mock_feature_service,
|
||||
):
|
||||
# Mock ToolFileManager
|
||||
mock_tool_file_instance = MagicMock()
|
||||
mock_tool_file_manager.return_value = mock_tool_file_instance
|
||||
|
||||
# Mock file creation
|
||||
mock_tool_file_instance = mock_tool_file_manager.return_value # Mock file creation
|
||||
mock_tool_file = MagicMock()
|
||||
mock_tool_file.id = "test_file_id"
|
||||
mock_tool_file_instance.create_file_by_raw.return_value = mock_tool_file
|
||||
@ -435,12 +432,12 @@ class TestWebhookService:
|
||||
|
||||
with flask_app_with_containers.app_context():
|
||||
# Mock tenant owner lookup to return the test account
|
||||
with patch("services.trigger.webhook_service.select") as mock_select:
|
||||
with patch("services.trigger.webhook_service.select", autospec=True) as mock_select:
|
||||
mock_query = MagicMock()
|
||||
mock_select.return_value.join.return_value.where.return_value = mock_query
|
||||
|
||||
# Mock the session to return our test account
|
||||
with patch("services.trigger.webhook_service.Session") as mock_session:
|
||||
with patch("services.trigger.webhook_service.Session", autospec=True) as mock_session:
|
||||
mock_session_instance = MagicMock()
|
||||
mock_session.return_value.__enter__.return_value = mock_session_instance
|
||||
mock_session_instance.scalar.return_value = test_data["account"]
|
||||
@ -462,7 +459,7 @@ class TestWebhookService:
|
||||
with flask_app_with_containers.app_context():
|
||||
# Mock EndUserService to raise an exception
|
||||
with patch(
|
||||
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type"
|
||||
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", autospec=True
|
||||
) as mock_end_user:
|
||||
mock_end_user.side_effect = ValueError("Failed to create end user")
|
||||
|
||||
|
||||
@ -764,7 +764,7 @@ class TestWorkflowService:
|
||||
# Act - Mock current_user context and pass session
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch("flask_login.utils._get_user", return_value=account):
|
||||
with patch("flask_login.utils._get_user", return_value=account, autospec=True):
|
||||
result = workflow_service.publish_workflow(
|
||||
session=db_session_with_containers, app_model=app, account=account
|
||||
)
|
||||
@ -1401,6 +1401,7 @@ class TestWorkflowService:
|
||||
DifyNodeFactory,
|
||||
"_build_model_instance_for_llm_node",
|
||||
return_value=MagicMock(spec=ModelInstance),
|
||||
autospec=True,
|
||||
):
|
||||
result = workflow_service.run_free_workflow_node(
|
||||
node_data=node_data, tenant_id=tenant_id, user_id=user_id, node_id=node_id, user_inputs=user_inputs
|
||||
|
||||
@ -18,7 +18,9 @@ class TestAddDocumentToIndexTask:
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("tasks.add_document_to_index_task.IndexProcessorFactory") as mock_index_processor_factory,
|
||||
patch(
|
||||
"tasks.add_document_to_index_task.IndexProcessorFactory", autospec=True
|
||||
) as mock_index_processor_factory,
|
||||
):
|
||||
# Setup mock index processor
|
||||
mock_processor = MagicMock()
|
||||
@ -378,7 +380,7 @@ class TestAddDocumentToIndexTask:
|
||||
redis_client.set(indexing_cache_key, "processing", ex=300)
|
||||
|
||||
# Mock the get_child_chunks method for each segment
|
||||
with patch.object(DocumentSegment, "get_child_chunks") as mock_get_child_chunks:
|
||||
with patch.object(DocumentSegment, "get_child_chunks", autospec=True) as mock_get_child_chunks:
|
||||
# Setup mock to return child chunks for each segment
|
||||
mock_child_chunks = []
|
||||
for i in range(2): # Each segment has 2 child chunks
|
||||
|
||||
@ -51,9 +51,9 @@ class TestBatchCreateSegmentToIndexTask:
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("tasks.batch_create_segment_to_index_task.storage") as mock_storage,
|
||||
patch("tasks.batch_create_segment_to_index_task.ModelManager") as mock_model_manager,
|
||||
patch("tasks.batch_create_segment_to_index_task.VectorService") as mock_vector_service,
|
||||
patch("tasks.batch_create_segment_to_index_task.storage", autospec=True) as mock_storage,
|
||||
patch("tasks.batch_create_segment_to_index_task.ModelManager", autospec=True) as mock_model_manager,
|
||||
patch("tasks.batch_create_segment_to_index_task.VectorService", autospec=True) as mock_vector_service,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_storage.download.return_value = None
|
||||
|
||||
@ -63,8 +63,8 @@ class TestCleanDatasetTask:
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("tasks.clean_dataset_task.storage") as mock_storage,
|
||||
patch("tasks.clean_dataset_task.IndexProcessorFactory") as mock_index_processor_factory,
|
||||
patch("tasks.clean_dataset_task.storage", autospec=True) as mock_storage,
|
||||
patch("tasks.clean_dataset_task.IndexProcessorFactory", autospec=True) as mock_index_processor_factory,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_storage.delete.return_value = None
|
||||
@ -597,7 +597,7 @@ class TestCleanDatasetTask:
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Mock the get_image_upload_file_ids function to return our image file IDs
|
||||
with patch("tasks.clean_dataset_task.get_image_upload_file_ids") as mock_get_image_ids:
|
||||
with patch("tasks.clean_dataset_task.get_image_upload_file_ids", autospec=True) as mock_get_image_ids:
|
||||
mock_get_image_ids.return_value = [f.id for f in image_files]
|
||||
|
||||
# Execute the task
|
||||
|
||||
@ -41,7 +41,7 @@ class TestCreateSegmentToIndexTask:
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("tasks.create_segment_to_index_task.IndexProcessorFactory") as mock_factory,
|
||||
patch("tasks.create_segment_to_index_task.IndexProcessorFactory", autospec=True) as mock_factory,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_processor = MagicMock()
|
||||
@ -708,7 +708,7 @@ class TestCreateSegmentToIndexTask:
|
||||
redis_client.set(cache_key, "processing", ex=300)
|
||||
|
||||
# Mock Redis to raise exception in finally block
|
||||
with patch.object(redis_client, "delete", side_effect=Exception("Redis connection failed")):
|
||||
with patch.object(redis_client, "delete", side_effect=Exception("Redis connection failed"), autospec=True):
|
||||
# Act: Execute the task - Redis failure should not prevent completion
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
create_segment_to_index_task(segment.id)
|
||||
|
||||
@ -37,7 +37,7 @@ class _TrackedSessionContext:
|
||||
self._closed_sessions.append(self._session)
|
||||
return original_close(*args, **kwargs)
|
||||
|
||||
self._close_patcher = patch.object(self._session, "close", side_effect=_tracked_close)
|
||||
self._close_patcher = patch.object(self._session, "close", side_effect=_tracked_close, autospec=True)
|
||||
self._close_patcher.start()
|
||||
return self._session
|
||||
|
||||
@ -69,7 +69,9 @@ def session_close_tracker():
|
||||
original_context_manager = original_create_session(*args, **kwargs)
|
||||
return _TrackedSessionContext(original_context_manager, opened_sessions, closed_sessions)
|
||||
|
||||
with patch.object(task_module.session_factory, "create_session", side_effect=_tracked_create_session):
|
||||
with patch.object(
|
||||
task_module.session_factory, "create_session", side_effect=_tracked_create_session, autospec=True
|
||||
):
|
||||
yield {"opened_sessions": opened_sessions, "closed_sessions": closed_sessions}
|
||||
|
||||
|
||||
@ -77,13 +79,11 @@ def session_close_tracker():
|
||||
def patched_external_dependencies():
|
||||
"""Patch non-DB collaborators while keeping database behavior real."""
|
||||
with (
|
||||
patch("tasks.document_indexing_task.IndexingRunner") as mock_indexing_runner,
|
||||
patch("tasks.document_indexing_task.FeatureService") as mock_feature_service,
|
||||
patch("tasks.document_indexing_task.generate_summary_index_task") as mock_summary_task,
|
||||
patch("tasks.document_indexing_task.IndexingRunner", autospec=True) as mock_indexing_runner,
|
||||
patch("tasks.document_indexing_task.FeatureService", autospec=True) as mock_feature_service,
|
||||
patch("tasks.document_indexing_task.generate_summary_index_task", autospec=True) as mock_summary_task,
|
||||
):
|
||||
mock_runner_instance = MagicMock()
|
||||
mock_indexing_runner.return_value = mock_runner_instance
|
||||
|
||||
mock_runner_instance = mock_indexing_runner.return_value
|
||||
mock_features = MagicMock()
|
||||
mock_features.billing.enabled = False
|
||||
mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
||||
@ -307,9 +307,17 @@ class TestDatasetIndexingTaskIntegration:
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[next_task]),
|
||||
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time") as set_waiting_spy,
|
||||
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key") as delete_key_spy,
|
||||
patch(
|
||||
"tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks",
|
||||
return_value=[next_task],
|
||||
autospec=True,
|
||||
),
|
||||
patch(
|
||||
"tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time", autospec=True
|
||||
) as set_waiting_spy,
|
||||
patch(
|
||||
"tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key", autospec=True
|
||||
) as delete_key_spy,
|
||||
):
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
@ -336,8 +344,10 @@ class TestDatasetIndexingTaskIntegration:
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[]),
|
||||
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key") as delete_key_spy,
|
||||
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[], autospec=True),
|
||||
patch(
|
||||
"tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key", autospec=True
|
||||
) as delete_key_spy,
|
||||
):
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
@ -426,9 +436,13 @@ class TestDatasetIndexingTaskIntegration:
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch("tasks.document_indexing_task._document_indexing", side_effect=Exception("failed")),
|
||||
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[next_task]),
|
||||
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time"),
|
||||
patch("tasks.document_indexing_task._document_indexing", side_effect=Exception("failed"), autospec=True),
|
||||
patch(
|
||||
"tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks",
|
||||
return_value=[next_task],
|
||||
autospec=True,
|
||||
),
|
||||
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time", autospec=True),
|
||||
):
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
@ -511,8 +525,11 @@ class TestDatasetIndexingTaskIntegration:
|
||||
patch(
|
||||
"tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks",
|
||||
return_value=pending_tasks[:concurrency_limit],
|
||||
autospec=True,
|
||||
),
|
||||
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time") as set_waiting_spy,
|
||||
patch(
|
||||
"tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time", autospec=True
|
||||
) as set_waiting_spy,
|
||||
):
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
@ -538,8 +555,12 @@ class TestDatasetIndexingTaskIntegration:
|
||||
# Act
|
||||
with (
|
||||
patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", 3),
|
||||
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=ordered_tasks),
|
||||
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time"),
|
||||
patch(
|
||||
"tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks",
|
||||
return_value=ordered_tasks,
|
||||
autospec=True,
|
||||
),
|
||||
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time", autospec=True),
|
||||
):
|
||||
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
|
||||
|
||||
@ -578,8 +599,10 @@ class TestDatasetIndexingTaskIntegration:
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[]),
|
||||
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key") as delete_key_spy,
|
||||
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[], autospec=True),
|
||||
patch(
|
||||
"tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key", autospec=True
|
||||
) as delete_key_spy,
|
||||
):
|
||||
normal_document_indexing_task(dataset.tenant_id, dataset.id, document_ids)
|
||||
|
||||
@ -599,8 +622,10 @@ class TestDatasetIndexingTaskIntegration:
|
||||
|
||||
# Act
|
||||
with (
|
||||
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[]),
|
||||
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key") as delete_key_spy,
|
||||
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[], autospec=True),
|
||||
patch(
|
||||
"tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key", autospec=True
|
||||
) as delete_key_spy,
|
||||
):
|
||||
priority_document_indexing_task(dataset.tenant_id, dataset.id, document_ids)
|
||||
|
||||
|
||||
@ -216,7 +216,7 @@ class TestDeleteSegmentFromIndexTask:
|
||||
db_session_with_containers.commit()
|
||||
return segments
|
||||
|
||||
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory")
|
||||
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory", autospec=True)
|
||||
def test_delete_segment_from_index_task_success(self, mock_index_processor_factory, db_session_with_containers):
|
||||
"""
|
||||
Test successful segment deletion from index with comprehensive verification.
|
||||
@ -399,7 +399,7 @@ class TestDeleteSegmentFromIndexTask:
|
||||
# Verify the task completed without exceptions
|
||||
assert result is None # Task should return None when indexing is not completed
|
||||
|
||||
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory")
|
||||
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory", autospec=True)
|
||||
def test_delete_segment_from_index_task_index_processor_clean(
|
||||
self, mock_index_processor_factory, db_session_with_containers
|
||||
):
|
||||
@ -457,7 +457,7 @@ class TestDeleteSegmentFromIndexTask:
|
||||
mock_index_processor_factory.reset_mock()
|
||||
mock_processor.reset_mock()
|
||||
|
||||
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory")
|
||||
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory", autospec=True)
|
||||
def test_delete_segment_from_index_task_exception_handling(
|
||||
self, mock_index_processor_factory, db_session_with_containers
|
||||
):
|
||||
@ -501,7 +501,7 @@ class TestDeleteSegmentFromIndexTask:
|
||||
assert call_args[1]["with_keywords"] is True
|
||||
assert call_args[1]["delete_child_chunks"] is True
|
||||
|
||||
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory")
|
||||
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory", autospec=True)
|
||||
def test_delete_segment_from_index_task_empty_index_node_ids(
|
||||
self, mock_index_processor_factory, db_session_with_containers
|
||||
):
|
||||
@ -543,7 +543,7 @@ class TestDeleteSegmentFromIndexTask:
|
||||
assert call_args[1]["with_keywords"] is True
|
||||
assert call_args[1]["delete_child_chunks"] is True
|
||||
|
||||
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory")
|
||||
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory", autospec=True)
|
||||
def test_delete_segment_from_index_task_large_index_node_ids(
|
||||
self, mock_index_processor_factory, db_session_with_containers
|
||||
):
|
||||
|
||||
@ -32,14 +32,11 @@ class TestDocumentIndexingTasks:
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("tasks.document_indexing_task.IndexingRunner") as mock_indexing_runner,
|
||||
patch("tasks.document_indexing_task.FeatureService") as mock_feature_service,
|
||||
patch("tasks.document_indexing_task.IndexingRunner", autospec=True) as mock_indexing_runner,
|
||||
patch("tasks.document_indexing_task.FeatureService", autospec=True) as mock_feature_service,
|
||||
):
|
||||
# Setup mock indexing runner
|
||||
mock_runner_instance = MagicMock()
|
||||
mock_indexing_runner.return_value = mock_runner_instance
|
||||
|
||||
# Setup mock feature service
|
||||
mock_runner_instance = mock_indexing_runner.return_value # Setup mock feature service
|
||||
mock_features = MagicMock()
|
||||
mock_features.billing.enabled = False
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
|
||||
@ -16,15 +16,13 @@ class TestDocumentIndexingUpdateTask:
|
||||
- IndexingRunner.run([...])
|
||||
"""
|
||||
with (
|
||||
patch("tasks.document_indexing_update_task.IndexProcessorFactory") as mock_factory,
|
||||
patch("tasks.document_indexing_update_task.IndexingRunner") as mock_runner,
|
||||
patch("tasks.document_indexing_update_task.IndexProcessorFactory", autospec=True) as mock_factory,
|
||||
patch("tasks.document_indexing_update_task.IndexingRunner", autospec=True) as mock_runner,
|
||||
):
|
||||
processor_instance = MagicMock()
|
||||
mock_factory.return_value.init_index_processor.return_value = processor_instance
|
||||
|
||||
runner_instance = MagicMock()
|
||||
mock_runner.return_value = runner_instance
|
||||
|
||||
runner_instance = mock_runner.return_value
|
||||
yield {
|
||||
"factory": mock_factory,
|
||||
"processor": processor_instance,
|
||||
|
||||
@ -31,15 +31,14 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("tasks.duplicate_document_indexing_task.IndexingRunner") as mock_indexing_runner,
|
||||
patch("tasks.duplicate_document_indexing_task.FeatureService") as mock_feature_service,
|
||||
patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory") as mock_index_processor_factory,
|
||||
patch("tasks.duplicate_document_indexing_task.IndexingRunner", autospec=True) as mock_indexing_runner,
|
||||
patch("tasks.duplicate_document_indexing_task.FeatureService", autospec=True) as mock_feature_service,
|
||||
patch(
|
||||
"tasks.duplicate_document_indexing_task.IndexProcessorFactory", autospec=True
|
||||
) as mock_index_processor_factory,
|
||||
):
|
||||
# Setup mock indexing runner
|
||||
mock_runner_instance = MagicMock()
|
||||
mock_indexing_runner.return_value = mock_runner_instance
|
||||
|
||||
# Setup mock feature service
|
||||
mock_runner_instance = mock_indexing_runner.return_value # Setup mock feature service
|
||||
mock_features = MagicMock()
|
||||
mock_features.billing.enabled = False
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
@ -650,7 +649,7 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
|
||||
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True)
|
||||
def test_normal_duplicate_document_indexing_task_with_tenant_queue(
|
||||
self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
@ -693,7 +692,7 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
|
||||
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True)
|
||||
def test_priority_duplicate_document_indexing_task_with_tenant_queue(
|
||||
self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
@ -737,7 +736,7 @@ class TestDuplicateDocumentIndexingTasks:
|
||||
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
|
||||
assert updated_document.indexing_status == "parsing"
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
|
||||
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True)
|
||||
def test_tenant_queue_wrapper_processes_next_tasks(
|
||||
self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
|
||||
@ -18,7 +18,9 @@ class TestEnableSegmentsToIndexTask:
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("tasks.enable_segments_to_index_task.IndexProcessorFactory") as mock_index_processor_factory,
|
||||
patch(
|
||||
"tasks.enable_segments_to_index_task.IndexProcessorFactory", autospec=True
|
||||
) as mock_index_processor_factory,
|
||||
):
|
||||
# Setup mock index processor
|
||||
mock_processor = MagicMock()
|
||||
@ -370,7 +372,7 @@ class TestEnableSegmentsToIndexTask:
|
||||
redis_client.set(indexing_cache_key, "processing", ex=300)
|
||||
|
||||
# Mock the get_child_chunks method for each segment
|
||||
with patch.object(DocumentSegment, "get_child_chunks") as mock_get_child_chunks:
|
||||
with patch.object(DocumentSegment, "get_child_chunks", autospec=True) as mock_get_child_chunks:
|
||||
# Setup mock to return child chunks for each segment
|
||||
mock_child_chunks = []
|
||||
for i in range(2): # Each segment has 2 child chunks
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
@ -16,16 +16,14 @@ class TestMailAccountDeletionTask:
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("tasks.mail_account_deletion_task.mail") as mock_mail,
|
||||
patch("tasks.mail_account_deletion_task.get_email_i18n_service") as mock_get_email_service,
|
||||
patch("tasks.mail_account_deletion_task.mail", autospec=True) as mock_mail,
|
||||
patch("tasks.mail_account_deletion_task.get_email_i18n_service", autospec=True) as mock_get_email_service,
|
||||
):
|
||||
# Setup mock mail service
|
||||
mock_mail.is_inited.return_value = True
|
||||
|
||||
# Setup mock email service
|
||||
mock_email_service = MagicMock()
|
||||
mock_get_email_service.return_value = mock_email_service
|
||||
|
||||
mock_email_service = mock_get_email_service.return_value
|
||||
yield {
|
||||
"mail": mock_mail,
|
||||
"get_email_service": mock_get_email_service,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
@ -15,16 +15,14 @@ class TestMailChangeMailTask:
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("tasks.mail_change_mail_task.mail") as mock_mail,
|
||||
patch("tasks.mail_change_mail_task.get_email_i18n_service") as mock_get_email_i18n_service,
|
||||
patch("tasks.mail_change_mail_task.mail", autospec=True) as mock_mail,
|
||||
patch("tasks.mail_change_mail_task.get_email_i18n_service", autospec=True) as mock_get_email_i18n_service,
|
||||
):
|
||||
# Setup mock mail service
|
||||
mock_mail.is_inited.return_value = True
|
||||
|
||||
# Setup mock email i18n service
|
||||
mock_email_service = MagicMock()
|
||||
mock_get_email_i18n_service.return_value = mock_email_service
|
||||
|
||||
mock_email_service = mock_get_email_i18n_service.return_value
|
||||
yield {
|
||||
"mail": mock_mail,
|
||||
"email_i18n_service": mock_email_service,
|
||||
|
||||
@ -53,8 +53,8 @@ class TestSendEmailCodeLoginMailTask:
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("tasks.mail_email_code_login.mail") as mock_mail,
|
||||
patch("tasks.mail_email_code_login.get_email_i18n_service") as mock_email_service,
|
||||
patch("tasks.mail_email_code_login.mail", autospec=True) as mock_mail,
|
||||
patch("tasks.mail_email_code_login.get_email_i18n_service", autospec=True) as mock_email_service,
|
||||
):
|
||||
# Setup default mock returns
|
||||
mock_mail.is_inited.return_value = True
|
||||
@ -573,7 +573,7 @@ class TestSendEmailCodeLoginMailTask:
|
||||
mock_email_service_instance.send_email.side_effect = exception
|
||||
|
||||
# Mock logging to capture error messages
|
||||
with patch("tasks.mail_email_code_login.logger") as mock_logger:
|
||||
with patch("tasks.mail_email_code_login.logger", autospec=True) as mock_logger:
|
||||
# Act: Execute the task - it should handle the exception gracefully
|
||||
send_email_code_login_mail_task(
|
||||
language=test_language,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
@ -13,18 +13,15 @@ class TestMailInnerTask:
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("tasks.mail_inner_task.mail") as mock_mail,
|
||||
patch("tasks.mail_inner_task.get_email_i18n_service") as mock_get_email_i18n_service,
|
||||
patch("tasks.mail_inner_task._render_template_with_strategy") as mock_render_template,
|
||||
patch("tasks.mail_inner_task.mail", autospec=True) as mock_mail,
|
||||
patch("tasks.mail_inner_task.get_email_i18n_service", autospec=True) as mock_get_email_i18n_service,
|
||||
patch("tasks.mail_inner_task._render_template_with_strategy", autospec=True) as mock_render_template,
|
||||
):
|
||||
# Setup mock mail service
|
||||
mock_mail.is_inited.return_value = True
|
||||
|
||||
# Setup mock email i18n service
|
||||
mock_email_service = MagicMock()
|
||||
mock_get_email_i18n_service.return_value = mock_email_service
|
||||
|
||||
# Setup mock template rendering
|
||||
mock_email_service = mock_get_email_i18n_service.return_value # Setup mock template rendering
|
||||
mock_render_template.return_value = "<html>Test email content</html>"
|
||||
|
||||
yield {
|
||||
|
||||
@ -56,9 +56,9 @@ class TestMailInviteMemberTask:
|
||||
def mock_external_service_dependencies(self):
|
||||
"""Mock setup for external service dependencies."""
|
||||
with (
|
||||
patch("tasks.mail_invite_member_task.mail") as mock_mail,
|
||||
patch("tasks.mail_invite_member_task.get_email_i18n_service") as mock_email_service,
|
||||
patch("tasks.mail_invite_member_task.dify_config") as mock_config,
|
||||
patch("tasks.mail_invite_member_task.mail", autospec=True) as mock_mail,
|
||||
patch("tasks.mail_invite_member_task.get_email_i18n_service", autospec=True) as mock_email_service,
|
||||
patch("tasks.mail_invite_member_task.dify_config", autospec=True) as mock_config,
|
||||
):
|
||||
# Setup mail service mock
|
||||
mock_mail.is_inited.return_value = True
|
||||
@ -306,7 +306,7 @@ class TestMailInviteMemberTask:
|
||||
mock_email_service.send_email.side_effect = Exception("Email service failed")
|
||||
|
||||
# Act & Assert: Execute task and verify exception is handled
|
||||
with patch("tasks.mail_invite_member_task.logger") as mock_logger:
|
||||
with patch("tasks.mail_invite_member_task.logger", autospec=True) as mock_logger:
|
||||
send_invite_member_mail_task(
|
||||
language="en-US",
|
||||
to="test@example.com",
|
||||
|
||||
@ -7,7 +7,7 @@ testing with actual database and service dependencies.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
@ -30,16 +30,14 @@ class TestMailOwnerTransferTask:
|
||||
def mock_mail_dependencies(self):
|
||||
"""Mock setup for mail service dependencies."""
|
||||
with (
|
||||
patch("tasks.mail_owner_transfer_task.mail") as mock_mail,
|
||||
patch("tasks.mail_owner_transfer_task.get_email_i18n_service") as mock_get_email_service,
|
||||
patch("tasks.mail_owner_transfer_task.mail", autospec=True) as mock_mail,
|
||||
patch("tasks.mail_owner_transfer_task.get_email_i18n_service", autospec=True) as mock_get_email_service,
|
||||
):
|
||||
# Setup mock mail service
|
||||
mock_mail.is_inited.return_value = True
|
||||
|
||||
# Setup mock email service
|
||||
mock_email_service = MagicMock()
|
||||
mock_get_email_service.return_value = mock_email_service
|
||||
|
||||
mock_email_service = mock_get_email_service.return_value
|
||||
yield {
|
||||
"mail": mock_mail,
|
||||
"email_service": mock_email_service,
|
||||
|
||||
@ -5,7 +5,7 @@ This module provides integration tests for email registration tasks
|
||||
using TestContainers to ensure real database and service interactions.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
@ -21,16 +21,14 @@ class TestMailRegisterTask:
|
||||
def mock_mail_dependencies(self):
|
||||
"""Mock setup for mail service dependencies."""
|
||||
with (
|
||||
patch("tasks.mail_register_task.mail") as mock_mail,
|
||||
patch("tasks.mail_register_task.get_email_i18n_service") as mock_get_email_service,
|
||||
patch("tasks.mail_register_task.mail", autospec=True) as mock_mail,
|
||||
patch("tasks.mail_register_task.get_email_i18n_service", autospec=True) as mock_get_email_service,
|
||||
):
|
||||
# Setup mock mail service
|
||||
mock_mail.is_inited.return_value = True
|
||||
|
||||
# Setup mock email i18n service
|
||||
mock_email_service = MagicMock()
|
||||
mock_get_email_service.return_value = mock_email_service
|
||||
|
||||
mock_email_service = mock_get_email_service.return_value
|
||||
yield {
|
||||
"mail": mock_mail,
|
||||
"email_service": mock_email_service,
|
||||
@ -76,7 +74,7 @@ class TestMailRegisterTask:
|
||||
to_email = fake.email()
|
||||
code = fake.numerify("######")
|
||||
|
||||
with patch("tasks.mail_register_task.logger") as mock_logger:
|
||||
with patch("tasks.mail_register_task.logger", autospec=True) as mock_logger:
|
||||
send_email_register_mail_task(language="en-US", to=to_email, code=code)
|
||||
mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", to_email)
|
||||
|
||||
@ -89,7 +87,7 @@ class TestMailRegisterTask:
|
||||
to_email = fake.email()
|
||||
account_name = fake.name()
|
||||
|
||||
with patch("tasks.mail_register_task.dify_config") as mock_config:
|
||||
with patch("tasks.mail_register_task.dify_config", autospec=True) as mock_config:
|
||||
mock_config.CONSOLE_WEB_URL = "https://console.dify.ai"
|
||||
|
||||
send_email_register_mail_task_when_account_exist(language=language, to=to_email, account_name=account_name)
|
||||
@ -129,6 +127,6 @@ class TestMailRegisterTask:
|
||||
to_email = fake.email()
|
||||
account_name = fake.name()
|
||||
|
||||
with patch("tasks.mail_register_task.logger") as mock_logger:
|
||||
with patch("tasks.mail_register_task.logger", autospec=True) as mock_logger:
|
||||
send_email_register_mail_task_when_account_exist(language="en-US", to=to_email, account_name=account_name)
|
||||
mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", to_email)
|
||||
|
||||
@ -12,9 +12,17 @@ def test_get_conversation_mark_read_keeps_updated_at_unchanged():
|
||||
conversation.id = "conversation-id"
|
||||
|
||||
with (
|
||||
patch("controllers.console.app.conversation.current_account_with_tenant", return_value=(account, None)),
|
||||
patch("controllers.console.app.conversation.naive_utc_now", return_value=datetime(2026, 2, 9, 0, 0, 0)),
|
||||
patch("controllers.console.app.conversation.db.session") as mock_session,
|
||||
patch(
|
||||
"controllers.console.app.conversation.current_account_with_tenant",
|
||||
return_value=(account, None),
|
||||
autospec=True,
|
||||
),
|
||||
patch(
|
||||
"controllers.console.app.conversation.naive_utc_now",
|
||||
return_value=datetime(2026, 2, 9, 0, 0, 0),
|
||||
autospec=True,
|
||||
),
|
||||
patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session,
|
||||
):
|
||||
mock_session.query.return_value.where.return_value.first.return_value = conversation
|
||||
|
||||
|
||||
@ -40,7 +40,7 @@ class TestWorkflowDraftVariableFields:
|
||||
mock_variable.variable_file = mock_variable_file
|
||||
|
||||
# Mock the file helpers
|
||||
with patch("controllers.console.app.workflow_draft_variable.file_helpers") as mock_file_helpers:
|
||||
with patch("controllers.console.app.workflow_draft_variable.file_helpers", autospec=True) as mock_file_helpers:
|
||||
mock_file_helpers.get_signed_file_url.return_value = "http://example.com/signed-url"
|
||||
|
||||
# Call the function
|
||||
@ -203,7 +203,7 @@ class TestWorkflowDraftVariableFields:
|
||||
}
|
||||
)
|
||||
|
||||
with patch("controllers.console.app.workflow_draft_variable.file_helpers") as mock_file_helpers:
|
||||
with patch("controllers.console.app.workflow_draft_variable.file_helpers", autospec=True) as mock_file_helpers:
|
||||
mock_file_helpers.get_signed_file_url.return_value = "http://example.com/signed-url"
|
||||
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
|
||||
expected_with_value = expected_without_value.copy()
|
||||
|
||||
@ -47,8 +47,8 @@ class TestRefreshTokenApi:
|
||||
token_pair.csrf_token = "new_csrf_token"
|
||||
return token_pair
|
||||
|
||||
@patch("controllers.console.auth.login.extract_refresh_token")
|
||||
@patch("controllers.console.auth.login.AccountService.refresh_token")
|
||||
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
|
||||
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
|
||||
def test_successful_token_refresh(self, mock_refresh_token, mock_extract_token, app, mock_token_pair):
|
||||
"""
|
||||
Test successful token refresh flow.
|
||||
@ -73,7 +73,7 @@ class TestRefreshTokenApi:
|
||||
mock_refresh_token.assert_called_once_with("valid_refresh_token")
|
||||
assert response.json["result"] == "success"
|
||||
|
||||
@patch("controllers.console.auth.login.extract_refresh_token")
|
||||
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
|
||||
def test_refresh_fails_without_token(self, mock_extract_token, app):
|
||||
"""
|
||||
Test token refresh failure when no refresh token provided.
|
||||
@ -96,8 +96,8 @@ class TestRefreshTokenApi:
|
||||
assert response["result"] == "fail"
|
||||
assert "No refresh token provided" in response["message"]
|
||||
|
||||
@patch("controllers.console.auth.login.extract_refresh_token")
|
||||
@patch("controllers.console.auth.login.AccountService.refresh_token")
|
||||
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
|
||||
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
|
||||
def test_refresh_fails_with_invalid_token(self, mock_refresh_token, mock_extract_token, app):
|
||||
"""
|
||||
Test token refresh failure with invalid refresh token.
|
||||
@ -121,8 +121,8 @@ class TestRefreshTokenApi:
|
||||
assert response["result"] == "fail"
|
||||
assert "Invalid refresh token" in response["message"]
|
||||
|
||||
@patch("controllers.console.auth.login.extract_refresh_token")
|
||||
@patch("controllers.console.auth.login.AccountService.refresh_token")
|
||||
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
|
||||
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
|
||||
def test_refresh_fails_with_expired_token(self, mock_refresh_token, mock_extract_token, app):
|
||||
"""
|
||||
Test token refresh failure with expired refresh token.
|
||||
@ -146,8 +146,8 @@ class TestRefreshTokenApi:
|
||||
assert response["result"] == "fail"
|
||||
assert "expired" in response["message"].lower()
|
||||
|
||||
@patch("controllers.console.auth.login.extract_refresh_token")
|
||||
@patch("controllers.console.auth.login.AccountService.refresh_token")
|
||||
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
|
||||
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
|
||||
def test_refresh_with_empty_token(self, mock_refresh_token, mock_extract_token, app):
|
||||
"""
|
||||
Test token refresh with empty string token.
|
||||
@ -168,8 +168,8 @@ class TestRefreshTokenApi:
|
||||
assert status_code == 401
|
||||
assert response["result"] == "fail"
|
||||
|
||||
@patch("controllers.console.auth.login.extract_refresh_token")
|
||||
@patch("controllers.console.auth.login.AccountService.refresh_token")
|
||||
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
|
||||
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
|
||||
def test_refresh_updates_all_tokens(self, mock_refresh_token, mock_extract_token, app, mock_token_pair):
|
||||
"""
|
||||
Test that token refresh updates all three tokens.
|
||||
|
||||
@ -39,10 +39,12 @@ def client():
|
||||
|
||||
|
||||
@patch(
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")
|
||||
"controllers.console.workspace.tool_providers.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u1"), "t1"),
|
||||
autospec=True,
|
||||
)
|
||||
@patch("controllers.console.workspace.tool_providers.Session")
|
||||
@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url")
|
||||
@patch("controllers.console.workspace.tool_providers.Session", autospec=True)
|
||||
@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url", autospec=True)
|
||||
@pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant")
|
||||
def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_current_account_with_tenant, client):
|
||||
# Arrange: reconnect returns tools immediately
|
||||
@ -62,7 +64,7 @@ def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_
|
||||
svc.get_provider.return_value = MagicMock(id="provider-1", tenant_id="t1") # used by reload path
|
||||
mock_session.return_value.__enter__.return_value = MagicMock()
|
||||
# Patch MCPToolManageService constructed inside controller
|
||||
with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc):
|
||||
with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc, autospec=True):
|
||||
payload = {
|
||||
"server_url": "http://example.com/mcp",
|
||||
"name": "demo",
|
||||
@ -77,12 +79,19 @@ def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_
|
||||
# Act
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), # bypass setup_required DB check
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")),
|
||||
patch("libs.login.check_csrf_token", return_value=None), # bypass CSRF in login_required
|
||||
patch("libs.login._get_user", return_value=MagicMock(id="u1", is_authenticated=True)), # login
|
||||
patch(
|
||||
"controllers.console.wraps.current_account_with_tenant",
|
||||
return_value=(MagicMock(id="u1"), "t1"),
|
||||
autospec=True,
|
||||
),
|
||||
patch("libs.login.check_csrf_token", return_value=None, autospec=True), # bypass CSRF in login_required
|
||||
patch(
|
||||
"libs.login._get_user", return_value=MagicMock(id="u1", is_authenticated=True), autospec=True
|
||||
), # login
|
||||
patch(
|
||||
"services.tools.tools_transform_service.ToolTransformService.mcp_provider_to_user_provider",
|
||||
return_value={"id": "provider-1", "tools": [{"name": "ping"}]},
|
||||
autospec=True,
|
||||
),
|
||||
):
|
||||
resp = client.post(
|
||||
|
||||
@ -77,7 +77,7 @@ class DummyResult:
|
||||
|
||||
|
||||
class TestMCPAppApi:
|
||||
@patch.object(module, "handle_mcp_request", return_value=DummyResult())
|
||||
@patch.object(module, "handle_mcp_request", return_value=DummyResult(), autospec=True)
|
||||
def test_success_request(self, mock_handle):
|
||||
fake_payload(
|
||||
{
|
||||
@ -321,7 +321,7 @@ class TestMCPAppApi:
|
||||
post_fn("server-1")
|
||||
assert "App is unavailable" in str(exc_info.value)
|
||||
|
||||
@patch.object(module, "handle_mcp_request", return_value=None)
|
||||
@patch.object(module, "handle_mcp_request", return_value=None, autospec=True)
|
||||
def test_mcp_request_no_response(self, mock_handle):
|
||||
"""Test when handle_mcp_request returns None"""
|
||||
fake_payload(
|
||||
@ -380,7 +380,7 @@ class TestMCPAppApi:
|
||||
api = module.MCPAppApi()
|
||||
api._get_mcp_server_and_app = MagicMock(return_value=(server, app))
|
||||
|
||||
with patch.object(module, "handle_mcp_request", return_value=DummyResult()):
|
||||
with patch.object(module, "handle_mcp_request", return_value=DummyResult(), autospec=True):
|
||||
post_fn = unwrap(api.post)
|
||||
response = post_fn("server-1")
|
||||
assert isinstance(response, Response)
|
||||
@ -409,7 +409,7 @@ class TestMCPAppApi:
|
||||
api = module.MCPAppApi()
|
||||
api._get_mcp_server_and_app = MagicMock(return_value=(server, app))
|
||||
|
||||
with patch.object(module, "handle_mcp_request", return_value=DummyResult()):
|
||||
with patch.object(module, "handle_mcp_request", return_value=DummyResult(), autospec=True):
|
||||
post_fn = unwrap(api.post)
|
||||
response = post_fn("server-1")
|
||||
assert isinstance(response, Response)
|
||||
|
||||
@ -12,7 +12,7 @@ from controllers.service_api.index import IndexApi
|
||||
class TestIndexApi:
|
||||
"""Test suite for IndexApi resource."""
|
||||
|
||||
@patch("controllers.service_api.index.dify_config")
|
||||
@patch("controllers.service_api.index.dify_config", autospec=True)
|
||||
def test_get_returns_api_info(self, mock_config, app):
|
||||
"""Test that GET returns API metadata with correct structure."""
|
||||
# Arrange
|
||||
|
||||
@ -71,17 +71,17 @@ class TestBaseAppRunnerMultimodal:
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class:
|
||||
# Setup mock tool file manager
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_url.return_value = mock_tool_file
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class:
|
||||
# Setup mock message file
|
||||
mock_msg_file_class.return_value = mock_message_file
|
||||
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session:
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.refresh = MagicMock()
|
||||
@ -158,17 +158,17 @@ class TestBaseAppRunnerMultimodal:
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class:
|
||||
# Setup mock tool file manager
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_raw.return_value = mock_tool_file
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class:
|
||||
# Setup mock message file
|
||||
mock_msg_file_class.return_value = mock_message_file
|
||||
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session:
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.refresh = MagicMock()
|
||||
@ -231,17 +231,17 @@ class TestBaseAppRunnerMultimodal:
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class:
|
||||
# Setup mock tool file manager
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_raw.return_value = mock_tool_file
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class:
|
||||
# Setup mock message file
|
||||
mock_msg_file_class.return_value = mock_message_file
|
||||
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session:
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.refresh = MagicMock()
|
||||
@ -282,9 +282,9 @@ class TestBaseAppRunnerMultimodal:
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class:
|
||||
with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class:
|
||||
with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session:
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
@ -321,14 +321,14 @@ class TestBaseAppRunnerMultimodal:
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class:
|
||||
# Setup mock to raise exception
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_url.side_effect = Exception("Network error")
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class:
|
||||
with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session:
|
||||
# Act
|
||||
# Create a mock runner with the method bound
|
||||
runner = MagicMock()
|
||||
@ -368,17 +368,17 @@ class TestBaseAppRunnerMultimodal:
|
||||
)
|
||||
mock_queue_manager.invoke_from = InvokeFrom.DEBUGGER
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class:
|
||||
# Setup mock tool file manager
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_url.return_value = mock_tool_file
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class:
|
||||
# Setup mock message file
|
||||
mock_msg_file_class.return_value = mock_message_file
|
||||
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session:
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.refresh = MagicMock()
|
||||
@ -420,17 +420,17 @@ class TestBaseAppRunnerMultimodal:
|
||||
)
|
||||
mock_queue_manager.invoke_from = InvokeFrom.SERVICE_API
|
||||
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
|
||||
with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class:
|
||||
# Setup mock tool file manager
|
||||
mock_mgr = MagicMock()
|
||||
mock_mgr.create_file_by_url.return_value = mock_tool_file
|
||||
mock_mgr_class.return_value = mock_mgr
|
||||
|
||||
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
|
||||
with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class:
|
||||
# Setup mock message file
|
||||
mock_msg_file_class.return_value = mock_message_file
|
||||
|
||||
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
|
||||
with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session:
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = MagicMock()
|
||||
mock_session.refresh = MagicMock()
|
||||
|
||||
@ -84,7 +84,7 @@ def mock_time():
|
||||
mock_time_val += seconds
|
||||
return mock_time_val
|
||||
|
||||
with patch("time.time", return_value=mock_time_val) as mock:
|
||||
with patch("time.time", return_value=mock_time_val, autospec=True) as mock:
|
||||
mock.increment = increment_time
|
||||
yield mock
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ from core.helper.ssrf_proxy import (
|
||||
)
|
||||
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True)
|
||||
def test_successful_request(mock_get_client):
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
@ -22,7 +22,7 @@ def test_successful_request(mock_get_client):
|
||||
mock_client.request.assert_called_once()
|
||||
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True)
|
||||
def test_retry_exceed_max_retries(mock_get_client):
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
@ -71,7 +71,7 @@ class TestGetUserProvidedHostHeader:
|
||||
assert result in ("first.com", "second.com")
|
||||
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True)
|
||||
def test_host_header_preservation_with_user_header(mock_get_client):
|
||||
"""Test that user-provided Host header is preserved in the request."""
|
||||
mock_client = MagicMock()
|
||||
@ -89,7 +89,7 @@ def test_host_header_preservation_with_user_header(mock_get_client):
|
||||
assert call_kwargs["headers"]["host"] == custom_host
|
||||
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True)
|
||||
@pytest.mark.parametrize("host_key", ["host", "HOST", "Host"])
|
||||
def test_host_header_preservation_case_insensitive(mock_get_client, host_key):
|
||||
"""Test that Host header is preserved regardless of case."""
|
||||
@ -113,7 +113,7 @@ class TestFollowRedirectsParameter:
|
||||
These tests verify that follow_redirects is correctly passed to client.request().
|
||||
"""
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True)
|
||||
def test_follow_redirects_passed_to_request(self, mock_get_client):
|
||||
"""Verify follow_redirects IS passed to client.request()."""
|
||||
mock_client = MagicMock()
|
||||
@ -128,7 +128,7 @@ class TestFollowRedirectsParameter:
|
||||
call_kwargs = mock_client.request.call_args.kwargs
|
||||
assert call_kwargs.get("follow_redirects") is True
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True)
|
||||
def test_allow_redirects_converted_to_follow_redirects(self, mock_get_client):
|
||||
"""Verify allow_redirects (requests-style) is converted to follow_redirects (httpx-style)."""
|
||||
mock_client = MagicMock()
|
||||
@ -145,7 +145,7 @@ class TestFollowRedirectsParameter:
|
||||
assert call_kwargs.get("follow_redirects") is True
|
||||
assert "allow_redirects" not in call_kwargs
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True)
|
||||
def test_follow_redirects_not_set_when_not_specified(self, mock_get_client):
|
||||
"""Verify follow_redirects is not in kwargs when not specified (httpx default behavior)."""
|
||||
mock_client = MagicMock()
|
||||
@ -160,7 +160,7 @@ class TestFollowRedirectsParameter:
|
||||
call_kwargs = mock_client.request.call_args.kwargs
|
||||
assert "follow_redirects" not in call_kwargs
|
||||
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client")
|
||||
@patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True)
|
||||
def test_follow_redirects_takes_precedence_over_allow_redirects(self, mock_get_client):
|
||||
"""Verify follow_redirects takes precedence when both are specified."""
|
||||
mock_client = MagicMock()
|
||||
|
||||
@ -72,7 +72,7 @@ class TestTraceContextFilter:
|
||||
mock_span.get_span_context.return_value = mock_context
|
||||
|
||||
with (
|
||||
mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span),
|
||||
mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span, autospec=True),
|
||||
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
|
||||
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
|
||||
):
|
||||
@ -108,7 +108,9 @@ class TestIdentityContextFilter:
|
||||
filter = IdentityContextFilter()
|
||||
|
||||
# Should not raise even if something goes wrong
|
||||
with mock.patch("core.logging.filters.flask.has_request_context", side_effect=Exception("Test error")):
|
||||
with mock.patch(
|
||||
"core.logging.filters.flask.has_request_context", side_effect=Exception("Test error"), autospec=True
|
||||
):
|
||||
result = filter.filter(log_record)
|
||||
assert result is True
|
||||
assert log_record.tenant_id == ""
|
||||
|
||||
@ -8,7 +8,7 @@ class TestGetSpanIdFromOtelContext:
|
||||
def test_returns_none_without_span(self):
|
||||
from core.helper.trace_id_helper import get_span_id_from_otel_context
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=None, autospec=True):
|
||||
result = get_span_id_from_otel_context()
|
||||
assert result is None
|
||||
|
||||
@ -20,7 +20,7 @@ class TestGetSpanIdFromOtelContext:
|
||||
mock_context.span_id = 0x051581BF3BB55C45
|
||||
mock_span.get_span_context.return_value = mock_context
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span, autospec=True):
|
||||
with mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0):
|
||||
result = get_span_id_from_otel_context()
|
||||
assert result == "051581bf3bb55c45"
|
||||
@ -28,7 +28,7 @@ class TestGetSpanIdFromOtelContext:
|
||||
def test_returns_none_on_exception(self):
|
||||
from core.helper.trace_id_helper import get_span_id_from_otel_context
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception("Test error")):
|
||||
with mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception("Test error"), autospec=True):
|
||||
result = get_span_id_from_otel_context()
|
||||
assert result is None
|
||||
|
||||
@ -37,7 +37,7 @@ class TestGenerateTraceparentHeader:
|
||||
def test_generates_valid_format(self):
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=None, autospec=True):
|
||||
result = generate_traceparent_header()
|
||||
|
||||
assert result is not None
|
||||
@ -58,7 +58,7 @@ class TestGenerateTraceparentHeader:
|
||||
mock_context.span_id = 0x051581BF3BB55C45
|
||||
mock_span.get_span_context.return_value = mock_context
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span, autospec=True):
|
||||
with (
|
||||
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
|
||||
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
|
||||
@ -70,7 +70,7 @@ class TestGenerateTraceparentHeader:
|
||||
def test_generates_hex_only_values(self):
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
|
||||
with mock.patch("opentelemetry.trace.get_current_span", return_value=None, autospec=True):
|
||||
result = generate_traceparent_header()
|
||||
|
||||
parts = result.split("-")
|
||||
|
||||
@ -32,7 +32,7 @@ class TestConstants:
|
||||
class TestCreateSSRFProxyMCPHTTPClient:
|
||||
"""Test create_ssrf_proxy_mcp_http_client function."""
|
||||
|
||||
@patch("core.mcp.utils.dify_config")
|
||||
@patch("core.mcp.utils.dify_config", autospec=True)
|
||||
def test_create_client_with_all_url_proxy(self, mock_config):
|
||||
"""Test client creation with SSRF_PROXY_ALL_URL configured."""
|
||||
mock_config.SSRF_PROXY_ALL_URL = "http://proxy.example.com:8080"
|
||||
@ -50,7 +50,7 @@ class TestCreateSSRFProxyMCPHTTPClient:
|
||||
# Clean up
|
||||
client.close()
|
||||
|
||||
@patch("core.mcp.utils.dify_config")
|
||||
@patch("core.mcp.utils.dify_config", autospec=True)
|
||||
def test_create_client_with_http_https_proxies(self, mock_config):
|
||||
"""Test client creation with separate HTTP/HTTPS proxies."""
|
||||
mock_config.SSRF_PROXY_ALL_URL = None
|
||||
@ -66,7 +66,7 @@ class TestCreateSSRFProxyMCPHTTPClient:
|
||||
# Clean up
|
||||
client.close()
|
||||
|
||||
@patch("core.mcp.utils.dify_config")
|
||||
@patch("core.mcp.utils.dify_config", autospec=True)
|
||||
def test_create_client_without_proxy(self, mock_config):
|
||||
"""Test client creation without proxy configuration."""
|
||||
mock_config.SSRF_PROXY_ALL_URL = None
|
||||
@ -88,7 +88,7 @@ class TestCreateSSRFProxyMCPHTTPClient:
|
||||
# Clean up
|
||||
client.close()
|
||||
|
||||
@patch("core.mcp.utils.dify_config")
|
||||
@patch("core.mcp.utils.dify_config", autospec=True)
|
||||
def test_create_client_default_params(self, mock_config):
|
||||
"""Test client creation with default parameters."""
|
||||
mock_config.SSRF_PROXY_ALL_URL = None
|
||||
@ -111,8 +111,8 @@ class TestCreateSSRFProxyMCPHTTPClient:
|
||||
class TestSSRFProxySSEConnect:
|
||||
"""Test ssrf_proxy_sse_connect function."""
|
||||
|
||||
@patch("core.mcp.utils.connect_sse")
|
||||
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
|
||||
@patch("core.mcp.utils.connect_sse", autospec=True)
|
||||
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client", autospec=True)
|
||||
def test_sse_connect_with_provided_client(self, mock_create_client, mock_connect_sse):
|
||||
"""Test SSE connection with pre-configured client."""
|
||||
# Setup mocks
|
||||
@ -138,9 +138,9 @@ class TestSSRFProxySSEConnect:
|
||||
# Verify result
|
||||
assert result == mock_context
|
||||
|
||||
@patch("core.mcp.utils.connect_sse")
|
||||
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
|
||||
@patch("core.mcp.utils.dify_config")
|
||||
@patch("core.mcp.utils.connect_sse", autospec=True)
|
||||
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client", autospec=True)
|
||||
@patch("core.mcp.utils.dify_config", autospec=True)
|
||||
def test_sse_connect_without_client(self, mock_config, mock_create_client, mock_connect_sse):
|
||||
"""Test SSE connection without pre-configured client."""
|
||||
# Setup config
|
||||
@ -183,8 +183,8 @@ class TestSSRFProxySSEConnect:
|
||||
# Verify result
|
||||
assert result == mock_context
|
||||
|
||||
@patch("core.mcp.utils.connect_sse")
|
||||
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
|
||||
@patch("core.mcp.utils.connect_sse", autospec=True)
|
||||
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client", autospec=True)
|
||||
def test_sse_connect_with_custom_timeout(self, mock_create_client, mock_connect_sse):
|
||||
"""Test SSE connection with custom timeout."""
|
||||
# Setup mocks
|
||||
@ -209,8 +209,8 @@ class TestSSRFProxySSEConnect:
|
||||
# Verify result
|
||||
assert result == mock_context
|
||||
|
||||
@patch("core.mcp.utils.connect_sse")
|
||||
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
|
||||
@patch("core.mcp.utils.connect_sse", autospec=True)
|
||||
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client", autospec=True)
|
||||
def test_sse_connect_error_cleanup(self, mock_create_client, mock_connect_sse):
|
||||
"""Test SSE connection cleans up client on error."""
|
||||
# Setup mocks
|
||||
@ -227,7 +227,7 @@ class TestSSRFProxySSEConnect:
|
||||
# Verify client was cleaned up
|
||||
mock_client.close.assert_called_once()
|
||||
|
||||
@patch("core.mcp.utils.connect_sse")
|
||||
@patch("core.mcp.utils.connect_sse", autospec=True)
|
||||
def test_sse_connect_error_no_cleanup_with_provided_client(self, mock_connect_sse):
|
||||
"""Test SSE connection doesn't clean up provided client on error."""
|
||||
# Setup mocks
|
||||
|
||||
@ -324,7 +324,7 @@ class TestOpenAIModeration:
|
||||
with pytest.raises(ValueError, match="At least one of inputs_config or outputs_config must be enabled"):
|
||||
OpenAIModeration.validate_config("test-tenant", config)
|
||||
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
|
||||
def test_moderation_for_inputs_no_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration):
|
||||
"""Test input moderation when OpenAI API returns no violations."""
|
||||
# Mock the model manager and instance
|
||||
@ -341,7 +341,7 @@ class TestOpenAIModeration:
|
||||
assert result.action == ModerationAction.DIRECT_OUTPUT
|
||||
assert result.preset_response == "Content flagged by OpenAI moderation."
|
||||
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
|
||||
def test_moderation_for_inputs_with_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration):
|
||||
"""Test input moderation when OpenAI API detects violations."""
|
||||
# Mock the model manager to return violation
|
||||
@ -358,7 +358,7 @@ class TestOpenAIModeration:
|
||||
assert result.action == ModerationAction.DIRECT_OUTPUT
|
||||
assert result.preset_response == "Content flagged by OpenAI moderation."
|
||||
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
|
||||
def test_moderation_for_inputs_query_included(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration):
|
||||
"""Test that query is included in moderation check with special key."""
|
||||
mock_instance = MagicMock()
|
||||
@ -385,7 +385,7 @@ class TestOpenAIModeration:
|
||||
assert "u" in moderated_text
|
||||
assert "e" in moderated_text
|
||||
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
|
||||
def test_moderation_for_inputs_disabled(self, mock_model_manager: Mock):
|
||||
"""Test input moderation when inputs_config is disabled."""
|
||||
config = {
|
||||
@ -400,7 +400,7 @@ class TestOpenAIModeration:
|
||||
# Should not call the API when disabled
|
||||
mock_model_manager.assert_not_called()
|
||||
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
|
||||
def test_moderation_for_outputs_no_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration):
|
||||
"""Test output moderation when OpenAI API returns no violations."""
|
||||
mock_instance = MagicMock()
|
||||
@ -414,7 +414,7 @@ class TestOpenAIModeration:
|
||||
assert result.action == ModerationAction.DIRECT_OUTPUT
|
||||
assert result.preset_response == "Response blocked by moderation."
|
||||
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
|
||||
def test_moderation_for_outputs_with_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration):
|
||||
"""Test output moderation when OpenAI API detects violations."""
|
||||
mock_instance = MagicMock()
|
||||
@ -427,7 +427,7 @@ class TestOpenAIModeration:
|
||||
assert result.flagged is True
|
||||
assert result.action == ModerationAction.DIRECT_OUTPUT
|
||||
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
|
||||
def test_moderation_for_outputs_disabled(self, mock_model_manager: Mock):
|
||||
"""Test output moderation when outputs_config is disabled."""
|
||||
config = {
|
||||
@ -441,7 +441,7 @@ class TestOpenAIModeration:
|
||||
assert result.flagged is False
|
||||
mock_model_manager.assert_not_called()
|
||||
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
|
||||
def test_model_manager_called_with_correct_params(
|
||||
self, mock_model_manager: Mock, openai_moderation: OpenAIModeration
|
||||
):
|
||||
@ -494,7 +494,7 @@ class TestModerationRuleStructure:
|
||||
class TestModerationFactoryIntegration:
|
||||
"""Test suite for ModerationFactory integration."""
|
||||
|
||||
@patch("core.moderation.factory.code_based_extension")
|
||||
@patch("core.moderation.factory.code_based_extension", autospec=True)
|
||||
def test_factory_delegates_to_extension(self, mock_extension: Mock):
|
||||
"""Test ModerationFactory delegates to extension system."""
|
||||
from core.moderation.factory import ModerationFactory
|
||||
@ -518,7 +518,7 @@ class TestModerationFactoryIntegration:
|
||||
assert result.flagged is False
|
||||
mock_instance.moderation_for_inputs.assert_called_once()
|
||||
|
||||
@patch("core.moderation.factory.code_based_extension")
|
||||
@patch("core.moderation.factory.code_based_extension", autospec=True)
|
||||
def test_factory_validate_config_delegates(self, mock_extension: Mock):
|
||||
"""Test ModerationFactory.validate_config delegates to extension."""
|
||||
from core.moderation.factory import ModerationFactory
|
||||
@ -629,7 +629,7 @@ class TestPresetManagement:
|
||||
assert result.flagged is True
|
||||
assert result.preset_response == "Custom output blocked message"
|
||||
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
|
||||
def test_openai_preset_response_in_inputs(self, mock_model_manager: Mock):
|
||||
"""Test preset response is properly returned for OpenAI input violations."""
|
||||
mock_instance = MagicMock()
|
||||
@ -650,7 +650,7 @@ class TestPresetManagement:
|
||||
assert result.flagged is True
|
||||
assert result.preset_response == "OpenAI input blocked"
|
||||
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
|
||||
def test_openai_preset_response_in_outputs(self, mock_model_manager: Mock):
|
||||
"""Test preset response is properly returned for OpenAI output violations."""
|
||||
mock_instance = MagicMock()
|
||||
@ -989,7 +989,7 @@ class TestOpenAIModerationAdvanced:
|
||||
- Performance considerations
|
||||
"""
|
||||
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
|
||||
def test_openai_api_timeout_handling(self, mock_model_manager: Mock):
|
||||
"""
|
||||
Test graceful handling of OpenAI API timeouts.
|
||||
@ -1012,7 +1012,7 @@ class TestOpenAIModerationAdvanced:
|
||||
with pytest.raises(TimeoutError):
|
||||
moderation.moderation_for_inputs({"text": "test"}, "")
|
||||
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
|
||||
def test_openai_api_rate_limit_handling(self, mock_model_manager: Mock):
|
||||
"""
|
||||
Test handling of OpenAI API rate limit errors.
|
||||
@ -1035,7 +1035,7 @@ class TestOpenAIModerationAdvanced:
|
||||
with pytest.raises(Exception, match="Rate limit exceeded"):
|
||||
moderation.moderation_for_inputs({"text": "test"}, "")
|
||||
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
|
||||
def test_openai_with_multiple_input_fields(self, mock_model_manager: Mock):
|
||||
"""
|
||||
Test OpenAI moderation with multiple input fields.
|
||||
@ -1079,7 +1079,7 @@ class TestOpenAIModerationAdvanced:
|
||||
assert "u" in moderated_text
|
||||
assert "e" in moderated_text
|
||||
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
|
||||
def test_openai_empty_text_handling(self, mock_model_manager: Mock):
|
||||
"""
|
||||
Test OpenAI moderation with empty text inputs.
|
||||
@ -1103,7 +1103,7 @@ class TestOpenAIModerationAdvanced:
|
||||
assert result.flagged is False
|
||||
mock_instance.invoke_moderation.assert_called_once()
|
||||
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
|
||||
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
|
||||
def test_openai_model_instance_fetched_on_each_call(self, mock_model_manager: Mock):
|
||||
"""
|
||||
Test that ModelManager fetches a fresh model instance on each call.
|
||||
|
||||
@ -64,7 +64,7 @@ class TestPluginEndpointClientDelete:
|
||||
"data": True,
|
||||
}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act
|
||||
result = endpoint_client.delete_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
@ -102,7 +102,7 @@ class TestPluginEndpointClientDelete:
|
||||
),
|
||||
}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act
|
||||
result = endpoint_client.delete_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
@ -139,7 +139,7 @@ class TestPluginEndpointClientDelete:
|
||||
),
|
||||
}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(PluginDaemonInternalServerError) as exc_info:
|
||||
endpoint_client.delete_endpoint(
|
||||
@ -174,7 +174,7 @@ class TestPluginEndpointClientDelete:
|
||||
"message": '{"error_type": "PluginDaemonInternalServerError", "message": "Record Not Found"}',
|
||||
}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act
|
||||
result = endpoint_client.delete_endpoint(
|
||||
tenant_id=tenant_id,
|
||||
@ -222,7 +222,7 @@ class TestPluginEndpointClientDelete:
|
||||
),
|
||||
}
|
||||
|
||||
with patch("httpx.request") as mock_request:
|
||||
with patch("httpx.request", autospec=True) as mock_request:
|
||||
# Act - first call
|
||||
mock_request.return_value = mock_response_success
|
||||
result1 = endpoint_client.delete_endpoint(
|
||||
@ -266,7 +266,7 @@ class TestPluginEndpointClientDelete:
|
||||
"message": '{"error_type": "PluginDaemonUnauthorizedError", "message": "unauthorized access"}',
|
||||
}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
endpoint_client.delete_endpoint(
|
||||
|
||||
@ -114,7 +114,7 @@ class TestPluginRuntimeExecution:
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"result": "success"}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response) as mock_request:
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
|
||||
# Act
|
||||
response = plugin_client._request("GET", "plugin/test-tenant/management/list")
|
||||
|
||||
@ -132,7 +132,7 @@ class TestPluginRuntimeExecution:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.request", return_value=mock_response) as mock_request:
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
|
||||
# Act
|
||||
plugin_client._request("GET", "plugin/test-tenant/test")
|
||||
|
||||
@ -143,7 +143,7 @@ class TestPluginRuntimeExecution:
|
||||
def test_request_connection_error(self, plugin_client, mock_config):
|
||||
"""Test handling of connection errors during request."""
|
||||
# Arrange
|
||||
with patch("httpx.request", side_effect=httpx.RequestError("Connection failed")):
|
||||
with patch("httpx.request", side_effect=httpx.RequestError("Connection failed"), autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(PluginDaemonInnerError) as exc_info:
|
||||
plugin_client._request("GET", "plugin/test-tenant/test")
|
||||
@ -182,7 +182,7 @@ class TestPluginRuntimeSandboxIsolation:
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"code": 0, "message": "", "data": True}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response) as mock_request:
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
|
||||
# Act
|
||||
plugin_client._request("GET", "plugin/test-tenant/test")
|
||||
|
||||
@ -201,7 +201,7 @@ class TestPluginRuntimeSandboxIsolation:
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"code": 0, "message": "", "data": {"result": "isolated_execution"}}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act
|
||||
result = plugin_client._request_with_plugin_daemon_response(
|
||||
"POST", "plugin/test-tenant/dispatch/tool/invoke", TestResponse, data={"tool": "test"}
|
||||
@ -218,7 +218,7 @@ class TestPluginRuntimeSandboxIsolation:
|
||||
error_message = json.dumps({"error_type": "PluginDaemonUnauthorizedError", "message": "Unauthorized access"})
|
||||
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(PluginDaemonUnauthorizedError) as exc_info:
|
||||
plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool)
|
||||
@ -234,7 +234,7 @@ class TestPluginRuntimeSandboxIsolation:
|
||||
)
|
||||
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(PluginPermissionDeniedError) as exc_info:
|
||||
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/test", bool)
|
||||
@ -272,7 +272,7 @@ class TestPluginRuntimeResourceLimits:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.request", return_value=mock_response) as mock_request:
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
|
||||
# Act
|
||||
plugin_client._request("GET", "plugin/test-tenant/test")
|
||||
|
||||
@ -283,7 +283,7 @@ class TestPluginRuntimeResourceLimits:
|
||||
def test_timeout_error_handling(self, plugin_client, mock_config):
|
||||
"""Test handling of timeout errors."""
|
||||
# Arrange
|
||||
with patch("httpx.request", side_effect=httpx.TimeoutException("Request timeout")):
|
||||
with patch("httpx.request", side_effect=httpx.TimeoutException("Request timeout"), autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(PluginDaemonInnerError) as exc_info:
|
||||
plugin_client._request("GET", "plugin/test-tenant/test")
|
||||
@ -292,7 +292,7 @@ class TestPluginRuntimeResourceLimits:
|
||||
def test_streaming_request_timeout(self, plugin_client, mock_config):
|
||||
"""Test timeout handling for streaming requests."""
|
||||
# Arrange
|
||||
with patch("httpx.stream", side_effect=httpx.TimeoutException("Stream timeout")):
|
||||
with patch("httpx.stream", side_effect=httpx.TimeoutException("Stream timeout"), autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(PluginDaemonInnerError) as exc_info:
|
||||
list(plugin_client._stream_request("POST", "plugin/test-tenant/stream"))
|
||||
@ -308,7 +308,7 @@ class TestPluginRuntimeResourceLimits:
|
||||
)
|
||||
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(PluginDaemonInternalServerError) as exc_info:
|
||||
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/test", bool)
|
||||
@ -352,7 +352,7 @@ class TestPluginRuntimeErrorHandling:
|
||||
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
|
||||
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(InvokeRateLimitError) as exc_info:
|
||||
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/invoke", bool)
|
||||
@ -371,7 +371,7 @@ class TestPluginRuntimeErrorHandling:
|
||||
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
|
||||
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(InvokeAuthorizationError) as exc_info:
|
||||
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/invoke", bool)
|
||||
@ -390,7 +390,7 @@ class TestPluginRuntimeErrorHandling:
|
||||
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
|
||||
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(InvokeBadRequestError) as exc_info:
|
||||
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/invoke", bool)
|
||||
@ -409,7 +409,7 @@ class TestPluginRuntimeErrorHandling:
|
||||
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
|
||||
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(InvokeConnectionError) as exc_info:
|
||||
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/invoke", bool)
|
||||
@ -428,7 +428,7 @@ class TestPluginRuntimeErrorHandling:
|
||||
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
|
||||
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(InvokeServerUnavailableError) as exc_info:
|
||||
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/invoke", bool)
|
||||
@ -446,7 +446,7 @@ class TestPluginRuntimeErrorHandling:
|
||||
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
|
||||
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(CredentialsValidateFailedError) as exc_info:
|
||||
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/validate", bool)
|
||||
@ -462,7 +462,7 @@ class TestPluginRuntimeErrorHandling:
|
||||
)
|
||||
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(PluginNotFoundError) as exc_info:
|
||||
plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/get", bool)
|
||||
@ -478,7 +478,7 @@ class TestPluginRuntimeErrorHandling:
|
||||
)
|
||||
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(PluginUniqueIdentifierError) as exc_info:
|
||||
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/install", bool)
|
||||
@ -494,7 +494,7 @@ class TestPluginRuntimeErrorHandling:
|
||||
)
|
||||
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(PluginDaemonBadRequestError) as exc_info:
|
||||
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/test", bool)
|
||||
@ -508,7 +508,7 @@ class TestPluginRuntimeErrorHandling:
|
||||
error_message = json.dumps({"error_type": "PluginDaemonNotFoundError", "message": "Resource not found"})
|
||||
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(PluginDaemonNotFoundError) as exc_info:
|
||||
plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/resource", bool)
|
||||
@ -526,7 +526,7 @@ class TestPluginRuntimeErrorHandling:
|
||||
error_message = json.dumps({"error_type": "PluginInvokeError", "message": invoke_error_message})
|
||||
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(PluginInvokeError) as exc_info:
|
||||
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/invoke", bool)
|
||||
@ -540,7 +540,7 @@ class TestPluginRuntimeErrorHandling:
|
||||
error_message = json.dumps({"error_type": "UnknownErrorType", "message": "Unknown error occurred"})
|
||||
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/test", bool)
|
||||
@ -555,7 +555,7 @@ class TestPluginRuntimeErrorHandling:
|
||||
"Server Error", request=MagicMock(), response=mock_response
|
||||
)
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool)
|
||||
@ -567,7 +567,7 @@ class TestPluginRuntimeErrorHandling:
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"code": 0, "message": "", "data": None}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool)
|
||||
@ -610,7 +610,7 @@ class TestPluginRuntimeCommunication:
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"code": 0, "message": "", "data": {"value": "test", "count": 42}}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act
|
||||
result = plugin_client._request_with_plugin_daemon_response(
|
||||
"POST", "plugin/test-tenant/test", TestModel, data={"input": "data"}
|
||||
@ -637,7 +637,7 @@ class TestPluginRuntimeCommunication:
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
|
||||
|
||||
with patch("httpx.stream") as mock_stream:
|
||||
with patch("httpx.stream", autospec=True) as mock_stream:
|
||||
mock_stream.return_value.__enter__.return_value = mock_response
|
||||
|
||||
# Act
|
||||
@ -667,7 +667,7 @@ class TestPluginRuntimeCommunication:
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
|
||||
|
||||
with patch("httpx.stream") as mock_stream:
|
||||
with patch("httpx.stream", autospec=True) as mock_stream:
|
||||
mock_stream.return_value.__enter__.return_value = mock_response
|
||||
|
||||
# Act
|
||||
@ -689,7 +689,7 @@ class TestPluginRuntimeCommunication:
|
||||
def test_streaming_connection_error(self, plugin_client, mock_config):
|
||||
"""Test connection error during streaming."""
|
||||
# Arrange
|
||||
with patch("httpx.stream", side_effect=httpx.RequestError("Stream connection failed")):
|
||||
with patch("httpx.stream", side_effect=httpx.RequestError("Stream connection failed"), autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(PluginDaemonInnerError) as exc_info:
|
||||
list(plugin_client._stream_request("POST", "plugin/test-tenant/stream"))
|
||||
@ -707,7 +707,7 @@ class TestPluginRuntimeCommunication:
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"status": "success", "data": {"key": "value"}}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act
|
||||
result = plugin_client._request_with_model("GET", "plugin/test-tenant/direct", DirectModel)
|
||||
|
||||
@ -732,7 +732,7 @@ class TestPluginRuntimeCommunication:
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
|
||||
|
||||
with patch("httpx.stream") as mock_stream:
|
||||
with patch("httpx.stream", autospec=True) as mock_stream:
|
||||
mock_stream.return_value.__enter__.return_value = mock_response
|
||||
|
||||
# Act
|
||||
@ -764,7 +764,7 @@ class TestPluginRuntimeCommunication:
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
|
||||
|
||||
with patch("httpx.stream") as mock_stream:
|
||||
with patch("httpx.stream", autospec=True) as mock_stream:
|
||||
mock_stream.return_value.__enter__.return_value = mock_response
|
||||
|
||||
# Act
|
||||
@ -814,7 +814,7 @@ class TestPluginToolManagerIntegration:
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
|
||||
|
||||
with patch("httpx.stream") as mock_stream:
|
||||
with patch("httpx.stream", autospec=True) as mock_stream:
|
||||
mock_stream.return_value.__enter__.return_value = mock_response
|
||||
|
||||
# Act
|
||||
@ -844,7 +844,7 @@ class TestPluginToolManagerIntegration:
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
|
||||
|
||||
with patch("httpx.stream") as mock_stream:
|
||||
with patch("httpx.stream", autospec=True) as mock_stream:
|
||||
mock_stream.return_value.__enter__.return_value = mock_response
|
||||
|
||||
# Act
|
||||
@ -868,7 +868,7 @@ class TestPluginToolManagerIntegration:
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
|
||||
|
||||
with patch("httpx.stream") as mock_stream:
|
||||
with patch("httpx.stream", autospec=True) as mock_stream:
|
||||
mock_stream.return_value.__enter__.return_value = mock_response
|
||||
|
||||
# Act
|
||||
@ -892,7 +892,7 @@ class TestPluginToolManagerIntegration:
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
|
||||
|
||||
with patch("httpx.stream") as mock_stream:
|
||||
with patch("httpx.stream", autospec=True) as mock_stream:
|
||||
mock_stream.return_value.__enter__.return_value = mock_response
|
||||
|
||||
# Act
|
||||
@ -945,7 +945,7 @@ class TestPluginInstallerIntegration:
|
||||
},
|
||||
}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act
|
||||
result = installer.list_plugins("test-tenant")
|
||||
|
||||
@ -959,7 +959,7 @@ class TestPluginInstallerIntegration:
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"code": 0, "message": "", "data": True}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act
|
||||
result = installer.uninstall("test-tenant", "plugin-installation-id")
|
||||
|
||||
@ -973,7 +973,7 @@ class TestPluginInstallerIntegration:
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"code": 0, "message": "", "data": True}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act
|
||||
result = installer.fetch_plugin_by_identifier("test-tenant", "plugin-identifier")
|
||||
|
||||
@ -1012,7 +1012,7 @@ class TestPluginRuntimeEdgeCases:
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool)
|
||||
@ -1025,7 +1025,7 @@ class TestPluginRuntimeEdgeCases:
|
||||
# Missing required fields in response
|
||||
mock_response.json.return_value = {"invalid": "structure"}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool)
|
||||
@ -1041,7 +1041,7 @@ class TestPluginRuntimeEdgeCases:
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
|
||||
|
||||
with patch("httpx.stream") as mock_stream:
|
||||
with patch("httpx.stream", autospec=True) as mock_stream:
|
||||
mock_stream.return_value.__enter__.return_value = mock_response
|
||||
|
||||
# Act
|
||||
@ -1065,7 +1065,7 @@ class TestPluginRuntimeEdgeCases:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.request", return_value=mock_response) as mock_request:
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
|
||||
# Act
|
||||
plugin_client._request("POST", "plugin/test-tenant/upload", data=b"binary data")
|
||||
|
||||
@ -1081,7 +1081,7 @@ class TestPluginRuntimeEdgeCases:
|
||||
|
||||
files = {"file": ("test.txt", b"file content", "text/plain")}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response) as mock_request:
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
|
||||
# Act
|
||||
plugin_client._request("POST", "plugin/test-tenant/upload", files=files)
|
||||
|
||||
@ -1095,7 +1095,7 @@ class TestPluginRuntimeEdgeCases:
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.return_value = []
|
||||
|
||||
with patch("httpx.stream") as mock_stream:
|
||||
with patch("httpx.stream", autospec=True) as mock_stream:
|
||||
mock_stream.return_value.__enter__.return_value = mock_response
|
||||
|
||||
# Act
|
||||
@ -1115,7 +1115,7 @@ class TestPluginRuntimeEdgeCases:
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
|
||||
|
||||
with patch("httpx.stream") as mock_stream:
|
||||
with patch("httpx.stream", autospec=True) as mock_stream:
|
||||
mock_stream.return_value.__enter__.return_value = mock_response
|
||||
|
||||
# Act & Assert
|
||||
@ -1136,7 +1136,7 @@ class TestPluginRuntimeEdgeCases:
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"code": -1, "message": "Plain text error message", "data": None}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool)
|
||||
@ -1174,7 +1174,7 @@ class TestPluginRuntimeAdvancedScenarios:
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"code": 0, "message": "", "data": True}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response) as mock_request:
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
|
||||
# Act
|
||||
for i in range(5):
|
||||
result = plugin_client._request_with_plugin_daemon_response("GET", f"plugin/test-tenant/test/{i}", bool)
|
||||
@ -1203,7 +1203,7 @@ class TestPluginRuntimeAdvancedScenarios:
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"code": 0, "message": "", "data": complex_data}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act
|
||||
result = plugin_client._request_with_plugin_daemon_response(
|
||||
"POST", "plugin/test-tenant/complex", ComplexModel
|
||||
@ -1231,7 +1231,7 @@ class TestPluginRuntimeAdvancedScenarios:
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
|
||||
|
||||
with patch("httpx.stream") as mock_stream:
|
||||
with patch("httpx.stream", autospec=True) as mock_stream:
|
||||
mock_stream.return_value.__enter__.return_value = mock_response
|
||||
|
||||
# Act
|
||||
@ -1262,7 +1262,7 @@ class TestPluginRuntimeAdvancedScenarios:
|
||||
mock_response.status_code = 200
|
||||
return mock_response
|
||||
|
||||
with patch("httpx.request", side_effect=side_effect):
|
||||
with patch("httpx.request", side_effect=side_effect, autospec=True):
|
||||
# Act & Assert - First two calls should fail
|
||||
with pytest.raises(PluginDaemonInnerError):
|
||||
plugin_client._request("GET", "plugin/test-tenant/test")
|
||||
@ -1286,7 +1286,7 @@ class TestPluginRuntimeAdvancedScenarios:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.request", return_value=mock_response) as mock_request:
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
|
||||
# Act
|
||||
plugin_client._request("GET", "plugin/test-tenant/test", headers=custom_headers)
|
||||
|
||||
@ -1312,7 +1312,7 @@ class TestPluginRuntimeAdvancedScenarios:
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
|
||||
|
||||
with patch("httpx.stream") as mock_stream:
|
||||
with patch("httpx.stream", autospec=True) as mock_stream:
|
||||
mock_stream.return_value.__enter__.return_value = mock_response
|
||||
|
||||
# Act
|
||||
@ -1359,7 +1359,7 @@ class TestPluginRuntimeSecurityAndValidation:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.request", return_value=mock_response) as mock_request:
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
|
||||
# Act
|
||||
plugin_client._request("GET", "plugin/test-tenant/test")
|
||||
|
||||
@ -1381,7 +1381,7 @@ class TestPluginRuntimeSecurityAndValidation:
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"code": 0, "message": "", "data": True}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response) as mock_request:
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
|
||||
# Act
|
||||
plugin_client._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
@ -1403,7 +1403,7 @@ class TestPluginRuntimeSecurityAndValidation:
|
||||
error_message = json.dumps({"error_type": "PluginDaemonUnauthorizedError", "message": "Invalid API key"})
|
||||
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(PluginDaemonUnauthorizedError) as exc_info:
|
||||
plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool)
|
||||
@ -1424,7 +1424,7 @@ class TestPluginRuntimeSecurityAndValidation:
|
||||
)
|
||||
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(PluginDaemonBadRequestError) as exc_info:
|
||||
plugin_client._request_with_plugin_daemon_response(
|
||||
@ -1438,7 +1438,7 @@ class TestPluginRuntimeSecurityAndValidation:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
|
||||
with patch("httpx.request", return_value=mock_response) as mock_request:
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
|
||||
# Act
|
||||
plugin_client._request(
|
||||
"POST", "plugin/test-tenant/test", headers={"Content-Type": "application/json"}, data={"key": "value"}
|
||||
@ -1489,7 +1489,7 @@ class TestPluginRuntimePerformanceScenarios:
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
|
||||
|
||||
with patch("httpx.stream") as mock_stream:
|
||||
with patch("httpx.stream", autospec=True) as mock_stream:
|
||||
mock_stream.return_value.__enter__.return_value = mock_response
|
||||
|
||||
# Act
|
||||
@ -1524,7 +1524,7 @@ class TestPluginRuntimePerformanceScenarios:
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
|
||||
|
||||
with patch("httpx.stream") as mock_stream:
|
||||
with patch("httpx.stream", autospec=True) as mock_stream:
|
||||
mock_stream.return_value.__enter__.return_value = mock_response
|
||||
|
||||
# Act - Process chunks one by one
|
||||
@ -1539,7 +1539,7 @@ class TestPluginRuntimePerformanceScenarios:
|
||||
def test_timeout_with_slow_response(self, plugin_client, mock_config):
|
||||
"""Test timeout handling with slow response simulation."""
|
||||
# Arrange
|
||||
with patch("httpx.request", side_effect=httpx.TimeoutException("Request timed out after 30s")):
|
||||
with patch("httpx.request", side_effect=httpx.TimeoutException("Request timed out after 30s"), autospec=True):
|
||||
# Act & Assert
|
||||
with pytest.raises(PluginDaemonInnerError) as exc_info:
|
||||
plugin_client._request("GET", "plugin/test-tenant/slow-endpoint")
|
||||
@ -1554,7 +1554,7 @@ class TestPluginRuntimePerformanceScenarios:
|
||||
|
||||
request_results = []
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act - Simulate 10 concurrent requests
|
||||
for i in range(10):
|
||||
result = plugin_client._request_with_plugin_daemon_response(
|
||||
@ -1612,7 +1612,7 @@ class TestPluginToolManagerAdvanced:
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
|
||||
|
||||
with patch("httpx.stream") as mock_stream:
|
||||
with patch("httpx.stream", autospec=True) as mock_stream:
|
||||
mock_stream.return_value.__enter__.return_value = mock_response
|
||||
|
||||
# Act
|
||||
@ -1641,7 +1641,7 @@ class TestPluginToolManagerAdvanced:
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
|
||||
|
||||
with patch("httpx.stream") as mock_stream:
|
||||
with patch("httpx.stream", autospec=True) as mock_stream:
|
||||
mock_stream.return_value.__enter__.return_value = mock_response
|
||||
|
||||
# Act
|
||||
@ -1673,7 +1673,7 @@ class TestPluginToolManagerAdvanced:
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
|
||||
|
||||
with patch("httpx.stream") as mock_stream:
|
||||
with patch("httpx.stream", autospec=True) as mock_stream:
|
||||
mock_stream.return_value.__enter__.return_value = mock_response
|
||||
|
||||
# Act
|
||||
@ -1704,7 +1704,7 @@ class TestPluginToolManagerAdvanced:
|
||||
mock_response = MagicMock()
|
||||
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
|
||||
|
||||
with patch("httpx.stream") as mock_stream:
|
||||
with patch("httpx.stream", autospec=True) as mock_stream:
|
||||
mock_stream.return_value.__enter__.return_value = mock_response
|
||||
|
||||
# Act
|
||||
@ -1770,7 +1770,7 @@ class TestPluginInstallerAdvanced:
|
||||
},
|
||||
}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act
|
||||
result = installer.upload_pkg("test-tenant", plugin_package, verify_signature=False)
|
||||
|
||||
@ -1788,7 +1788,7 @@ class TestPluginInstallerAdvanced:
|
||||
"data": {"content": "# Plugin README\n\nThis is a test plugin.", "language": "en"},
|
||||
}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act
|
||||
result = installer.fetch_plugin_readme("test-tenant", "test-org/test-plugin", "en")
|
||||
|
||||
@ -1807,7 +1807,7 @@ class TestPluginInstallerAdvanced:
|
||||
|
||||
mock_response.raise_for_status = raise_for_status
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act & Assert - Should raise HTTPStatusError for 404
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
installer.fetch_plugin_readme("test-tenant", "test-org/test-plugin", "en")
|
||||
@ -1826,7 +1826,7 @@ class TestPluginInstallerAdvanced:
|
||||
},
|
||||
}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act
|
||||
result = installer.list_plugins_with_total("test-tenant", page=2, page_size=20)
|
||||
|
||||
@ -1848,7 +1848,7 @@ class TestPluginInstallerAdvanced:
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"code": 0, "message": "", "data": [True, False]}
|
||||
|
||||
with patch("httpx.request", return_value=mock_response):
|
||||
with patch("httpx.request", return_value=mock_response, autospec=True):
|
||||
# Act
|
||||
result = installer.check_tools_existence("test-tenant", provider_ids)
|
||||
|
||||
|
||||
@ -142,7 +142,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
|
||||
|
||||
prompt_transform = AdvancedPromptTransform()
|
||||
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
|
||||
with patch("core.workflow.file.file_manager.to_prompt_message_content") as mock_get_encoded_string:
|
||||
with patch("core.workflow.file.file_manager.to_prompt_message_content", autospec=True) as mock_get_encoded_string:
|
||||
mock_get_encoded_string.return_value = ImagePromptMessageContent(
|
||||
url=str(files[0].remote_url), format="jpg", mime_type="image/jpg"
|
||||
)
|
||||
|
||||
@ -83,7 +83,7 @@ def test_extract_images_formats(mock_dependencies, monkeypatch, image_bytes, exp
|
||||
extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1")
|
||||
|
||||
# We need to handle the import inside _extract_images
|
||||
with patch("pypdfium2.raw") as mock_raw:
|
||||
with patch("pypdfium2.raw", autospec=True) as mock_raw:
|
||||
mock_raw.FPDF_PAGEOBJ_IMAGE = 1
|
||||
result = extractor._extract_images(mock_page)
|
||||
|
||||
@ -115,7 +115,7 @@ def test_extract_images_get_objects_scenarios(mock_dependencies, get_objects_sid
|
||||
|
||||
extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1")
|
||||
|
||||
with patch("pypdfium2.raw") as mock_raw:
|
||||
with patch("pypdfium2.raw", autospec=True) as mock_raw:
|
||||
mock_raw.FPDF_PAGEOBJ_IMAGE = 1
|
||||
result = extractor._extract_images(mock_page)
|
||||
|
||||
@ -133,11 +133,11 @@ def test_extract_calls_extract_images(mock_dependencies, monkeypatch):
|
||||
mock_text_page.get_text_range.return_value = "Page text content"
|
||||
mock_page.get_textpage.return_value = mock_text_page
|
||||
|
||||
with patch("pypdfium2.PdfDocument", return_value=mock_pdf_doc):
|
||||
with patch("pypdfium2.PdfDocument", return_value=mock_pdf_doc, autospec=True):
|
||||
# Mock Blob
|
||||
mock_blob = MagicMock()
|
||||
mock_blob.source = "test.pdf"
|
||||
with patch("core.rag.extractor.pdf_extractor.Blob.from_path", return_value=mock_blob):
|
||||
with patch("core.rag.extractor.pdf_extractor.Blob.from_path", return_value=mock_blob, autospec=True):
|
||||
extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1")
|
||||
|
||||
# Mock _extract_images to return a known string
|
||||
@ -175,7 +175,7 @@ def test_extract_images_failures(mock_dependencies):
|
||||
|
||||
extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1")
|
||||
|
||||
with patch("pypdfium2.raw") as mock_raw:
|
||||
with patch("pypdfium2.raw", autospec=True) as mock_raw:
|
||||
mock_raw.FPDF_PAGEOBJ_IMAGE = 1
|
||||
result = extractor._extract_images(mock_page)
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ class TestRerankModelRunner:
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_model_manager(self):
|
||||
"""Auto-use fixture to patch ModelManager for all tests in this class."""
|
||||
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
|
||||
with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm:
|
||||
mock_mm.return_value.check_model_support_vision.return_value = False
|
||||
yield mock_mm
|
||||
|
||||
@ -397,19 +397,19 @@ class TestWeightRerankRunner:
|
||||
@pytest.fixture
|
||||
def mock_model_manager(self):
|
||||
"""Mock ModelManager for embedding model."""
|
||||
with patch("core.rag.rerank.weight_rerank.ModelManager") as mock_manager:
|
||||
with patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager:
|
||||
yield mock_manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cache_embedding(self):
|
||||
"""Mock CacheEmbedding for vector operations."""
|
||||
with patch("core.rag.rerank.weight_rerank.CacheEmbedding") as mock_cache:
|
||||
with patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache:
|
||||
yield mock_cache
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jieba_handler(self):
|
||||
"""Mock JiebaKeywordTableHandler for keyword extraction."""
|
||||
with patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler") as mock_jieba:
|
||||
with patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba:
|
||||
yield mock_jieba
|
||||
|
||||
@pytest.fixture
|
||||
@ -914,7 +914,7 @@ class TestRerankIntegration:
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_model_manager(self):
|
||||
"""Auto-use fixture to patch ModelManager for all tests in this class."""
|
||||
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
|
||||
with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm:
|
||||
mock_mm.return_value.check_model_support_vision.return_value = False
|
||||
yield mock_mm
|
||||
|
||||
@ -1026,7 +1026,7 @@ class TestRerankEdgeCases:
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_model_manager(self):
|
||||
"""Auto-use fixture to patch ModelManager for all tests in this class."""
|
||||
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
|
||||
with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm:
|
||||
mock_mm.return_value.check_model_support_vision.return_value = False
|
||||
yield mock_mm
|
||||
|
||||
@ -1295,9 +1295,9 @@ class TestRerankEdgeCases:
|
||||
|
||||
# Mock dependencies
|
||||
with (
|
||||
patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler") as mock_jieba,
|
||||
patch("core.rag.rerank.weight_rerank.ModelManager") as mock_manager,
|
||||
patch("core.rag.rerank.weight_rerank.CacheEmbedding") as mock_cache,
|
||||
patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba,
|
||||
patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager,
|
||||
patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache,
|
||||
):
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.extract_keywords.return_value = ["test"]
|
||||
@ -1367,7 +1367,7 @@ class TestRerankPerformance:
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_model_manager(self):
|
||||
"""Auto-use fixture to patch ModelManager for all tests in this class."""
|
||||
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
|
||||
with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm:
|
||||
mock_mm.return_value.check_model_support_vision.return_value = False
|
||||
yield mock_mm
|
||||
|
||||
@ -1441,9 +1441,9 @@ class TestRerankPerformance:
|
||||
runner = WeightRerankRunner(tenant_id="tenant123", weights=weights)
|
||||
|
||||
with (
|
||||
patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler") as mock_jieba,
|
||||
patch("core.rag.rerank.weight_rerank.ModelManager") as mock_manager,
|
||||
patch("core.rag.rerank.weight_rerank.CacheEmbedding") as mock_cache,
|
||||
patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba,
|
||||
patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager,
|
||||
patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache,
|
||||
):
|
||||
mock_handler = MagicMock()
|
||||
# Track keyword extraction calls
|
||||
@ -1484,7 +1484,7 @@ class TestRerankErrorHandling:
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_model_manager(self):
|
||||
"""Auto-use fixture to patch ModelManager for all tests in this class."""
|
||||
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
|
||||
with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm:
|
||||
mock_mm.return_value.check_model_support_vision.return_value = False
|
||||
yield mock_mm
|
||||
|
||||
@ -1592,9 +1592,9 @@ class TestRerankErrorHandling:
|
||||
runner = WeightRerankRunner(tenant_id="tenant123", weights=weights)
|
||||
|
||||
with (
|
||||
patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler") as mock_jieba,
|
||||
patch("core.rag.rerank.weight_rerank.ModelManager") as mock_manager,
|
||||
patch("core.rag.rerank.weight_rerank.CacheEmbedding") as mock_cache,
|
||||
patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba,
|
||||
patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager,
|
||||
patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache,
|
||||
):
|
||||
mock_handler = MagicMock()
|
||||
mock_handler.extract_keywords.return_value = ["test"]
|
||||
|
||||
@ -48,7 +48,7 @@ class TestRepositoryFactory:
|
||||
import_string("invalidpath")
|
||||
assert "doesn't look like a module path" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
@patch("core.repositories.factory.dify_config", autospec=True)
|
||||
def test_create_workflow_execution_repository_success(self, mock_config):
|
||||
"""Test successful WorkflowExecutionRepository creation."""
|
||||
# Setup mock configuration
|
||||
@ -66,7 +66,7 @@ class TestRepositoryFactory:
|
||||
mock_repository_class.return_value = mock_repository_instance
|
||||
|
||||
# Mock import_string
|
||||
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
|
||||
with patch("core.repositories.factory.import_string", return_value=mock_repository_class, autospec=True):
|
||||
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
@ -83,7 +83,7 @@ class TestRepositoryFactory:
|
||||
)
|
||||
assert result is mock_repository_instance
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
@patch("core.repositories.factory.dify_config", autospec=True)
|
||||
def test_create_workflow_execution_repository_import_error(self, mock_config):
|
||||
"""Test WorkflowExecutionRepository creation with import error."""
|
||||
# Setup mock configuration with invalid class path
|
||||
@ -101,7 +101,7 @@ class TestRepositoryFactory:
|
||||
)
|
||||
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
@patch("core.repositories.factory.dify_config", autospec=True)
|
||||
def test_create_workflow_execution_repository_instantiation_error(self, mock_config):
|
||||
"""Test WorkflowExecutionRepository creation with instantiation error."""
|
||||
# Setup mock configuration
|
||||
@ -115,7 +115,7 @@ class TestRepositoryFactory:
|
||||
mock_repository_class.side_effect = Exception("Instantiation failed")
|
||||
|
||||
# Mock import_string to return a failing class
|
||||
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
|
||||
with patch("core.repositories.factory.import_string", return_value=mock_repository_class, autospec=True):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
@ -125,7 +125,7 @@ class TestRepositoryFactory:
|
||||
)
|
||||
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
@patch("core.repositories.factory.dify_config", autospec=True)
|
||||
def test_create_workflow_node_execution_repository_success(self, mock_config):
|
||||
"""Test successful WorkflowNodeExecutionRepository creation."""
|
||||
# Setup mock configuration
|
||||
@ -143,7 +143,7 @@ class TestRepositoryFactory:
|
||||
mock_repository_class.return_value = mock_repository_instance
|
||||
|
||||
# Mock import_string
|
||||
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
|
||||
with patch("core.repositories.factory.import_string", return_value=mock_repository_class, autospec=True):
|
||||
result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
@ -160,7 +160,7 @@ class TestRepositoryFactory:
|
||||
)
|
||||
assert result is mock_repository_instance
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
@patch("core.repositories.factory.dify_config", autospec=True)
|
||||
def test_create_workflow_node_execution_repository_import_error(self, mock_config):
|
||||
"""Test WorkflowNodeExecutionRepository creation with import error."""
|
||||
# Setup mock configuration with invalid class path
|
||||
@ -178,7 +178,7 @@ class TestRepositoryFactory:
|
||||
)
|
||||
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
@patch("core.repositories.factory.dify_config", autospec=True)
|
||||
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
|
||||
"""Test WorkflowNodeExecutionRepository creation with instantiation error."""
|
||||
# Setup mock configuration
|
||||
@ -192,7 +192,7 @@ class TestRepositoryFactory:
|
||||
mock_repository_class.side_effect = Exception("Instantiation failed")
|
||||
|
||||
# Mock import_string to return a failing class
|
||||
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
|
||||
with patch("core.repositories.factory.import_string", return_value=mock_repository_class, autospec=True):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
@ -208,7 +208,7 @@ class TestRepositoryFactory:
|
||||
error = RepositoryImportError(error_message)
|
||||
assert str(error) == error_message
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
@patch("core.repositories.factory.dify_config", autospec=True)
|
||||
def test_create_with_engine_instead_of_sessionmaker(self, mock_config):
|
||||
"""Test repository creation with Engine instead of sessionmaker."""
|
||||
# Setup mock configuration
|
||||
@ -226,7 +226,7 @@ class TestRepositoryFactory:
|
||||
mock_repository_class.return_value = mock_repository_instance
|
||||
|
||||
# Mock import_string
|
||||
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
|
||||
with patch("core.repositories.factory.import_string", return_value=mock_repository_class, autospec=True):
|
||||
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_engine, # Using Engine instead of sessionmaker
|
||||
user=mock_user,
|
||||
|
||||
@ -196,7 +196,7 @@ class TestSchemaResolver:
|
||||
resolved1 = resolve_dify_schema_refs(schema)
|
||||
|
||||
# Mock the registry to return different data
|
||||
with patch.object(self.registry, "get_schema") as mock_get:
|
||||
with patch.object(self.registry, "get_schema", autospec=True) as mock_get:
|
||||
mock_get.return_value = {"type": "different"}
|
||||
|
||||
# Second resolution should use cache
|
||||
@ -445,7 +445,7 @@ class TestSchemaResolverClass:
|
||||
|
||||
# Second resolver should use the same cache
|
||||
resolver2 = SchemaResolver()
|
||||
with patch.object(resolver2.registry, "get_schema") as mock_get:
|
||||
with patch.object(resolver2.registry, "get_schema", autospec=True) as mock_get:
|
||||
result2 = resolver2.resolve(schema)
|
||||
# Should not call registry since it's in cache
|
||||
mock_get.assert_not_called()
|
||||
|
||||
@ -138,8 +138,8 @@ class TestFlaskExecutionContext:
|
||||
class TestCaptureFlaskContext:
|
||||
"""Test capture_flask_context function."""
|
||||
|
||||
@patch("context.flask_app_context.current_app")
|
||||
@patch("context.flask_app_context.g")
|
||||
@patch("context.flask_app_context.current_app", autospec=True)
|
||||
@patch("context.flask_app_context.g", autospec=True)
|
||||
def test_capture_flask_context_captures_app(self, mock_g, mock_current_app):
|
||||
"""Test capture_flask_context captures Flask app."""
|
||||
mock_app = MagicMock()
|
||||
@ -152,8 +152,8 @@ class TestCaptureFlaskContext:
|
||||
|
||||
assert ctx._flask_app == mock_app
|
||||
|
||||
@patch("context.flask_app_context.current_app")
|
||||
@patch("context.flask_app_context.g")
|
||||
@patch("context.flask_app_context.current_app", autospec=True)
|
||||
@patch("context.flask_app_context.g", autospec=True)
|
||||
def test_capture_flask_context_captures_user_from_g(self, mock_g, mock_current_app):
|
||||
"""Test capture_flask_context captures user from Flask g object."""
|
||||
mock_app = MagicMock()
|
||||
@ -170,7 +170,7 @@ class TestCaptureFlaskContext:
|
||||
|
||||
assert ctx.user == mock_user
|
||||
|
||||
@patch("context.flask_app_context.current_app")
|
||||
@patch("context.flask_app_context.current_app", autospec=True)
|
||||
def test_capture_flask_context_with_explicit_user(self, mock_current_app):
|
||||
"""Test capture_flask_context uses explicit user parameter."""
|
||||
mock_app = MagicMock()
|
||||
@ -186,7 +186,7 @@ class TestCaptureFlaskContext:
|
||||
|
||||
assert ctx.user == explicit_user
|
||||
|
||||
@patch("context.flask_app_context.current_app")
|
||||
@patch("context.flask_app_context.current_app", autospec=True)
|
||||
def test_capture_flask_context_captures_contextvars(self, mock_current_app):
|
||||
"""Test capture_flask_context captures context variables."""
|
||||
mock_app = MagicMock()
|
||||
@ -267,7 +267,7 @@ class TestFlaskExecutionContextIntegration:
|
||||
# Verify app context was entered
|
||||
assert mock_flask_app.app_context.called
|
||||
|
||||
@patch("context.flask_app_context.g")
|
||||
@patch("context.flask_app_context.g", autospec=True)
|
||||
def test_enter_restores_user_in_g(self, mock_g, mock_flask_app):
|
||||
"""Test that enter restores user in Flask g object."""
|
||||
mock_user = MagicMock()
|
||||
|
||||
@ -138,10 +138,10 @@ class TestGraphRuntimeState:
|
||||
_ = state.response_coordinator
|
||||
|
||||
mock_graph = MagicMock()
|
||||
with patch("core.workflow.graph_engine.response_coordinator.ResponseStreamCoordinator") as coordinator_cls:
|
||||
coordinator_instance = MagicMock()
|
||||
coordinator_cls.return_value = coordinator_instance
|
||||
|
||||
with patch(
|
||||
"core.workflow.graph_engine.response_coordinator.ResponseStreamCoordinator", autospec=True
|
||||
) as coordinator_cls:
|
||||
coordinator_instance = coordinator_cls.return_value
|
||||
state.configure(graph=mock_graph)
|
||||
|
||||
assert state.response_coordinator is coordinator_instance
|
||||
@ -204,7 +204,7 @@ class TestGraphRuntimeState:
|
||||
|
||||
mock_graph = MagicMock()
|
||||
stub = StubCoordinator()
|
||||
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub):
|
||||
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub, autospec=True):
|
||||
state.attach_graph(mock_graph)
|
||||
|
||||
stub.state = "configured"
|
||||
@ -230,7 +230,7 @@ class TestGraphRuntimeState:
|
||||
assert restored_execution.started is True
|
||||
|
||||
new_stub = StubCoordinator()
|
||||
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub):
|
||||
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub, autospec=True):
|
||||
restored.attach_graph(mock_graph)
|
||||
|
||||
assert new_stub.state == "configured"
|
||||
@ -251,14 +251,14 @@ class TestGraphRuntimeState:
|
||||
|
||||
mock_graph = MagicMock()
|
||||
original_stub = StubCoordinator()
|
||||
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=original_stub):
|
||||
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=original_stub, autospec=True):
|
||||
state.attach_graph(mock_graph)
|
||||
|
||||
original_stub.state = "configured"
|
||||
snapshot = state.dumps()
|
||||
|
||||
new_stub = StubCoordinator()
|
||||
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub):
|
||||
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub, autospec=True):
|
||||
restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
restored.attach_graph(mock_graph)
|
||||
restored.loads(snapshot)
|
||||
|
||||
@ -63,7 +63,7 @@ class TestPrivateWorkflowPauseEntity:
|
||||
|
||||
assert entity.resumed_at is None
|
||||
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage", autospec=True)
|
||||
def test_get_state_first_call(self, mock_storage):
|
||||
"""Test get_state loads from storage on first call."""
|
||||
state_data = b'{"test": "data", "step": 5}'
|
||||
@ -81,7 +81,7 @@ class TestPrivateWorkflowPauseEntity:
|
||||
mock_storage.load.assert_called_once_with("test-state-key")
|
||||
assert entity._cached_state == state_data
|
||||
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage", autospec=True)
|
||||
def test_get_state_cached_call(self, mock_storage):
|
||||
"""Test get_state returns cached data on subsequent calls."""
|
||||
state_data = b'{"test": "data", "step": 5}'
|
||||
@ -102,7 +102,7 @@ class TestPrivateWorkflowPauseEntity:
|
||||
# Storage should only be called once
|
||||
mock_storage.load.assert_called_once_with("test-state-key")
|
||||
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
|
||||
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage", autospec=True)
|
||||
def test_get_state_with_pre_cached_data(self, mock_storage):
|
||||
"""Test get_state returns pre-cached data."""
|
||||
state_data = b'{"test": "data", "step": 5}'
|
||||
@ -125,7 +125,7 @@ class TestPrivateWorkflowPauseEntity:
|
||||
# Test with binary data that's not valid JSON
|
||||
binary_data = b"\x00\x01\x02\x03\x04\x05\xff\xfe"
|
||||
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
|
||||
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage", autospec=True) as mock_storage:
|
||||
mock_storage.load.return_value = binary_data
|
||||
|
||||
mock_pause_model = MagicMock(spec=WorkflowPauseModel)
|
||||
|
||||
@ -90,14 +90,14 @@ def mock_tool_node():
|
||||
@pytest.fixture
|
||||
def mock_is_instrument_flag_enabled_false():
|
||||
"""Mock is_instrument_flag_enabled to return False."""
|
||||
with patch("core.app.workflow.layers.observability.is_instrument_flag_enabled", return_value=False):
|
||||
with patch("core.app.workflow.layers.observability.is_instrument_flag_enabled", return_value=False, autospec=True):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_is_instrument_flag_enabled_true():
|
||||
"""Mock is_instrument_flag_enabled to return True."""
|
||||
with patch("core.app.workflow.layers.observability.is_instrument_flag_enabled", return_value=True):
|
||||
with patch("core.app.workflow.layers.observability.is_instrument_flag_enabled", return_value=True, autospec=True):
|
||||
yield
|
||||
|
||||
|
||||
|
||||
@ -117,9 +117,7 @@ def test_parallel_streaming_workflow():
|
||||
# Create node factory and graph
|
||||
node_factory = DifyNodeFactory(graph_init_params=init_params, graph_runtime_state=graph_runtime_state)
|
||||
with patch.object(
|
||||
DifyNodeFactory,
|
||||
"_build_model_instance_for_llm_node",
|
||||
return_value=MagicMock(spec=ModelInstance),
|
||||
DifyNodeFactory, "_build_model_instance_for_llm_node", return_value=MagicMock(spec=ModelInstance), autospec=True
|
||||
):
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
|
||||
@ -378,7 +378,7 @@ class TestStopEventIntegration:
|
||||
class TestStopEventTimeoutBehavior:
|
||||
"""Test stop_event behavior with join timeouts."""
|
||||
|
||||
@patch("core.workflow.graph_engine.orchestration.dispatcher.threading.Thread")
|
||||
@patch("core.workflow.graph_engine.orchestration.dispatcher.threading.Thread", autospec=True)
|
||||
def test_dispatcher_uses_shorter_timeout(self, mock_thread_cls: MagicMock):
|
||||
"""Test that Dispatcher uses 2s timeout instead of 10s."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
@ -405,7 +405,7 @@ class TestStopEventTimeoutBehavior:
|
||||
|
||||
mock_thread_instance.join.assert_called_once_with(timeout=2.0)
|
||||
|
||||
@patch("core.workflow.graph_engine.worker_management.worker_pool.Worker")
|
||||
@patch("core.workflow.graph_engine.worker_management.worker_pool.Worker", autospec=True)
|
||||
def test_worker_pool_uses_shorter_timeout(self, mock_worker_cls: MagicMock):
|
||||
"""Test that WorkerPool uses 2s timeout instead of 10s."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
@ -198,11 +198,10 @@ def test_fetch_model_config_uses_ports(model_config: ModelConfigWithCredentialsE
|
||||
provider_model_bundle.configuration.__class__,
|
||||
"get_provider_model",
|
||||
return_value=provider_model,
|
||||
autospec=True,
|
||||
),
|
||||
mock.patch.object(
|
||||
model_type_instance.__class__,
|
||||
"get_model_schema",
|
||||
return_value=model_config.model_schema,
|
||||
model_type_instance.__class__, "get_model_schema", return_value=model_config.model_schema, autospec=True
|
||||
),
|
||||
):
|
||||
fetch_model_config(
|
||||
|
||||
@ -128,7 +128,8 @@ class TestTemplateTransformNode:
|
||||
assert TemplateTransformNode.version() == "1"
|
||||
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
|
||||
autospec=True,
|
||||
)
|
||||
def test_run_simple_template(
|
||||
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
|
||||
@ -165,7 +166,8 @@ class TestTemplateTransformNode:
|
||||
assert result.inputs["age"] == 30
|
||||
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
|
||||
autospec=True,
|
||||
)
|
||||
def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with None variable values."""
|
||||
@ -192,7 +194,8 @@ class TestTemplateTransformNode:
|
||||
assert result.inputs["value"] is None
|
||||
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
|
||||
autospec=True,
|
||||
)
|
||||
def test_run_with_code_execution_error(
|
||||
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
|
||||
@ -215,7 +218,8 @@ class TestTemplateTransformNode:
|
||||
assert "Template syntax error" in result.error
|
||||
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
|
||||
autospec=True,
|
||||
)
|
||||
def test_run_output_length_exceeds_limit(
|
||||
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
|
||||
@ -239,7 +243,8 @@ class TestTemplateTransformNode:
|
||||
assert "Output length exceeds" in result.error
|
||||
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
|
||||
autospec=True,
|
||||
)
|
||||
def test_run_with_complex_jinja2_template(
|
||||
self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params
|
||||
@ -303,7 +308,8 @@ class TestTemplateTransformNode:
|
||||
assert mapping["node_123.var2"] == ["sys", "input2"]
|
||||
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
|
||||
autospec=True,
|
||||
)
|
||||
def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with no variables (static template)."""
|
||||
@ -330,7 +336,8 @@ class TestTemplateTransformNode:
|
||||
assert result.inputs == {}
|
||||
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
|
||||
autospec=True,
|
||||
)
|
||||
def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with numeric variable values."""
|
||||
@ -369,7 +376,8 @@ class TestTemplateTransformNode:
|
||||
assert result.outputs["output"] == "Total: $31.5"
|
||||
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
|
||||
autospec=True,
|
||||
)
|
||||
def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with dictionary variable values."""
|
||||
@ -400,7 +408,8 @@ class TestTemplateTransformNode:
|
||||
assert "john@example.com" in result.outputs["output"]
|
||||
|
||||
@patch(
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
|
||||
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
|
||||
autospec=True,
|
||||
)
|
||||
def test_run_with_list_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
|
||||
"""Test _run with list variable values."""
|
||||
|
||||
@ -92,7 +92,9 @@ def _run_transform(tool_node: ToolNode, message: ToolInvokeMessage) -> tuple[lis
|
||||
return messages
|
||||
|
||||
tool_runtime = MagicMock()
|
||||
with patch.object(ToolFileMessageTransformer, "transform_tool_invoke_messages", side_effect=_identity_transform):
|
||||
with patch.object(
|
||||
ToolFileMessageTransformer, "transform_tool_invoke_messages", side_effect=_identity_transform, autospec=True
|
||||
):
|
||||
generator = tool_node._transform_message(
|
||||
messages=iter([message]),
|
||||
tool_info={"provider_type": "builtin", "provider_id": "provider"},
|
||||
|
||||
@ -26,11 +26,8 @@ class TestWorkflowEntryRedisChannel:
|
||||
redis_channel = RedisChannel(mock_redis_client, "test:channel:key")
|
||||
|
||||
# Patch GraphEngine to verify it receives the Redis channel
|
||||
with patch("core.workflow.workflow_entry.GraphEngine") as MockGraphEngine:
|
||||
mock_graph_engine = MagicMock()
|
||||
MockGraphEngine.return_value = mock_graph_engine
|
||||
|
||||
# Create WorkflowEntry with Redis channel
|
||||
with patch("core.workflow.workflow_entry.GraphEngine", autospec=True) as MockGraphEngine:
|
||||
mock_graph_engine = MockGraphEngine.return_value # Create WorkflowEntry with Redis channel
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id="test-tenant",
|
||||
app_id="test-app",
|
||||
@ -63,15 +60,11 @@ class TestWorkflowEntryRedisChannel:
|
||||
|
||||
# Patch GraphEngine and InMemoryChannel
|
||||
with (
|
||||
patch("core.workflow.workflow_entry.GraphEngine") as MockGraphEngine,
|
||||
patch("core.workflow.workflow_entry.InMemoryChannel") as MockInMemoryChannel,
|
||||
patch("core.workflow.workflow_entry.GraphEngine", autospec=True) as MockGraphEngine,
|
||||
patch("core.workflow.workflow_entry.InMemoryChannel", autospec=True) as MockInMemoryChannel,
|
||||
):
|
||||
mock_graph_engine = MagicMock()
|
||||
MockGraphEngine.return_value = mock_graph_engine
|
||||
mock_inmemory_channel = MagicMock()
|
||||
MockInMemoryChannel.return_value = mock_inmemory_channel
|
||||
|
||||
# Create WorkflowEntry without providing a channel
|
||||
mock_graph_engine = MockGraphEngine.return_value
|
||||
mock_inmemory_channel = MockInMemoryChannel.return_value # Create WorkflowEntry without providing a channel
|
||||
workflow_entry = WorkflowEntry(
|
||||
tenant_id="test-tenant",
|
||||
app_id="test-app",
|
||||
@ -114,7 +107,7 @@ class TestWorkflowEntryRedisChannel:
|
||||
mock_event2 = MagicMock()
|
||||
|
||||
# Patch GraphEngine
|
||||
with patch("core.workflow.workflow_entry.GraphEngine") as MockGraphEngine:
|
||||
with patch("core.workflow.workflow_entry.GraphEngine", autospec=True) as MockGraphEngine:
|
||||
mock_graph_engine = MagicMock()
|
||||
mock_graph_engine.run.return_value = iter([mock_event1, mock_event2])
|
||||
MockGraphEngine.return_value = mock_graph_engine
|
||||
|
||||
@ -40,7 +40,7 @@ def mock_upload_file():
|
||||
mock.source_url = TEST_REMOTE_URL
|
||||
mock.size = 1024
|
||||
mock.key = "test_key"
|
||||
with patch("factories.file_factory.db.session.scalar", return_value=mock) as m:
|
||||
with patch("factories.file_factory.db.session.scalar", return_value=mock, autospec=True) as m:
|
||||
yield m
|
||||
|
||||
|
||||
@ -54,7 +54,7 @@ def mock_tool_file():
|
||||
mock.mimetype = "application/pdf"
|
||||
mock.original_url = "http://example.com/tool.pdf"
|
||||
mock.size = 2048
|
||||
with patch("factories.file_factory.db.session.scalar", return_value=mock):
|
||||
with patch("factories.file_factory.db.session.scalar", return_value=mock, autospec=True):
|
||||
yield mock
|
||||
|
||||
|
||||
@ -70,7 +70,7 @@ def mock_http_head():
|
||||
},
|
||||
)
|
||||
|
||||
with patch("factories.file_factory.ssrf_proxy.head") as mock_head:
|
||||
with patch("factories.file_factory.ssrf_proxy.head", autospec=True) as mock_head:
|
||||
mock_head.return_value = _mock_response("remote_test.jpg", 2048, "image/jpeg")
|
||||
yield mock_head
|
||||
|
||||
@ -188,7 +188,7 @@ def test_build_from_remote_url_without_strict_validation(mock_http_head):
|
||||
|
||||
def test_tool_file_not_found():
|
||||
"""Test ToolFile not found in database."""
|
||||
with patch("factories.file_factory.db.session.scalar", return_value=None):
|
||||
with patch("factories.file_factory.db.session.scalar", return_value=None, autospec=True):
|
||||
mapping = tool_file_mapping()
|
||||
with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"):
|
||||
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||
@ -196,7 +196,7 @@ def test_tool_file_not_found():
|
||||
|
||||
def test_local_file_not_found():
|
||||
"""Test UploadFile not found in database."""
|
||||
with patch("factories.file_factory.db.session.scalar", return_value=None):
|
||||
with patch("factories.file_factory.db.session.scalar", return_value=None, autospec=True):
|
||||
mapping = local_file_mapping()
|
||||
with pytest.raises(ValueError, match="Invalid upload file"):
|
||||
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||
@ -268,7 +268,7 @@ def test_tenant_mismatch():
|
||||
mock_file.key = "test_key"
|
||||
|
||||
# Mock the database query to return None (no file found for this tenant)
|
||||
with patch("factories.file_factory.db.session.scalar", return_value=None):
|
||||
with patch("factories.file_factory.db.session.scalar", return_value=None, autospec=True):
|
||||
mapping = local_file_mapping()
|
||||
with pytest.raises(ValueError, match="Invalid upload file"):
|
||||
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||
|
||||
@ -403,7 +403,7 @@ class TestRedisSubscription:
|
||||
|
||||
# ==================== Listener Thread Tests ====================
|
||||
|
||||
@patch("time.sleep", side_effect=lambda x: None) # Speed up test
|
||||
@patch("time.sleep", side_effect=lambda x: None, autospec=True) # Speed up test
|
||||
def test_listener_thread_normal_operation(
|
||||
self, mock_sleep, subscription: _RedisSubscription, mock_pubsub: MagicMock
|
||||
):
|
||||
@ -826,7 +826,7 @@ class TestRedisShardedSubscription:
|
||||
|
||||
# ==================== Listener Thread Tests ====================
|
||||
|
||||
@patch("time.sleep", side_effect=lambda x: None) # Speed up test
|
||||
@patch("time.sleep", side_effect=lambda x: None, autospec=True) # Speed up test
|
||||
def test_listener_thread_normal_operation(
|
||||
self, mock_sleep, sharded_subscription: _RedisShardedSubscription, mock_pubsub: MagicMock
|
||||
):
|
||||
|
||||
@ -104,7 +104,7 @@ class TestParseTimeRange:
|
||||
def test_parse_time_range_dst_ambiguous_time(self):
|
||||
"""Test parsing during DST ambiguous time (fall back)."""
|
||||
# This test simulates DST fall back where 2:30 AM occurs twice
|
||||
with patch("pytz.timezone") as mock_timezone:
|
||||
with patch("pytz.timezone", autospec=True) as mock_timezone:
|
||||
# Mock timezone that raises AmbiguousTimeError
|
||||
mock_tz = mock_timezone.return_value
|
||||
|
||||
@ -135,7 +135,7 @@ class TestParseTimeRange:
|
||||
|
||||
def test_parse_time_range_dst_nonexistent_time(self):
|
||||
"""Test parsing during DST nonexistent time (spring forward)."""
|
||||
with patch("pytz.timezone") as mock_timezone:
|
||||
with patch("pytz.timezone", autospec=True) as mock_timezone:
|
||||
# Mock timezone that raises NonExistentTimeError
|
||||
mock_tz = mock_timezone.return_value
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ class TestLoginRequired:
|
||||
with setup_app.test_request_context():
|
||||
# Mock authenticated user
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
|
||||
@ -70,7 +70,7 @@ class TestLoginRequired:
|
||||
with setup_app.test_request_context():
|
||||
# Mock unauthenticated user
|
||||
mock_user = MockUser("test_user", is_authenticated=False)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
|
||||
result = protected_view()
|
||||
assert result == "Unauthorized"
|
||||
setup_app.login_manager.unauthorized.assert_called_once()
|
||||
@ -86,8 +86,8 @@ class TestLoginRequired:
|
||||
with setup_app.test_request_context():
|
||||
# Mock unauthenticated user and LOGIN_DISABLED
|
||||
mock_user = MockUser("test_user", is_authenticated=False)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
with patch("libs.login.dify_config") as mock_config:
|
||||
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
|
||||
with patch("libs.login.dify_config", autospec=True) as mock_config:
|
||||
mock_config.LOGIN_DISABLED = True
|
||||
|
||||
result = protected_view()
|
||||
@ -106,7 +106,7 @@ class TestLoginRequired:
|
||||
with setup_app.test_request_context(method="OPTIONS"):
|
||||
# Mock unauthenticated user
|
||||
mock_user = MockUser("test_user", is_authenticated=False)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
# Ensure unauthorized was not called
|
||||
@ -125,7 +125,7 @@ class TestLoginRequired:
|
||||
|
||||
with setup_app.test_request_context():
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
|
||||
result = protected_view()
|
||||
assert result == "Synced content"
|
||||
setup_app.ensure_sync.assert_called_once()
|
||||
@ -144,7 +144,7 @@ class TestLoginRequired:
|
||||
|
||||
with setup_app.test_request_context():
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
|
||||
@ -197,14 +197,14 @@ class TestCurrentUser:
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
|
||||
with app.test_request_context():
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
|
||||
assert current_user.id == "test_user"
|
||||
assert current_user.is_authenticated is True
|
||||
|
||||
def test_current_user_proxy_returns_none_when_no_user(self, app: Flask):
|
||||
"""Test that current_user proxy handles None user."""
|
||||
with app.test_request_context():
|
||||
with patch("libs.login._get_user", return_value=None):
|
||||
with patch("libs.login._get_user", return_value=None, autospec=True):
|
||||
# When _get_user returns None, accessing attributes should fail
|
||||
# or current_user should evaluate to falsy
|
||||
try:
|
||||
@ -224,7 +224,7 @@ class TestCurrentUser:
|
||||
def check_user_in_thread(user_id: str, index: int):
|
||||
with app.test_request_context():
|
||||
mock_user = MockUser(user_id)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
|
||||
results[index] = current_user.id
|
||||
|
||||
# Create multiple threads with different users
|
||||
|
||||
@ -68,7 +68,7 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
({}, None, True),
|
||||
],
|
||||
)
|
||||
@patch("httpx.post")
|
||||
@patch("httpx.post", autospec=True)
|
||||
def test_should_retrieve_access_token(
|
||||
self, mock_post, oauth, mock_response, response_data, expected_token, should_raise
|
||||
):
|
||||
@ -105,7 +105,7 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("httpx.get")
|
||||
@patch("httpx.get", autospec=True)
|
||||
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email):
|
||||
user_response = MagicMock()
|
||||
user_response.json.return_value = user_data
|
||||
@ -121,7 +121,7 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
assert user_info.name == user_data["name"]
|
||||
assert user_info.email == expected_email
|
||||
|
||||
@patch("httpx.get")
|
||||
@patch("httpx.get", autospec=True)
|
||||
def test_should_handle_network_errors(self, mock_get, oauth):
|
||||
mock_get.side_effect = httpx.RequestError("Network error")
|
||||
|
||||
@ -167,7 +167,7 @@ class TestGoogleOAuth(BaseOAuthTest):
|
||||
({}, None, True),
|
||||
],
|
||||
)
|
||||
@patch("httpx.post")
|
||||
@patch("httpx.post", autospec=True)
|
||||
def test_should_retrieve_access_token(
|
||||
self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise
|
||||
):
|
||||
@ -201,7 +201,7 @@ class TestGoogleOAuth(BaseOAuthTest):
|
||||
({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string
|
||||
],
|
||||
)
|
||||
@patch("httpx.get")
|
||||
@patch("httpx.get", autospec=True)
|
||||
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name):
|
||||
mock_response.json.return_value = user_data
|
||||
mock_get.return_value = mock_response
|
||||
@ -222,7 +222,7 @@ class TestGoogleOAuth(BaseOAuthTest):
|
||||
httpx.TimeoutException,
|
||||
],
|
||||
)
|
||||
@patch("httpx.get")
|
||||
@patch("httpx.get", autospec=True)
|
||||
def test_should_handle_http_errors(self, mock_get, oauth, exception_type):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status.side_effect = exception_type("Error")
|
||||
|
||||
@ -9,11 +9,9 @@ def _mail() -> dict:
|
||||
return {"to": "user@example.com", "subject": "Hi", "html": "<b>Hi</b>"}
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP")
|
||||
@patch("libs.smtp.smtplib.SMTP", autospec=True)
|
||||
def test_smtp_plain_success(mock_smtp_cls: MagicMock):
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp_cls.return_value = mock_smtp
|
||||
|
||||
mock_smtp = mock_smtp_cls.return_value
|
||||
client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com")
|
||||
client.send(_mail())
|
||||
|
||||
@ -22,11 +20,9 @@ def test_smtp_plain_success(mock_smtp_cls: MagicMock):
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP")
|
||||
@patch("libs.smtp.smtplib.SMTP", autospec=True)
|
||||
def test_smtp_tls_opportunistic_success(mock_smtp_cls: MagicMock):
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp_cls.return_value = mock_smtp
|
||||
|
||||
mock_smtp = mock_smtp_cls.return_value
|
||||
client = SMTPClient(
|
||||
server="smtp.example.com",
|
||||
port=587,
|
||||
@ -46,7 +42,7 @@ def test_smtp_tls_opportunistic_success(mock_smtp_cls: MagicMock):
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP_SSL")
|
||||
@patch("libs.smtp.smtplib.SMTP_SSL", autospec=True)
|
||||
def test_smtp_tls_ssl_branch_and_timeout(mock_smtp_ssl_cls: MagicMock):
|
||||
# Cover SMTP_SSL branch and TimeoutError handling
|
||||
mock_smtp = MagicMock()
|
||||
@ -67,7 +63,7 @@ def test_smtp_tls_ssl_branch_and_timeout(mock_smtp_ssl_cls: MagicMock):
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP")
|
||||
@patch("libs.smtp.smtplib.SMTP", autospec=True)
|
||||
def test_smtp_generic_exception_propagates(mock_smtp_cls: MagicMock):
|
||||
mock_smtp = MagicMock()
|
||||
mock_smtp.sendmail.side_effect = RuntimeError("oops")
|
||||
@ -79,7 +75,7 @@ def test_smtp_generic_exception_propagates(mock_smtp_cls: MagicMock):
|
||||
mock_smtp.quit.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.smtp.smtplib.SMTP")
|
||||
@patch("libs.smtp.smtplib.SMTP", autospec=True)
|
||||
def test_smtp_smtplib_exception_in_login(mock_smtp_cls: MagicMock):
|
||||
# Ensure we hit the specific SMTPException except branch
|
||||
import smtplib
|
||||
|
||||
@ -301,7 +301,7 @@ class TestAppModelConfig:
|
||||
)
|
||||
|
||||
# Mock database query to return None
|
||||
with patch("models.model.db.session.query") as mock_query:
|
||||
with patch("models.model.db.session.query", autospec=True) as mock_query:
|
||||
mock_query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
@ -952,7 +952,7 @@ class TestSiteModel:
|
||||
def test_site_generate_code(self):
|
||||
"""Test Site.generate_code static method."""
|
||||
# Mock database query to return 0 (no existing codes)
|
||||
with patch("models.model.db.session.query") as mock_query:
|
||||
with patch("models.model.db.session.query", autospec=True) as mock_query:
|
||||
mock_query.return_value.where.return_value.count.return_value = 0
|
||||
|
||||
# Act
|
||||
@ -1167,7 +1167,7 @@ class TestConversationStatusCount:
|
||||
conversation.id = str(uuid4())
|
||||
|
||||
# Mock the database query to return no messages
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
with patch("models.model.db.session.scalars", autospec=True) as mock_scalars:
|
||||
mock_scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
@ -1192,7 +1192,7 @@ class TestConversationStatusCount:
|
||||
conversation.id = conversation_id
|
||||
|
||||
# Mock the database query to return no messages with workflow_run_id
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
with patch("models.model.db.session.scalars", autospec=True) as mock_scalars:
|
||||
mock_scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
@ -1277,7 +1277,7 @@ class TestConversationStatusCount:
|
||||
return mock_result
|
||||
|
||||
# Act & Assert
|
||||
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
|
||||
with patch("models.model.db.session.scalars", side_effect=mock_scalars, autospec=True):
|
||||
result = conversation.status_count
|
||||
|
||||
# Verify only 2 database queries were made (not N+1)
|
||||
@ -1340,7 +1340,7 @@ class TestConversationStatusCount:
|
||||
return mock_result
|
||||
|
||||
# Act
|
||||
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
|
||||
with patch("models.model.db.session.scalars", side_effect=mock_scalars, autospec=True):
|
||||
result = conversation.status_count
|
||||
|
||||
# Assert - query should include app_id filter
|
||||
@ -1385,7 +1385,7 @@ class TestConversationStatusCount:
|
||||
),
|
||||
]
|
||||
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
with patch("models.model.db.session.scalars", autospec=True) as mock_scalars:
|
||||
# Mock the messages query
|
||||
def mock_scalars_side_effect(query):
|
||||
mock_result = MagicMock()
|
||||
@ -1441,7 +1441,7 @@ class TestConversationStatusCount:
|
||||
),
|
||||
]
|
||||
|
||||
with patch("models.model.db.session.scalars") as mock_scalars:
|
||||
with patch("models.model.db.session.scalars", autospec=True) as mock_scalars:
|
||||
|
||||
def mock_scalars_side_effect(query):
|
||||
mock_result = MagicMock()
|
||||
|
||||
@ -15,7 +15,7 @@ class TestTencentCos(BaseStorageTest):
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self, setup_tencent_cos_mock):
|
||||
"""Executed before each test method."""
|
||||
with patch.object(CosConfig, "__init__", return_value=None):
|
||||
with patch.object(CosConfig, "__init__", return_value=None, autospec=True):
|
||||
self.storage = TencentCosStorage()
|
||||
self.storage.bucket_name = get_example_bucket()
|
||||
|
||||
@ -39,9 +39,9 @@ class TestTencentCosConfiguration:
|
||||
with (
|
||||
patch("extensions.storage.tencent_cos_storage.dify_config", mock_dify_config),
|
||||
patch(
|
||||
"extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance
|
||||
"extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance, autospec=True
|
||||
) as mock_cos_config,
|
||||
patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client),
|
||||
patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client, autospec=True),
|
||||
):
|
||||
TencentCosStorage()
|
||||
|
||||
@ -72,9 +72,9 @@ class TestTencentCosConfiguration:
|
||||
with (
|
||||
patch("extensions.storage.tencent_cos_storage.dify_config", mock_dify_config),
|
||||
patch(
|
||||
"extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance
|
||||
"extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance, autospec=True
|
||||
) as mock_cos_config,
|
||||
patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client),
|
||||
patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client, autospec=True),
|
||||
):
|
||||
TencentCosStorage()
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ class TestApiKeyAuthFactory:
|
||||
)
|
||||
def test_get_apikey_auth_factory_valid_providers(self, provider, auth_class_path):
|
||||
"""Test getting auth factory for all valid providers"""
|
||||
with patch(auth_class_path) as mock_auth:
|
||||
with patch(auth_class_path, autospec=True) as mock_auth:
|
||||
auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider)
|
||||
assert auth_class == mock_auth
|
||||
|
||||
@ -46,7 +46,7 @@ class TestApiKeyAuthFactory:
|
||||
(False, False),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory")
|
||||
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory", autospec=True)
|
||||
def test_validate_credentials_delegates_to_auth_instance(
|
||||
self, mock_get_factory, credentials_return_value, expected_result
|
||||
):
|
||||
@ -65,7 +65,7 @@ class TestApiKeyAuthFactory:
|
||||
assert result is expected_result
|
||||
mock_auth_instance.validate_credentials.assert_called_once()
|
||||
|
||||
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory")
|
||||
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory", autospec=True)
|
||||
def test_validate_credentials_propagates_exceptions(self, mock_get_factory):
|
||||
"""Test that exceptions from auth instance are propagated"""
|
||||
# Arrange
|
||||
|
||||
@ -65,7 +65,7 @@ class TestFirecrawlAuth:
|
||||
FirecrawlAuth(credentials)
|
||||
assert str(exc_info.value) == expected_error
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
@ -96,7 +96,7 @@ class TestFirecrawlAuth:
|
||||
(500, "Internal server error"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
|
||||
def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance):
|
||||
"""Test handling of various HTTP error codes"""
|
||||
mock_response = MagicMock()
|
||||
@ -118,7 +118,7 @@ class TestFirecrawlAuth:
|
||||
(401, "Not JSON", True, "Failed to authorize. Status code: 401. Error: Not JSON"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
|
||||
def test_should_handle_unexpected_errors(
|
||||
self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance
|
||||
):
|
||||
@ -145,7 +145,7 @@ class TestFirecrawlAuth:
|
||||
(httpx.ConnectTimeout, "Connection timeout"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
|
||||
def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance):
|
||||
"""Test handling of various network-related errors including timeouts"""
|
||||
mock_post.side_effect = exception_type(exception_message)
|
||||
@ -167,7 +167,7 @@ class TestFirecrawlAuth:
|
||||
FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
|
||||
assert "super_secret_key_12345" not in str(exc_info.value)
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
|
||||
def test_should_use_custom_base_url_in_validation(self, mock_post):
|
||||
"""Test that custom base URL is used in validation and normalized"""
|
||||
mock_response = MagicMock()
|
||||
@ -185,7 +185,7 @@ class TestFirecrawlAuth:
|
||||
assert result is True
|
||||
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
|
||||
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post")
|
||||
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
|
||||
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
|
||||
"""Test that timeout errors are handled gracefully with appropriate error message"""
|
||||
mock_post.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
|
||||
|
||||
@ -35,7 +35,7 @@ class TestJinaAuth:
|
||||
JinaAuth(credentials)
|
||||
assert str(exc_info.value) == "No API key provided"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina.httpx.post", autospec=True)
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_post):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
@ -53,7 +53,7 @@ class TestJinaAuth:
|
||||
json={"url": "https://example.com"},
|
||||
)
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina.httpx.post", autospec=True)
|
||||
def test_should_handle_http_402_error(self, mock_post):
|
||||
"""Test handling of 402 Payment Required error"""
|
||||
mock_response = MagicMock()
|
||||
@ -68,7 +68,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina.httpx.post", autospec=True)
|
||||
def test_should_handle_http_409_error(self, mock_post):
|
||||
"""Test handling of 409 Conflict error"""
|
||||
mock_response = MagicMock()
|
||||
@ -83,7 +83,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina.httpx.post", autospec=True)
|
||||
def test_should_handle_http_500_error(self, mock_post):
|
||||
"""Test handling of 500 Internal Server Error"""
|
||||
mock_response = MagicMock()
|
||||
@ -98,7 +98,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina.httpx.post", autospec=True)
|
||||
def test_should_handle_unexpected_error_with_text_response(self, mock_post):
|
||||
"""Test handling of unexpected errors with text response"""
|
||||
mock_response = MagicMock()
|
||||
@ -114,7 +114,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina.httpx.post", autospec=True)
|
||||
def test_should_handle_unexpected_error_without_text(self, mock_post):
|
||||
"""Test handling of unexpected errors without text response"""
|
||||
mock_response = MagicMock()
|
||||
@ -130,7 +130,7 @@ class TestJinaAuth:
|
||||
auth.validate_credentials()
|
||||
assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404"
|
||||
|
||||
@patch("services.auth.jina.jina.httpx.post")
|
||||
@patch("services.auth.jina.jina.httpx.post", autospec=True)
|
||||
def test_should_handle_network_errors(self, mock_post):
|
||||
"""Test handling of network connection errors"""
|
||||
mock_post.side_effect = httpx.ConnectError("Network error")
|
||||
|
||||
@ -64,7 +64,7 @@ class TestWatercrawlAuth:
|
||||
WatercrawlAuth(credentials)
|
||||
assert str(exc_info.value) == expected_error
|
||||
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
|
||||
def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance):
|
||||
"""Test successful credential validation"""
|
||||
mock_response = MagicMock()
|
||||
@ -87,7 +87,7 @@ class TestWatercrawlAuth:
|
||||
(500, "Internal server error"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
|
||||
def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance):
|
||||
"""Test handling of various HTTP error codes"""
|
||||
mock_response = MagicMock()
|
||||
@ -107,7 +107,7 @@ class TestWatercrawlAuth:
|
||||
(401, "Not JSON", True, "Expecting value"), # JSON decode error
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
|
||||
def test_should_handle_unexpected_errors(
|
||||
self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance
|
||||
):
|
||||
@ -132,7 +132,7 @@ class TestWatercrawlAuth:
|
||||
(httpx.ConnectTimeout, "Connection timeout"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
|
||||
def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance):
|
||||
"""Test handling of various network-related errors including timeouts"""
|
||||
mock_get.side_effect = exception_type(exception_message)
|
||||
@ -154,7 +154,7 @@ class TestWatercrawlAuth:
|
||||
WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}})
|
||||
assert "super_secret_key_12345" not in str(exc_info.value)
|
||||
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
|
||||
def test_should_use_custom_base_url_in_validation(self, mock_get):
|
||||
"""Test that custom base URL is used in validation"""
|
||||
mock_response = MagicMock()
|
||||
@ -179,7 +179,7 @@ class TestWatercrawlAuth:
|
||||
("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
|
||||
],
|
||||
)
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
|
||||
def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url):
|
||||
"""Test that urljoin is used correctly for URL construction with various base URLs"""
|
||||
mock_response = MagicMock()
|
||||
@ -193,7 +193,7 @@ class TestWatercrawlAuth:
|
||||
# Verify the correct URL was called
|
||||
assert mock_get.call_args[0][0] == expected_url
|
||||
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get")
|
||||
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
|
||||
def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance):
|
||||
"""Test that timeout errors are handled gracefully with appropriate error message"""
|
||||
mock_get.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
|
||||
|
||||
@ -27,7 +27,7 @@ class TestTraceparentPropagation:
|
||||
@pytest.fixture
|
||||
def mock_httpx_client(self):
|
||||
"""Mock httpx.Client for testing."""
|
||||
with patch("services.enterprise.base.httpx.Client") as mock_client_class:
|
||||
with patch("services.enterprise.base.httpx.Client", autospec=True) as mock_client_class:
|
||||
mock_client = MagicMock()
|
||||
mock_client_class.return_value.__enter__.return_value = mock_client
|
||||
mock_client_class.return_value.__exit__.return_value = None
|
||||
@ -44,7 +44,9 @@ class TestTraceparentPropagation:
|
||||
# Arrange
|
||||
expected_traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
|
||||
|
||||
with patch("services.enterprise.base.generate_traceparent_header", return_value=expected_traceparent):
|
||||
with patch(
|
||||
"services.enterprise.base.generate_traceparent_header", return_value=expected_traceparent, autospec=True
|
||||
):
|
||||
# Act
|
||||
EnterpriseRequest.send_request("GET", "/test")
|
||||
|
||||
|
||||
@ -135,8 +135,8 @@ class TestExternalDatasetServiceGetExternalKnowledgeApis:
|
||||
"""
|
||||
|
||||
with (
|
||||
patch("services.external_knowledge_service.db.paginate") as mock_paginate,
|
||||
patch("services.external_knowledge_service.select"),
|
||||
patch("services.external_knowledge_service.db.paginate", autospec=True) as mock_paginate,
|
||||
patch("services.external_knowledge_service.select", autospec=True),
|
||||
):
|
||||
yield mock_paginate
|
||||
|
||||
@ -245,7 +245,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
|
||||
Patch ``db.session`` for all CRUD tests in this class.
|
||||
"""
|
||||
|
||||
with patch("services.external_knowledge_service.db.session") as mock_session:
|
||||
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_create_external_knowledge_api_success(self, mock_db_session: MagicMock):
|
||||
@ -263,7 +263,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
|
||||
}
|
||||
|
||||
# We do not want to actually call the remote endpoint here, so we patch the validator.
|
||||
with patch.object(ExternalDatasetService, "check_endpoint_and_api_key") as mock_check:
|
||||
with patch.object(ExternalDatasetService, "check_endpoint_and_api_key", autospec=True) as mock_check:
|
||||
result = ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args)
|
||||
|
||||
assert isinstance(result, ExternalKnowledgeApis)
|
||||
@ -386,7 +386,7 @@ class TestExternalDatasetServiceUsageAndBindings:
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
with patch("services.external_knowledge_service.db.session") as mock_session:
|
||||
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_external_knowledge_api_use_check_in_use(self, mock_db_session: MagicMock):
|
||||
@ -447,7 +447,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
with patch("services.external_knowledge_service.db.session") as mock_session:
|
||||
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_document_create_args_validate_success(self, mock_db_session: MagicMock):
|
||||
@ -520,7 +520,7 @@ class TestExternalDatasetServiceProcessExternalApi:
|
||||
|
||||
fake_response = httpx.Response(200)
|
||||
|
||||
with patch("services.external_knowledge_service.ssrf_proxy.post") as mock_post:
|
||||
with patch("services.external_knowledge_service.ssrf_proxy.post", autospec=True) as mock_post:
|
||||
mock_post.return_value = fake_response
|
||||
|
||||
result = ExternalDatasetService.process_external_api(settings, files=None)
|
||||
@ -681,7 +681,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
with patch("services.external_knowledge_service.db.session") as mock_session:
|
||||
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_create_external_dataset_success(self, mock_db_session: MagicMock):
|
||||
@ -801,7 +801,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
with patch("services.external_knowledge_service.db.session") as mock_session:
|
||||
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
|
||||
yield mock_session
|
||||
|
||||
def test_fetch_external_knowledge_retrieval_success(self, mock_db_session: MagicMock):
|
||||
@ -838,7 +838,9 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
|
||||
|
||||
metadata_condition = SimpleNamespace(model_dump=lambda: {"field": "value"})
|
||||
|
||||
with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response) as mock_process:
|
||||
with patch.object(
|
||||
ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True
|
||||
) as mock_process:
|
||||
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
@ -908,7 +910,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
|
||||
fake_response.status_code = 500
|
||||
fake_response.json.return_value = {}
|
||||
|
||||
with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response):
|
||||
with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True):
|
||||
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
|
||||
tenant_id="tenant-1",
|
||||
dataset_id="ds-1",
|
||||
|
||||
@ -146,7 +146,7 @@ class TestHitTestingServiceRetrieve:
|
||||
Provides a mocked database session for testing database operations
|
||||
like adding and committing DatasetQuery records.
|
||||
"""
|
||||
with patch("services.hit_testing_service.db.session") as mock_db:
|
||||
with patch("services.hit_testing_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
def test_retrieve_success_with_default_retrieval_model(self, mock_db_session):
|
||||
@ -174,9 +174,11 @@ class TestHitTestingServiceRetrieve:
|
||||
]
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
|
||||
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
|
||||
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
|
||||
patch("services.hit_testing_service.RetrievalService.retrieve", autospec=True) as mock_retrieve,
|
||||
patch(
|
||||
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
|
||||
) as mock_format,
|
||||
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
|
||||
):
|
||||
mock_perf_counter.side_effect = [0.0, 0.1] # start, end
|
||||
mock_retrieve.return_value = documents
|
||||
@ -218,9 +220,11 @@ class TestHitTestingServiceRetrieve:
|
||||
mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()]
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
|
||||
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
|
||||
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
|
||||
patch("services.hit_testing_service.RetrievalService.retrieve", autospec=True) as mock_retrieve,
|
||||
patch(
|
||||
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
|
||||
) as mock_format,
|
||||
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
|
||||
):
|
||||
mock_perf_counter.side_effect = [0.0, 0.1]
|
||||
mock_retrieve.return_value = documents
|
||||
@ -268,10 +272,12 @@ class TestHitTestingServiceRetrieve:
|
||||
mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()]
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
|
||||
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
|
||||
patch("services.hit_testing_service.DatasetRetrieval") as mock_dataset_retrieval_class,
|
||||
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
|
||||
patch("services.hit_testing_service.RetrievalService.retrieve", autospec=True) as mock_retrieve,
|
||||
patch(
|
||||
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
|
||||
) as mock_format,
|
||||
patch("services.hit_testing_service.DatasetRetrieval", autospec=True) as mock_dataset_retrieval_class,
|
||||
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
|
||||
):
|
||||
mock_perf_counter.side_effect = [0.0, 0.1]
|
||||
mock_dataset_retrieval_class.return_value = mock_dataset_retrieval
|
||||
@ -311,8 +317,10 @@ class TestHitTestingServiceRetrieve:
|
||||
mock_dataset_retrieval.get_metadata_filter_condition.return_value = ({}, True)
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.DatasetRetrieval") as mock_dataset_retrieval_class,
|
||||
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
|
||||
patch("services.hit_testing_service.DatasetRetrieval", autospec=True) as mock_dataset_retrieval_class,
|
||||
patch(
|
||||
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
|
||||
) as mock_format,
|
||||
):
|
||||
mock_dataset_retrieval_class.return_value = mock_dataset_retrieval
|
||||
mock_format.return_value = []
|
||||
@ -346,9 +354,11 @@ class TestHitTestingServiceRetrieve:
|
||||
mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()]
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
|
||||
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
|
||||
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
|
||||
patch("services.hit_testing_service.RetrievalService.retrieve", autospec=True) as mock_retrieve,
|
||||
patch(
|
||||
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
|
||||
) as mock_format,
|
||||
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
|
||||
):
|
||||
mock_perf_counter.side_effect = [0.0, 0.1]
|
||||
mock_retrieve.return_value = documents
|
||||
@ -380,7 +390,7 @@ class TestHitTestingServiceExternalRetrieve:
|
||||
Provides a mocked database session for testing database operations
|
||||
like adding and committing DatasetQuery records.
|
||||
"""
|
||||
with patch("services.hit_testing_service.db.session") as mock_db:
|
||||
with patch("services.hit_testing_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
def test_external_retrieve_success(self, mock_db_session):
|
||||
@ -403,8 +413,10 @@ class TestHitTestingServiceExternalRetrieve:
|
||||
]
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve,
|
||||
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
|
||||
patch(
|
||||
"services.hit_testing_service.RetrievalService.external_retrieve", autospec=True
|
||||
) as mock_external_retrieve,
|
||||
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
|
||||
):
|
||||
mock_perf_counter.side_effect = [0.0, 0.1]
|
||||
mock_external_retrieve.return_value = external_documents
|
||||
@ -467,8 +479,10 @@ class TestHitTestingServiceExternalRetrieve:
|
||||
external_documents = [{"content": "Doc 1", "title": "Title", "score": 0.9, "metadata": {}}]
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve,
|
||||
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
|
||||
patch(
|
||||
"services.hit_testing_service.RetrievalService.external_retrieve", autospec=True
|
||||
) as mock_external_retrieve,
|
||||
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
|
||||
):
|
||||
mock_perf_counter.side_effect = [0.0, 0.1]
|
||||
mock_external_retrieve.return_value = external_documents
|
||||
@ -499,8 +513,10 @@ class TestHitTestingServiceExternalRetrieve:
|
||||
metadata_filtering_conditions = {}
|
||||
|
||||
with (
|
||||
patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve,
|
||||
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
|
||||
patch(
|
||||
"services.hit_testing_service.RetrievalService.external_retrieve", autospec=True
|
||||
) as mock_external_retrieve,
|
||||
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
|
||||
):
|
||||
mock_perf_counter.side_effect = [0.0, 0.1]
|
||||
mock_external_retrieve.return_value = []
|
||||
@ -542,7 +558,9 @@ class TestHitTestingServiceCompactRetrieveResponse:
|
||||
HitTestingTestDataFactory.create_retrieval_record_mock(content="Doc 2", score=0.85),
|
||||
]
|
||||
|
||||
with patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format:
|
||||
with patch(
|
||||
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
|
||||
) as mock_format:
|
||||
mock_format.return_value = mock_records
|
||||
|
||||
# Act
|
||||
@ -566,7 +584,9 @@ class TestHitTestingServiceCompactRetrieveResponse:
|
||||
query = "test query"
|
||||
documents = []
|
||||
|
||||
with patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format:
|
||||
with patch(
|
||||
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
|
||||
) as mock_format:
|
||||
mock_format.return_value = []
|
||||
|
||||
# Act
|
||||
|
||||
@ -147,7 +147,7 @@ class TestSegmentServiceCreateSegment:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
@pytest.fixture
|
||||
@ -172,10 +172,12 @@ class TestSegmentServiceCreateSegment:
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.lock") as mock_lock,
|
||||
patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_now,
|
||||
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
|
||||
patch(
|
||||
"services.dataset_service.VectorService.create_segments_vector", autospec=True
|
||||
) as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
|
||||
):
|
||||
mock_lock.return_value.__enter__ = Mock()
|
||||
mock_lock.return_value.__exit__ = Mock(return_value=None)
|
||||
@ -219,10 +221,12 @@ class TestSegmentServiceCreateSegment:
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.lock") as mock_lock,
|
||||
patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_now,
|
||||
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
|
||||
patch(
|
||||
"services.dataset_service.VectorService.create_segments_vector", autospec=True
|
||||
) as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
|
||||
):
|
||||
mock_lock.return_value.__enter__ = Mock()
|
||||
mock_lock.return_value.__exit__ = Mock(return_value=None)
|
||||
@ -257,11 +261,13 @@ class TestSegmentServiceCreateSegment:
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.lock") as mock_lock,
|
||||
patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service,
|
||||
patch("services.dataset_service.ModelManager") as mock_model_manager_class,
|
||||
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_now,
|
||||
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
|
||||
patch(
|
||||
"services.dataset_service.VectorService.create_segments_vector", autospec=True
|
||||
) as mock_vector_service,
|
||||
patch("services.dataset_service.ModelManager", autospec=True) as mock_model_manager_class,
|
||||
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
|
||||
):
|
||||
mock_lock.return_value.__enter__ = Mock()
|
||||
mock_lock.return_value.__exit__ = Mock(return_value=None)
|
||||
@ -292,10 +298,12 @@ class TestSegmentServiceCreateSegment:
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.lock") as mock_lock,
|
||||
patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_now,
|
||||
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
|
||||
patch(
|
||||
"services.dataset_service.VectorService.create_segments_vector", autospec=True
|
||||
) as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
|
||||
):
|
||||
mock_lock.return_value.__enter__ = Mock()
|
||||
mock_lock.return_value.__exit__ = Mock(return_value=None)
|
||||
@ -317,7 +325,7 @@ class TestSegmentServiceUpdateSegment:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
@pytest.fixture
|
||||
@ -338,10 +346,10 @@ class TestSegmentServiceUpdateSegment:
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = segment
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.get") as mock_redis_get,
|
||||
patch("services.dataset_service.VectorService.update_segment_vector") as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_now,
|
||||
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
|
||||
patch("services.dataset_service.VectorService.update_segment_vector", autospec=True) as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
|
||||
):
|
||||
mock_redis_get.return_value = None # Not indexing
|
||||
mock_hash.return_value = "new-hash"
|
||||
@ -368,10 +376,10 @@ class TestSegmentServiceUpdateSegment:
|
||||
args = SegmentUpdateArgs(enabled=False)
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.get") as mock_redis_get,
|
||||
patch("services.dataset_service.redis_client.setex") as mock_redis_setex,
|
||||
patch("services.dataset_service.disable_segment_from_index_task") as mock_task,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_now,
|
||||
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
|
||||
patch("services.dataset_service.redis_client.setex", autospec=True) as mock_redis_setex,
|
||||
patch("services.dataset_service.disable_segment_from_index_task", autospec=True) as mock_task,
|
||||
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
|
||||
):
|
||||
mock_redis_get.return_value = None
|
||||
mock_now.return_value = "2024-01-01T00:00:00"
|
||||
@ -394,7 +402,7 @@ class TestSegmentServiceUpdateSegment:
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock()
|
||||
args = SegmentUpdateArgs(content="Updated content")
|
||||
|
||||
with patch("services.dataset_service.redis_client.get") as mock_redis_get:
|
||||
with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get:
|
||||
mock_redis_get.return_value = "1" # Indexing in progress
|
||||
|
||||
# Act & Assert
|
||||
@ -409,7 +417,7 @@ class TestSegmentServiceUpdateSegment:
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock()
|
||||
args = SegmentUpdateArgs(content="Updated content")
|
||||
|
||||
with patch("services.dataset_service.redis_client.get") as mock_redis_get:
|
||||
with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get:
|
||||
mock_redis_get.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
@ -427,10 +435,10 @@ class TestSegmentServiceUpdateSegment:
|
||||
mock_db_session.query.return_value.where.return_value.first.return_value = segment
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.get") as mock_redis_get,
|
||||
patch("services.dataset_service.VectorService.update_segment_vector") as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_now,
|
||||
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
|
||||
patch("services.dataset_service.VectorService.update_segment_vector", autospec=True) as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
|
||||
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
|
||||
):
|
||||
mock_redis_get.return_value = None
|
||||
mock_hash.return_value = "new-hash"
|
||||
@ -456,7 +464,7 @@ class TestSegmentServiceDeleteSegment:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
def test_delete_segment_success(self, mock_db_session):
|
||||
@ -471,10 +479,10 @@ class TestSegmentServiceDeleteSegment:
|
||||
mock_db_session.scalars.return_value = mock_scalars
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.get") as mock_redis_get,
|
||||
patch("services.dataset_service.redis_client.setex") as mock_redis_setex,
|
||||
patch("services.dataset_service.delete_segment_from_index_task") as mock_task,
|
||||
patch("services.dataset_service.select") as mock_select,
|
||||
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
|
||||
patch("services.dataset_service.redis_client.setex", autospec=True) as mock_redis_setex,
|
||||
patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task,
|
||||
patch("services.dataset_service.select", autospec=True) as mock_select,
|
||||
):
|
||||
mock_redis_get.return_value = None
|
||||
mock_select.return_value.where.return_value = mock_select
|
||||
@ -495,8 +503,8 @@ class TestSegmentServiceDeleteSegment:
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock()
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.get") as mock_redis_get,
|
||||
patch("services.dataset_service.delete_segment_from_index_task") as mock_task,
|
||||
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
|
||||
patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task,
|
||||
):
|
||||
mock_redis_get.return_value = None
|
||||
|
||||
@ -515,7 +523,7 @@ class TestSegmentServiceDeleteSegment:
|
||||
document = SegmentTestDataFactory.create_document_mock()
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock()
|
||||
|
||||
with patch("services.dataset_service.redis_client.get") as mock_redis_get:
|
||||
with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get:
|
||||
mock_redis_get.return_value = "1" # Deletion in progress
|
||||
|
||||
# Act & Assert
|
||||
@ -529,7 +537,7 @@ class TestSegmentServiceDeleteSegments:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
@pytest.fixture
|
||||
@ -562,8 +570,8 @@ class TestSegmentServiceDeleteSegments:
|
||||
mock_db_session.scalars.return_value = mock_scalars
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.delete_segment_from_index_task") as mock_task,
|
||||
patch("services.dataset_service.select") as mock_select_func,
|
||||
patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task,
|
||||
patch("services.dataset_service.select", autospec=True) as mock_select_func,
|
||||
):
|
||||
mock_select_func.return_value = mock_select
|
||||
|
||||
@ -594,7 +602,7 @@ class TestSegmentServiceUpdateSegmentsStatus:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
@pytest.fixture
|
||||
@ -623,9 +631,9 @@ class TestSegmentServiceUpdateSegmentsStatus:
|
||||
mock_db_session.scalars.return_value = mock_scalars
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.get") as mock_redis_get,
|
||||
patch("services.dataset_service.enable_segments_to_index_task") as mock_task,
|
||||
patch("services.dataset_service.select") as mock_select_func,
|
||||
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
|
||||
patch("services.dataset_service.enable_segments_to_index_task", autospec=True) as mock_task,
|
||||
patch("services.dataset_service.select", autospec=True) as mock_select_func,
|
||||
):
|
||||
mock_redis_get.return_value = None
|
||||
mock_select_func.return_value = mock_select
|
||||
@ -657,10 +665,10 @@ class TestSegmentServiceUpdateSegmentsStatus:
|
||||
mock_db_session.scalars.return_value = mock_scalars
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.get") as mock_redis_get,
|
||||
patch("services.dataset_service.disable_segments_from_index_task") as mock_task,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_now,
|
||||
patch("services.dataset_service.select") as mock_select_func,
|
||||
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
|
||||
patch("services.dataset_service.disable_segments_from_index_task", autospec=True) as mock_task,
|
||||
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
|
||||
patch("services.dataset_service.select", autospec=True) as mock_select_func,
|
||||
):
|
||||
mock_redis_get.return_value = None
|
||||
mock_now.return_value = "2024-01-01T00:00:00"
|
||||
@ -693,7 +701,7 @@ class TestSegmentServiceGetSegments:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
@pytest.fixture
|
||||
@ -771,7 +779,7 @@ class TestSegmentServiceGetSegmentById:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
def test_get_segment_by_id_success(self, mock_db_session):
|
||||
@ -814,7 +822,7 @@ class TestSegmentServiceGetChildChunks:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
@pytest.fixture
|
||||
@ -876,7 +884,7 @@ class TestSegmentServiceGetChildChunkById:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
def test_get_child_chunk_by_id_success(self, mock_db_session):
|
||||
@ -919,7 +927,7 @@ class TestSegmentServiceCreateChildChunk:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
@pytest.fixture
|
||||
@ -942,9 +950,11 @@ class TestSegmentServiceCreateChildChunk:
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.lock") as mock_lock,
|
||||
patch("services.dataset_service.VectorService.create_child_chunk_vector") as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
|
||||
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
|
||||
patch(
|
||||
"services.dataset_service.VectorService.create_child_chunk_vector", autospec=True
|
||||
) as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
|
||||
):
|
||||
mock_lock.return_value.__enter__ = Mock()
|
||||
mock_lock.return_value.__exit__ = Mock(return_value=None)
|
||||
@ -972,9 +982,11 @@ class TestSegmentServiceCreateChildChunk:
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.redis_client.lock") as mock_lock,
|
||||
patch("services.dataset_service.VectorService.create_child_chunk_vector") as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
|
||||
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
|
||||
patch(
|
||||
"services.dataset_service.VectorService.create_child_chunk_vector", autospec=True
|
||||
) as mock_vector_service,
|
||||
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
|
||||
):
|
||||
mock_lock.return_value.__enter__ = Mock()
|
||||
mock_lock.return_value.__exit__ = Mock(return_value=None)
|
||||
@ -994,7 +1006,7 @@ class TestSegmentServiceUpdateChildChunk:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
@pytest.fixture
|
||||
@ -1014,8 +1026,10 @@ class TestSegmentServiceUpdateChildChunk:
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock()
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.VectorService.update_child_chunk_vector") as mock_vector_service,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_now,
|
||||
patch(
|
||||
"services.dataset_service.VectorService.update_child_chunk_vector", autospec=True
|
||||
) as mock_vector_service,
|
||||
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
|
||||
):
|
||||
mock_now.return_value = "2024-01-01T00:00:00"
|
||||
|
||||
@ -1040,8 +1054,10 @@ class TestSegmentServiceUpdateChildChunk:
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock()
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.VectorService.update_child_chunk_vector") as mock_vector_service,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_now,
|
||||
patch(
|
||||
"services.dataset_service.VectorService.update_child_chunk_vector", autospec=True
|
||||
) as mock_vector_service,
|
||||
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
|
||||
):
|
||||
mock_vector_service.side_effect = Exception("Vector indexing failed")
|
||||
mock_now.return_value = "2024-01-01T00:00:00"
|
||||
@ -1059,7 +1075,7 @@ class TestSegmentServiceDeleteChildChunk:
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""Mock database session."""
|
||||
with patch("services.dataset_service.db.session") as mock_db:
|
||||
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
|
||||
yield mock_db
|
||||
|
||||
def test_delete_child_chunk_success(self, mock_db_session):
|
||||
@ -1068,7 +1084,9 @@ class TestSegmentServiceDeleteChildChunk:
|
||||
chunk = SegmentTestDataFactory.create_child_chunk_mock()
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock()
|
||||
|
||||
with patch("services.dataset_service.VectorService.delete_child_chunk_vector") as mock_vector_service:
|
||||
with patch(
|
||||
"services.dataset_service.VectorService.delete_child_chunk_vector", autospec=True
|
||||
) as mock_vector_service:
|
||||
# Act
|
||||
SegmentService.delete_child_chunk(chunk, dataset)
|
||||
|
||||
@ -1083,7 +1101,9 @@ class TestSegmentServiceDeleteChildChunk:
|
||||
chunk = SegmentTestDataFactory.create_child_chunk_mock()
|
||||
dataset = SegmentTestDataFactory.create_dataset_mock()
|
||||
|
||||
with patch("services.dataset_service.VectorService.delete_child_chunk_vector") as mock_vector_service:
|
||||
with patch(
|
||||
"services.dataset_service.VectorService.delete_child_chunk_vector", autospec=True
|
||||
) as mock_vector_service:
|
||||
mock_vector_service.side_effect = Exception("Vector deletion failed")
|
||||
|
||||
# Act & Assert
|
||||
|
||||
@ -15,8 +15,8 @@ from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
|
||||
class TestWorkflowRunArchiver:
|
||||
"""Tests for the WorkflowRunArchiver class."""
|
||||
|
||||
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config")
|
||||
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage")
|
||||
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config", autospec=True)
|
||||
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage", autospec=True)
|
||||
def test_archiver_initialization(self, mock_get_storage, mock_config):
|
||||
"""Test archiver can be initialized with various options."""
|
||||
from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver
|
||||
|
||||
@ -214,7 +214,7 @@ def factory():
|
||||
class TestAudioServiceASR:
|
||||
"""Test speech-to-text (ASR) operations."""
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory):
|
||||
"""Test successful ASR transcription in CHAT mode."""
|
||||
# Arrange
|
||||
@ -226,9 +226,7 @@ class TestAudioServiceASR:
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_speech2text.return_value = "Transcribed text"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
@ -242,7 +240,7 @@ class TestAudioServiceASR:
|
||||
call_args = mock_model_instance.invoke_speech2text.call_args
|
||||
assert call_args.kwargs["user"] == "user-123"
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_asr_success_advanced_chat_mode(self, mock_model_manager_class, factory):
|
||||
"""Test successful ASR transcription in ADVANCED_CHAT mode."""
|
||||
# Arrange
|
||||
@ -254,9 +252,7 @@ class TestAudioServiceASR:
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_speech2text.return_value = "Workflow transcribed text"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
@ -351,7 +347,7 @@ class TestAudioServiceASR:
|
||||
with pytest.raises(AudioTooLargeServiceError, match="Audio size larger than 30 mb"):
|
||||
AudioService.transcript_asr(app_model=app, file=file)
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_asr_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
|
||||
"""Test that ASR raises error when no model instance is available."""
|
||||
# Arrange
|
||||
@ -363,8 +359,7 @@ class TestAudioServiceASR:
|
||||
file = factory.create_file_storage_mock()
|
||||
|
||||
# Mock ModelManager to return None
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_manager.get_default_model_instance.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
@ -375,7 +370,7 @@ class TestAudioServiceASR:
|
||||
class TestAudioServiceTTS:
|
||||
"""Test text-to-speech (TTS) operations."""
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory):
|
||||
"""Test successful TTS with text input."""
|
||||
# Arrange
|
||||
@ -388,9 +383,7 @@ class TestAudioServiceTTS:
|
||||
)
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"audio data"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
@ -412,8 +405,8 @@ class TestAudioServiceTTS:
|
||||
voice="en-US-Neural",
|
||||
)
|
||||
|
||||
@patch("services.audio_service.db.session")
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.db.session", autospec=True)
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_tts_with_message_id_success(self, mock_model_manager_class, mock_db_session, factory):
|
||||
"""Test successful TTS with message ID."""
|
||||
# Arrange
|
||||
@ -437,9 +430,7 @@ class TestAudioServiceTTS:
|
||||
mock_query.first.return_value = message
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"audio from message"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
@ -454,7 +445,7 @@ class TestAudioServiceTTS:
|
||||
assert result == b"audio from message"
|
||||
mock_model_instance.invoke_tts.assert_called_once()
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory):
|
||||
"""Test TTS uses default voice when none specified."""
|
||||
# Arrange
|
||||
@ -467,9 +458,7 @@ class TestAudioServiceTTS:
|
||||
)
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"audio data"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
@ -486,7 +475,7 @@ class TestAudioServiceTTS:
|
||||
call_args = mock_model_instance.invoke_tts.call_args
|
||||
assert call_args.kwargs["voice"] == "default-voice"
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_tts_gets_first_available_voice_when_none_configured(self, mock_model_manager_class, factory):
|
||||
"""Test TTS gets first available voice when none is configured."""
|
||||
# Arrange
|
||||
@ -499,9 +488,7 @@ class TestAudioServiceTTS:
|
||||
)
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.get_tts_voices.return_value = [{"value": "auto-voice"}]
|
||||
mock_model_instance.invoke_tts.return_value = b"audio data"
|
||||
@ -518,8 +505,8 @@ class TestAudioServiceTTS:
|
||||
call_args = mock_model_instance.invoke_tts.call_args
|
||||
assert call_args.kwargs["voice"] == "auto-voice"
|
||||
|
||||
@patch("services.audio_service.WorkflowService")
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.WorkflowService", autospec=True)
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_tts_workflow_mode_with_draft(
|
||||
self, mock_model_manager_class, mock_workflow_service_class, factory
|
||||
):
|
||||
@ -533,14 +520,11 @@ class TestAudioServiceTTS:
|
||||
)
|
||||
|
||||
# Mock WorkflowService
|
||||
mock_workflow_service = MagicMock()
|
||||
mock_workflow_service_class.return_value = mock_workflow_service
|
||||
mock_workflow_service = mock_workflow_service_class.return_value
|
||||
mock_workflow_service.get_draft_workflow.return_value = draft_workflow
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.invoke_tts.return_value = b"draft audio"
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
@ -565,7 +549,7 @@ class TestAudioServiceTTS:
|
||||
with pytest.raises(ValueError, match="Text is required"):
|
||||
AudioService.transcript_tts(app_model=app, text=None)
|
||||
|
||||
@patch("services.audio_service.db.session")
|
||||
@patch("services.audio_service.db.session", autospec=True)
|
||||
def test_transcript_tts_returns_none_for_invalid_message_id(self, mock_db_session, factory):
|
||||
"""Test that TTS returns None for invalid message ID format."""
|
||||
# Arrange
|
||||
@ -580,7 +564,7 @@ class TestAudioServiceTTS:
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.audio_service.db.session")
|
||||
@patch("services.audio_service.db.session", autospec=True)
|
||||
def test_transcript_tts_returns_none_for_nonexistent_message(self, mock_db_session, factory):
|
||||
"""Test that TTS returns None when message doesn't exist."""
|
||||
# Arrange
|
||||
@ -601,7 +585,7 @@ class TestAudioServiceTTS:
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.audio_service.db.session")
|
||||
@patch("services.audio_service.db.session", autospec=True)
|
||||
def test_transcript_tts_returns_none_for_empty_message_answer(self, mock_db_session, factory):
|
||||
"""Test that TTS returns None when message answer is empty."""
|
||||
# Arrange
|
||||
@ -627,7 +611,7 @@ class TestAudioServiceTTS:
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_tts_raises_error_when_no_voices_available(self, mock_model_manager_class, factory):
|
||||
"""Test that TTS raises error when no voices are available."""
|
||||
# Arrange
|
||||
@ -640,9 +624,7 @@ class TestAudioServiceTTS:
|
||||
)
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.get_tts_voices.return_value = [] # No voices available
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
@ -655,7 +637,7 @@ class TestAudioServiceTTS:
|
||||
class TestAudioServiceTTSVoices:
|
||||
"""Test TTS voice listing operations."""
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_tts_voices_success(self, mock_model_manager_class, factory):
|
||||
"""Test successful retrieval of TTS voices."""
|
||||
# Arrange
|
||||
@ -668,9 +650,7 @@ class TestAudioServiceTTSVoices:
|
||||
]
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.get_tts_voices.return_value = expected_voices
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
@ -682,7 +662,7 @@ class TestAudioServiceTTSVoices:
|
||||
assert result == expected_voices
|
||||
mock_model_instance.get_tts_voices.assert_called_once_with(language)
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_tts_voices_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
|
||||
"""Test that TTS voices raises error when no model instance is available."""
|
||||
# Arrange
|
||||
@ -690,15 +670,14 @@ class TestAudioServiceTTSVoices:
|
||||
language = "en-US"
|
||||
|
||||
# Mock ModelManager to return None
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_manager.get_default_model_instance.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ProviderNotSupportTextToSpeechServiceError):
|
||||
AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language)
|
||||
|
||||
@patch("services.audio_service.ModelManager")
|
||||
@patch("services.audio_service.ModelManager", autospec=True)
|
||||
def test_transcript_tts_voices_propagates_exceptions(self, mock_model_manager_class, factory):
|
||||
"""Test that TTS voices propagates exceptions from model instance."""
|
||||
# Arrange
|
||||
@ -706,9 +685,7 @@ class TestAudioServiceTTSVoices:
|
||||
language = "en-US"
|
||||
|
||||
# Mock ModelManager
|
||||
mock_model_manager = MagicMock()
|
||||
mock_model_manager_class.return_value = mock_model_manager
|
||||
|
||||
mock_model_manager = mock_model_manager_class.return_value
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance.get_tts_voices.side_effect = RuntimeError("Model error")
|
||||
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
|
||||
|
||||
@ -237,9 +237,9 @@ class TestConversationServiceSummarization:
|
||||
titles based on the first message.
|
||||
"""
|
||||
|
||||
@patch("services.conversation_service.db.session")
|
||||
@patch("services.conversation_service.ConversationService.get_conversation")
|
||||
@patch("services.conversation_service.ConversationService.auto_generate_name")
|
||||
@patch("services.conversation_service.db.session", autospec=True)
|
||||
@patch("services.conversation_service.ConversationService.get_conversation", autospec=True)
|
||||
@patch("services.conversation_service.ConversationService.auto_generate_name", autospec=True)
|
||||
def test_rename_with_auto_generate(self, mock_auto_generate, mock_get_conversation, mock_db_session):
|
||||
"""
|
||||
Test renaming conversation with auto-generation enabled.
|
||||
|
||||
@ -28,10 +28,14 @@ class TestArchivedWorkflowRunDeletion:
|
||||
with (
|
||||
patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db),
|
||||
patch(
|
||||
"services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker
|
||||
"services.retention.workflow_run.delete_archived_workflow_run.sessionmaker",
|
||||
return_value=session_maker,
|
||||
autospec=True,
|
||||
),
|
||||
patch.object(deleter, "_get_workflow_run_repo", return_value=repo),
|
||||
patch.object(deleter, "_delete_run", return_value=MagicMock(success=True)) as mock_delete_run,
|
||||
patch.object(deleter, "_get_workflow_run_repo", return_value=repo, autospec=True),
|
||||
patch.object(
|
||||
deleter, "_delete_run", return_value=MagicMock(success=True), autospec=True
|
||||
) as mock_delete_run,
|
||||
):
|
||||
result = deleter.delete_by_run_id("run-1")
|
||||
|
||||
@ -46,7 +50,7 @@ class TestArchivedWorkflowRunDeletion:
|
||||
run.id = "run-1"
|
||||
run.tenant_id = "tenant-1"
|
||||
|
||||
with patch.object(deleter, "_get_workflow_run_repo") as mock_get_repo:
|
||||
with patch.object(deleter, "_get_workflow_run_repo", autospec=True) as mock_get_repo:
|
||||
result = deleter._delete_run(run)
|
||||
|
||||
assert result.success is True
|
||||
|
||||
@ -402,7 +402,7 @@ class TestBillingDisabledPolicyFilterMessageIds:
|
||||
class TestCreateMessageCleanPolicy:
|
||||
"""Unit tests for create_message_clean_policy factory function."""
|
||||
|
||||
@patch("services.retention.conversation.messages_clean_policy.dify_config")
|
||||
@patch("services.retention.conversation.messages_clean_policy.dify_config", autospec=True)
|
||||
def test_billing_disabled_returns_billing_disabled_policy(self, mock_config):
|
||||
"""Test that BILLING_ENABLED=False returns BillingDisabledPolicy."""
|
||||
# Arrange
|
||||
@ -414,8 +414,8 @@ class TestCreateMessageCleanPolicy:
|
||||
# Assert
|
||||
assert isinstance(policy, BillingDisabledPolicy)
|
||||
|
||||
@patch("services.retention.conversation.messages_clean_policy.BillingService")
|
||||
@patch("services.retention.conversation.messages_clean_policy.dify_config")
|
||||
@patch("services.retention.conversation.messages_clean_policy.BillingService", autospec=True)
|
||||
@patch("services.retention.conversation.messages_clean_policy.dify_config", autospec=True)
|
||||
def test_billing_enabled_policy_has_correct_internals(self, mock_config, mock_billing_service):
|
||||
"""Test that BillingSandboxPolicy is created with correct internal values."""
|
||||
# Arrange
|
||||
@ -554,7 +554,7 @@ class TestMessagesCleanServiceFromDays:
|
||||
MessagesCleanService.from_days(policy=policy, days=-1)
|
||||
|
||||
# Act
|
||||
with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime:
|
||||
with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime:
|
||||
fixed_now = datetime.datetime(2024, 6, 15, 14, 0, 0)
|
||||
mock_datetime.datetime.now.return_value = fixed_now
|
||||
mock_datetime.timedelta = datetime.timedelta
|
||||
@ -586,7 +586,7 @@ class TestMessagesCleanServiceFromDays:
|
||||
dry_run = True
|
||||
|
||||
# Act
|
||||
with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime:
|
||||
with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime:
|
||||
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
|
||||
mock_datetime.datetime.now.return_value = fixed_now
|
||||
mock_datetime.timedelta = datetime.timedelta
|
||||
@ -613,7 +613,7 @@ class TestMessagesCleanServiceFromDays:
|
||||
policy = BillingDisabledPolicy()
|
||||
|
||||
# Act
|
||||
with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime:
|
||||
with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime:
|
||||
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
|
||||
mock_datetime.datetime.now.return_value = fixed_now
|
||||
mock_datetime.timedelta = datetime.timedelta
|
||||
|
||||
@ -134,8 +134,8 @@ def factory():
|
||||
class TestRecommendedAppServiceGetApps:
|
||||
"""Test get_recommended_apps_and_categories operations."""
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommended_apps_success_with_apps(self, mock_config, mock_factory_class, factory):
|
||||
"""Test successful retrieval of recommended apps when apps are returned."""
|
||||
# Arrange
|
||||
@ -161,8 +161,8 @@ class TestRecommendedAppServiceGetApps:
|
||||
mock_factory_class.get_recommend_app_factory.assert_called_once_with("remote")
|
||||
mock_retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommended_apps_fallback_to_builtin_when_empty(self, mock_config, mock_factory_class, factory):
|
||||
"""Test fallback to builtin when no recommended apps are returned."""
|
||||
# Arrange
|
||||
@ -199,8 +199,8 @@ class TestRecommendedAppServiceGetApps:
|
||||
# Verify fallback was called with en-US (hardcoded)
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommended_apps_fallback_when_none_recommended_apps(self, mock_config, mock_factory_class, factory):
|
||||
"""Test fallback when recommended_apps key is None."""
|
||||
# Arrange
|
||||
@ -232,8 +232,8 @@ class TestRecommendedAppServiceGetApps:
|
||||
assert result == builtin_response
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once()
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommended_apps_with_different_languages(self, mock_config, mock_factory_class, factory):
|
||||
"""Test retrieval with different language codes."""
|
||||
# Arrange
|
||||
@ -262,8 +262,8 @@ class TestRecommendedAppServiceGetApps:
|
||||
assert result["recommended_apps"][0]["id"] == f"app-{language}"
|
||||
mock_instance.get_recommended_apps_and_categories.assert_called_with(language)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommended_apps_uses_correct_factory_mode(self, mock_config, mock_factory_class, factory):
|
||||
"""Test that correct factory is selected based on mode."""
|
||||
# Arrange
|
||||
@ -292,8 +292,8 @@ class TestRecommendedAppServiceGetApps:
|
||||
class TestRecommendedAppServiceGetDetail:
|
||||
"""Test get_recommend_app_detail operations."""
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommend_app_detail_success(self, mock_config, mock_factory_class, factory):
|
||||
"""Test successful retrieval of app detail."""
|
||||
# Arrange
|
||||
@ -324,8 +324,8 @@ class TestRecommendedAppServiceGetDetail:
|
||||
assert result["name"] == "Productivity App"
|
||||
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommend_app_detail_with_different_modes(self, mock_config, mock_factory_class, factory):
|
||||
"""Test app detail retrieval with different factory modes."""
|
||||
# Arrange
|
||||
@ -352,8 +352,8 @@ class TestRecommendedAppServiceGetDetail:
|
||||
assert result["name"] == f"App from {mode}"
|
||||
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommend_app_detail_returns_none_when_not_found(self, mock_config, mock_factory_class, factory):
|
||||
"""Test that None is returned when app is not found."""
|
||||
# Arrange
|
||||
@ -375,8 +375,8 @@ class TestRecommendedAppServiceGetDetail:
|
||||
assert result is None
|
||||
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommend_app_detail_returns_empty_dict(self, mock_config, mock_factory_class, factory):
|
||||
"""Test handling of empty dict response."""
|
||||
# Arrange
|
||||
@ -397,8 +397,8 @@ class TestRecommendedAppServiceGetDetail:
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
|
||||
@patch("services.recommended_app_service.dify_config")
|
||||
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
|
||||
@patch("services.recommended_app_service.dify_config", autospec=True)
|
||||
def test_get_recommend_app_detail_with_complex_model_config(self, mock_config, mock_factory_class, factory):
|
||||
"""Test app detail with complex model configuration."""
|
||||
# Arrange
|
||||
|
||||
@ -201,8 +201,8 @@ def factory():
|
||||
class TestSavedMessageServicePagination:
|
||||
"""Test saved message pagination operations."""
|
||||
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True)
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_pagination_with_account_user(self, mock_db_session, mock_message_pagination, factory):
|
||||
"""Test pagination with an Account user."""
|
||||
# Arrange
|
||||
@ -247,8 +247,8 @@ class TestSavedMessageServicePagination:
|
||||
include_ids=["msg-0", "msg-1", "msg-2"],
|
||||
)
|
||||
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True)
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_pagination_with_end_user(self, mock_db_session, mock_message_pagination, factory):
|
||||
"""Test pagination with an EndUser."""
|
||||
# Arrange
|
||||
@ -301,8 +301,8 @@ class TestSavedMessageServicePagination:
|
||||
with pytest.raises(ValueError, match="User is required"):
|
||||
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=20)
|
||||
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True)
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_pagination_with_last_id(self, mock_db_session, mock_message_pagination, factory):
|
||||
"""Test pagination with last_id parameter."""
|
||||
# Arrange
|
||||
@ -340,8 +340,8 @@ class TestSavedMessageServicePagination:
|
||||
call_args = mock_message_pagination.call_args
|
||||
assert call_args.kwargs["last_id"] == last_id
|
||||
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True)
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_pagination_with_empty_saved_messages(self, mock_db_session, mock_message_pagination, factory):
|
||||
"""Test pagination when user has no saved messages."""
|
||||
# Arrange
|
||||
@ -377,8 +377,8 @@ class TestSavedMessageServicePagination:
|
||||
class TestSavedMessageServiceSave:
|
||||
"""Test save message operations."""
|
||||
|
||||
@patch("services.saved_message_service.MessageService.get_message")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.MessageService.get_message", autospec=True)
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_save_message_for_account(self, mock_db_session, mock_get_message, factory):
|
||||
"""Test saving a message for an Account user."""
|
||||
# Arrange
|
||||
@ -407,8 +407,8 @@ class TestSavedMessageServiceSave:
|
||||
assert saved_message.created_by_role == "account"
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.saved_message_service.MessageService.get_message")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.MessageService.get_message", autospec=True)
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_save_message_for_end_user(self, mock_db_session, mock_get_message, factory):
|
||||
"""Test saving a message for an EndUser."""
|
||||
# Arrange
|
||||
@ -437,7 +437,7 @@ class TestSavedMessageServiceSave:
|
||||
assert saved_message.created_by_role == "end_user"
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_save_without_user_does_nothing(self, mock_db_session, factory):
|
||||
"""Test that saving without user is a no-op."""
|
||||
# Arrange
|
||||
@ -451,8 +451,8 @@ class TestSavedMessageServiceSave:
|
||||
mock_db_session.add.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
@patch("services.saved_message_service.MessageService.get_message")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.MessageService.get_message", autospec=True)
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_save_duplicate_message_is_idempotent(self, mock_db_session, mock_get_message, factory):
|
||||
"""Test that saving an already saved message is idempotent."""
|
||||
# Arrange
|
||||
@ -480,8 +480,8 @@ class TestSavedMessageServiceSave:
|
||||
mock_db_session.commit.assert_not_called()
|
||||
mock_get_message.assert_not_called()
|
||||
|
||||
@patch("services.saved_message_service.MessageService.get_message")
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.MessageService.get_message", autospec=True)
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_save_validates_message_exists(self, mock_db_session, mock_get_message, factory):
|
||||
"""Test that save validates message exists through MessageService."""
|
||||
# Arrange
|
||||
@ -508,7 +508,7 @@ class TestSavedMessageServiceSave:
|
||||
class TestSavedMessageServiceDelete:
|
||||
"""Test delete saved message operations."""
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_delete_saved_message_for_account(self, mock_db_session, factory):
|
||||
"""Test deleting a saved message for an Account user."""
|
||||
# Arrange
|
||||
@ -535,7 +535,7 @@ class TestSavedMessageServiceDelete:
|
||||
mock_db_session.delete.assert_called_once_with(saved_message)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_delete_saved_message_for_end_user(self, mock_db_session, factory):
|
||||
"""Test deleting a saved message for an EndUser."""
|
||||
# Arrange
|
||||
@ -562,7 +562,7 @@ class TestSavedMessageServiceDelete:
|
||||
mock_db_session.delete.assert_called_once_with(saved_message)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_delete_without_user_does_nothing(self, mock_db_session, factory):
|
||||
"""Test that deleting without user is a no-op."""
|
||||
# Arrange
|
||||
@ -576,7 +576,7 @@ class TestSavedMessageServiceDelete:
|
||||
mock_db_session.delete.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_delete_non_existent_saved_message_does_nothing(self, mock_db_session, factory):
|
||||
"""Test that deleting a non-existent saved message is a no-op."""
|
||||
# Arrange
|
||||
@ -597,7 +597,7 @@ class TestSavedMessageServiceDelete:
|
||||
mock_db_session.delete.assert_not_called()
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
@patch("services.saved_message_service.db.session")
|
||||
@patch("services.saved_message_service.db.session", autospec=True)
|
||||
def test_delete_only_affects_user_own_saved_messages(self, mock_db_session, factory):
|
||||
"""Test that delete only removes the user's own saved message."""
|
||||
# Arrange
|
||||
|
||||
@ -315,7 +315,7 @@ class TestTagServiceRetrieval:
|
||||
- get_tags_by_target_id: Get all tags bound to a specific target
|
||||
"""
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_get_tags_with_binding_counts(self, mock_db_session, factory):
|
||||
"""
|
||||
Test retrieving tags with their binding counts.
|
||||
@ -372,7 +372,7 @@ class TestTagServiceRetrieval:
|
||||
# Verify database query was called
|
||||
mock_db_session.query.assert_called_once()
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_get_tags_with_keyword_filter(self, mock_db_session, factory):
|
||||
"""
|
||||
Test retrieving tags filtered by keyword (case-insensitive).
|
||||
@ -426,7 +426,7 @@ class TestTagServiceRetrieval:
|
||||
# 2. Additional WHERE clause for keyword filtering
|
||||
assert mock_query.where.call_count >= 2, "Keyword filter should add WHERE clause"
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_get_target_ids_by_tag_ids(self, mock_db_session, factory):
|
||||
"""
|
||||
Test retrieving target IDs by tag IDs.
|
||||
@ -482,7 +482,7 @@ class TestTagServiceRetrieval:
|
||||
# Verify both queries were executed
|
||||
assert mock_db_session.scalars.call_count == 2, "Should execute tag query and binding query"
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_get_target_ids_with_empty_tag_ids(self, mock_db_session, factory):
|
||||
"""
|
||||
Test that empty tag_ids returns empty list.
|
||||
@ -510,7 +510,7 @@ class TestTagServiceRetrieval:
|
||||
assert results == [], "Should return empty list for empty input"
|
||||
mock_db_session.scalars.assert_not_called(), "Should not query database for empty input"
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_get_tag_by_tag_name(self, mock_db_session, factory):
|
||||
"""
|
||||
Test retrieving tags by name.
|
||||
@ -552,7 +552,7 @@ class TestTagServiceRetrieval:
|
||||
assert len(results) == 1, "Should find exactly one tag"
|
||||
assert results[0].name == tag_name, "Tag name should match"
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_get_tag_by_tag_name_returns_empty_for_missing_params(self, mock_db_session, factory):
|
||||
"""
|
||||
Test that missing tag_type or tag_name returns empty list.
|
||||
@ -580,7 +580,7 @@ class TestTagServiceRetrieval:
|
||||
# Verify no database queries were executed
|
||||
mock_db_session.scalars.assert_not_called(), "Should not query database for invalid input"
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_get_tags_by_target_id(self, mock_db_session, factory):
|
||||
"""
|
||||
Test retrieving tags associated with a specific target.
|
||||
@ -651,10 +651,10 @@ class TestTagServiceCRUD:
|
||||
- get_tag_binding_count: Get count of bindings for a tag
|
||||
"""
|
||||
|
||||
@patch("services.tag_service.current_user")
|
||||
@patch("services.tag_service.TagService.get_tag_by_tag_name")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.uuid.uuid4")
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
@patch("services.tag_service.uuid.uuid4", autospec=True)
|
||||
def test_save_tags(self, mock_uuid, mock_db_session, mock_get_tag_by_name, mock_current_user, factory):
|
||||
"""
|
||||
Test creating a new tag.
|
||||
@ -709,8 +709,8 @@ class TestTagServiceCRUD:
|
||||
assert added_tag.created_by == "user-123", "Created by should match current user"
|
||||
assert added_tag.tenant_id == "tenant-123", "Tenant ID should match current tenant"
|
||||
|
||||
@patch("services.tag_service.current_user")
|
||||
@patch("services.tag_service.TagService.get_tag_by_tag_name")
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
|
||||
def test_save_tags_raises_error_for_duplicate_name(self, mock_get_tag_by_name, mock_current_user, factory):
|
||||
"""
|
||||
Test that creating a tag with duplicate name raises ValueError.
|
||||
@ -740,9 +740,9 @@ class TestTagServiceCRUD:
|
||||
with pytest.raises(ValueError, match="Tag name already exists"):
|
||||
TagService.save_tags(args)
|
||||
|
||||
@patch("services.tag_service.current_user")
|
||||
@patch("services.tag_service.TagService.get_tag_by_tag_name")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_update_tags(self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory):
|
||||
"""
|
||||
Test updating a tag name.
|
||||
@ -792,9 +792,9 @@ class TestTagServiceCRUD:
|
||||
# Verify transaction was committed
|
||||
mock_db_session.commit.assert_called_once(), "Should commit transaction"
|
||||
|
||||
@patch("services.tag_service.current_user")
|
||||
@patch("services.tag_service.TagService.get_tag_by_tag_name")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_update_tags_raises_error_for_duplicate_name(
|
||||
self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory
|
||||
):
|
||||
@ -826,7 +826,7 @@ class TestTagServiceCRUD:
|
||||
with pytest.raises(ValueError, match="Tag name already exists"):
|
||||
TagService.update_tags(args, tag_id="tag-123")
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_update_tags_raises_not_found_for_missing_tag(self, mock_db_session, factory):
|
||||
"""
|
||||
Test that updating a non-existent tag raises NotFound.
|
||||
@ -848,8 +848,8 @@ class TestTagServiceCRUD:
|
||||
mock_query.first.return_value = None
|
||||
|
||||
# Mock duplicate check and current_user
|
||||
with patch("services.tag_service.TagService.get_tag_by_tag_name", return_value=[]):
|
||||
with patch("services.tag_service.current_user") as mock_user:
|
||||
with patch("services.tag_service.TagService.get_tag_by_tag_name", return_value=[], autospec=True):
|
||||
with patch("services.tag_service.current_user", autospec=True) as mock_user:
|
||||
mock_user.current_tenant_id = "tenant-123"
|
||||
args = {"name": "New Name", "type": "app"}
|
||||
|
||||
@ -858,7 +858,7 @@ class TestTagServiceCRUD:
|
||||
with pytest.raises(NotFound, match="Tag not found"):
|
||||
TagService.update_tags(args, tag_id="nonexistent")
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_get_tag_binding_count(self, mock_db_session, factory):
|
||||
"""
|
||||
Test getting the count of bindings for a tag.
|
||||
@ -894,7 +894,7 @@ class TestTagServiceCRUD:
|
||||
# Verify count matches expectation
|
||||
assert result == expected_count, "Binding count should match"
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_delete_tag(self, mock_db_session, factory):
|
||||
"""
|
||||
Test deleting a tag and its bindings.
|
||||
@ -950,7 +950,7 @@ class TestTagServiceCRUD:
|
||||
# Verify transaction was committed
|
||||
mock_db_session.commit.assert_called_once(), "Should commit transaction"
|
||||
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_delete_tag_raises_not_found(self, mock_db_session, factory):
|
||||
"""
|
||||
Test that deleting a non-existent tag raises NotFound.
|
||||
@ -996,9 +996,9 @@ class TestTagServiceBindings:
|
||||
- check_target_exists: Validate target (dataset/app) existence
|
||||
"""
|
||||
|
||||
@patch("services.tag_service.current_user")
|
||||
@patch("services.tag_service.TagService.check_target_exists")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_save_tag_binding(self, mock_db_session, mock_check_target, mock_current_user, factory):
|
||||
"""
|
||||
Test creating tag bindings.
|
||||
@ -1047,9 +1047,9 @@ class TestTagServiceBindings:
|
||||
# Verify transaction was committed
|
||||
mock_db_session.commit.assert_called_once(), "Should commit transaction"
|
||||
|
||||
@patch("services.tag_service.current_user")
|
||||
@patch("services.tag_service.TagService.check_target_exists")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_save_tag_binding_is_idempotent(self, mock_db_session, mock_check_target, mock_current_user, factory):
|
||||
"""
|
||||
Test that saving duplicate bindings is idempotent.
|
||||
@ -1088,8 +1088,8 @@ class TestTagServiceBindings:
|
||||
# Verify no new binding was added (idempotent)
|
||||
mock_db_session.add.assert_not_called(), "Should not create duplicate binding"
|
||||
|
||||
@patch("services.tag_service.TagService.check_target_exists")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_delete_tag_binding(self, mock_db_session, mock_check_target, factory):
|
||||
"""
|
||||
Test deleting a tag binding.
|
||||
@ -1136,8 +1136,8 @@ class TestTagServiceBindings:
|
||||
# Verify transaction was committed
|
||||
mock_db_session.commit.assert_called_once(), "Should commit transaction"
|
||||
|
||||
@patch("services.tag_service.TagService.check_target_exists")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_delete_tag_binding_does_nothing_if_not_exists(self, mock_db_session, mock_check_target, factory):
|
||||
"""
|
||||
Test that deleting a non-existent binding is a no-op.
|
||||
@ -1173,8 +1173,8 @@ class TestTagServiceBindings:
|
||||
# Verify no commit was made (nothing changed)
|
||||
mock_db_session.commit.assert_not_called(), "Should not commit if nothing to delete"
|
||||
|
||||
@patch("services.tag_service.current_user")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_check_target_exists_for_dataset(self, mock_db_session, mock_current_user, factory):
|
||||
"""
|
||||
Test validating that a dataset target exists.
|
||||
@ -1214,8 +1214,8 @@ class TestTagServiceBindings:
|
||||
# Verify no exception was raised and query was executed
|
||||
mock_db_session.query.assert_called_once(), "Should query database for dataset"
|
||||
|
||||
@patch("services.tag_service.current_user")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_check_target_exists_for_app(self, mock_db_session, mock_current_user, factory):
|
||||
"""
|
||||
Test validating that an app target exists.
|
||||
@ -1255,8 +1255,8 @@ class TestTagServiceBindings:
|
||||
# Verify no exception was raised and query was executed
|
||||
mock_db_session.query.assert_called_once(), "Should query database for app"
|
||||
|
||||
@patch("services.tag_service.current_user")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_check_target_exists_raises_not_found_for_missing_dataset(
|
||||
self, mock_db_session, mock_current_user, factory
|
||||
):
|
||||
@ -1287,8 +1287,8 @@ class TestTagServiceBindings:
|
||||
with pytest.raises(NotFound, match="Dataset not found"):
|
||||
TagService.check_target_exists("knowledge", "nonexistent")
|
||||
|
||||
@patch("services.tag_service.current_user")
|
||||
@patch("services.tag_service.db.session")
|
||||
@patch("services.tag_service.current_user", autospec=True)
|
||||
@patch("services.tag_service.db.session", autospec=True)
|
||||
def test_check_target_exists_raises_not_found_for_missing_app(self, mock_db_session, mock_current_user, factory):
|
||||
"""
|
||||
Test that missing app raises NotFound.
|
||||
|
||||
@ -87,7 +87,7 @@ class TestWebhookServiceUnit:
|
||||
webhook_trigger = MagicMock()
|
||||
webhook_trigger.tenant_id = "test_tenant"
|
||||
|
||||
with patch.object(WebhookService, "_process_file_uploads") as mock_process_files:
|
||||
with patch.object(WebhookService, "_process_file_uploads", autospec=True) as mock_process_files:
|
||||
mock_process_files.return_value = {"file": "mocked_file_obj"}
|
||||
|
||||
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
|
||||
@ -123,8 +123,10 @@ class TestWebhookServiceUnit:
|
||||
mock_file.to_dict.return_value = {"file": "data"}
|
||||
|
||||
with (
|
||||
patch.object(WebhookService, "_detect_binary_mimetype", return_value="text/plain") as mock_detect,
|
||||
patch.object(WebhookService, "_create_file_from_binary") as mock_create,
|
||||
patch.object(
|
||||
WebhookService, "_detect_binary_mimetype", return_value="text/plain", autospec=True
|
||||
) as mock_detect,
|
||||
patch.object(WebhookService, "_create_file_from_binary", autospec=True) as mock_create,
|
||||
):
|
||||
mock_create.return_value = mock_file
|
||||
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
|
||||
@ -168,7 +170,7 @@ class TestWebhookServiceUnit:
|
||||
fake_magic.from_buffer.side_effect = real_magic.MagicException("magic error")
|
||||
monkeypatch.setattr("services.trigger.webhook_service.magic", fake_magic)
|
||||
|
||||
with patch("services.trigger.webhook_service.logger") as mock_logger:
|
||||
with patch("services.trigger.webhook_service.logger", autospec=True) as mock_logger:
|
||||
result = WebhookService._detect_binary_mimetype(b"binary data")
|
||||
|
||||
assert result == "application/octet-stream"
|
||||
@ -245,15 +247,12 @@ class TestWebhookServiceUnit:
|
||||
assert response_data[0]["id"] == 1
|
||||
assert response_data[1]["id"] == 2
|
||||
|
||||
@patch("services.trigger.webhook_service.ToolFileManager")
|
||||
@patch("services.trigger.webhook_service.file_factory")
|
||||
@patch("services.trigger.webhook_service.ToolFileManager", autospec=True)
|
||||
@patch("services.trigger.webhook_service.file_factory", autospec=True)
|
||||
def test_process_file_uploads_success(self, mock_file_factory, mock_tool_file_manager):
|
||||
"""Test successful file upload processing."""
|
||||
# Mock ToolFileManager
|
||||
mock_tool_file_instance = MagicMock()
|
||||
mock_tool_file_manager.return_value = mock_tool_file_instance
|
||||
|
||||
# Mock file creation
|
||||
mock_tool_file_instance = mock_tool_file_manager.return_value # Mock file creation
|
||||
mock_tool_file = MagicMock()
|
||||
mock_tool_file.id = "test_file_id"
|
||||
mock_tool_file_instance.create_file_by_raw.return_value = mock_tool_file
|
||||
@ -285,15 +284,12 @@ class TestWebhookServiceUnit:
|
||||
assert mock_tool_file_manager.call_count == 2
|
||||
assert mock_file_factory.build_from_mapping.call_count == 2
|
||||
|
||||
@patch("services.trigger.webhook_service.ToolFileManager")
|
||||
@patch("services.trigger.webhook_service.file_factory")
|
||||
@patch("services.trigger.webhook_service.ToolFileManager", autospec=True)
|
||||
@patch("services.trigger.webhook_service.file_factory", autospec=True)
|
||||
def test_process_file_uploads_with_errors(self, mock_file_factory, mock_tool_file_manager):
|
||||
"""Test file upload processing with errors."""
|
||||
# Mock ToolFileManager
|
||||
mock_tool_file_instance = MagicMock()
|
||||
mock_tool_file_manager.return_value = mock_tool_file_instance
|
||||
|
||||
# Mock file creation
|
||||
mock_tool_file_instance = mock_tool_file_manager.return_value # Mock file creation
|
||||
mock_tool_file = MagicMock()
|
||||
mock_tool_file.id = "test_file_id"
|
||||
mock_tool_file_instance.create_file_by_raw.return_value = mock_tool_file
|
||||
@ -544,8 +540,8 @@ class TestWebhookServiceUnit:
|
||||
|
||||
# Mock the WebhookService methods
|
||||
with (
|
||||
patch.object(WebhookService, "get_webhook_trigger_and_workflow") as mock_get_trigger,
|
||||
patch.object(WebhookService, "extract_and_validate_webhook_data") as mock_extract,
|
||||
patch.object(WebhookService, "get_webhook_trigger_and_workflow", autospec=True) as mock_get_trigger,
|
||||
patch.object(WebhookService, "extract_and_validate_webhook_data", autospec=True) as mock_extract,
|
||||
):
|
||||
mock_trigger = MagicMock()
|
||||
mock_workflow = MagicMock()
|
||||
|
||||
@ -124,7 +124,7 @@ class TestWorkflowRunService:
|
||||
"""Create WorkflowRunService instance with mocked dependencies."""
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory", autospec=True) as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
service = WorkflowRunService(session_factory)
|
||||
return service
|
||||
@ -135,7 +135,7 @@ class TestWorkflowRunService:
|
||||
mock_engine = create_autospec(Engine)
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory", autospec=True) as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
service = WorkflowRunService(mock_engine)
|
||||
return service
|
||||
@ -146,7 +146,7 @@ class TestWorkflowRunService:
|
||||
"""Test WorkflowRunService initialization with session_factory."""
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory", autospec=True) as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
service = WorkflowRunService(session_factory)
|
||||
|
||||
@ -158,9 +158,11 @@ class TestWorkflowRunService:
|
||||
mock_engine = create_autospec(Engine)
|
||||
session_factory, _ = mock_session_factory
|
||||
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
|
||||
with patch("services.workflow_run_service.DifyAPIRepositoryFactory", autospec=True) as mock_factory:
|
||||
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
|
||||
with patch("services.workflow_run_service.sessionmaker", return_value=session_factory) as mock_sessionmaker:
|
||||
with patch(
|
||||
"services.workflow_run_service.sessionmaker", return_value=session_factory, autospec=True
|
||||
) as mock_sessionmaker:
|
||||
service = WorkflowRunService(mock_engine)
|
||||
|
||||
mock_sessionmaker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)
|
||||
|
||||
@ -141,7 +141,7 @@ class TestDraftVariableSaver:
|
||||
|
||||
def test_draft_saver_with_small_variables(self, draft_saver, mock_session):
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
|
||||
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable", autospec=True
|
||||
) as _mock_try_offload:
|
||||
_mock_try_offload.return_value = None
|
||||
mock_segment = StringSegment(value="small value")
|
||||
@ -153,7 +153,7 @@ class TestDraftVariableSaver:
|
||||
|
||||
def test_draft_saver_with_large_variables(self, draft_saver, mock_session):
|
||||
with patch(
|
||||
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
|
||||
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable", autospec=True
|
||||
) as _mock_try_offload:
|
||||
mock_segment = StringSegment(value="small value")
|
||||
mock_draft_var_file = WorkflowDraftVariableFile(
|
||||
@ -170,7 +170,7 @@ class TestDraftVariableSaver:
|
||||
# Should not have large variable metadata
|
||||
assert draft_var.file_id == mock_draft_var_file.id
|
||||
|
||||
@patch("services.workflow_draft_variable_service._batch_upsert_draft_variable")
|
||||
@patch("services.workflow_draft_variable_service._batch_upsert_draft_variable", autospec=True)
|
||||
def test_save_method_integration(self, mock_batch_upsert, draft_saver):
|
||||
"""Test complete save workflow."""
|
||||
outputs = {"result": {"data": "test_output"}, "metadata": {"type": "llm_response"}}
|
||||
@ -222,7 +222,7 @@ class TestWorkflowDraftVariableService:
|
||||
name="test_var",
|
||||
value=StringSegment(value="reset_value"),
|
||||
)
|
||||
with patch.object(service, "_reset_conv_var", return_value=expected_result) as mock_reset_conv:
|
||||
with patch.object(service, "_reset_conv_var", return_value=expected_result, autospec=True) as mock_reset_conv:
|
||||
result = service.reset_variable(workflow, variable)
|
||||
|
||||
mock_reset_conv.assert_called_once_with(workflow, variable)
|
||||
@ -330,8 +330,8 @@ class TestWorkflowDraftVariableService:
|
||||
# Mock workflow methods
|
||||
mock_node_config = {"type": "test_node"}
|
||||
with (
|
||||
patch.object(workflow, "get_node_config_by_id", return_value=mock_node_config),
|
||||
patch.object(workflow, "get_node_type_from_node_config", return_value=NodeType.LLM),
|
||||
patch.object(workflow, "get_node_config_by_id", return_value=mock_node_config, autospec=True),
|
||||
patch.object(workflow, "get_node_type_from_node_config", return_value=NodeType.LLM, autospec=True),
|
||||
):
|
||||
result = service._reset_node_var_or_sys_var(workflow, variable)
|
||||
|
||||
|
||||
@ -50,7 +50,7 @@ def pipeline_id():
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session via session_factory.create_session()."""
|
||||
with patch("tasks.clean_dataset_task.session_factory") as mock_sf:
|
||||
with patch("tasks.clean_dataset_task.session_factory", autospec=True) as mock_sf:
|
||||
mock_session = MagicMock()
|
||||
# context manager for create_session()
|
||||
cm = MagicMock()
|
||||
@ -79,7 +79,7 @@ def mock_db_session():
|
||||
@pytest.fixture
|
||||
def mock_storage():
|
||||
"""Mock storage client."""
|
||||
with patch("tasks.clean_dataset_task.storage") as mock_storage:
|
||||
with patch("tasks.clean_dataset_task.storage", autospec=True) as mock_storage:
|
||||
mock_storage.delete.return_value = None
|
||||
yield mock_storage
|
||||
|
||||
@ -87,7 +87,7 @@ def mock_storage():
|
||||
@pytest.fixture
|
||||
def mock_index_processor_factory():
|
||||
"""Mock IndexProcessorFactory."""
|
||||
with patch("tasks.clean_dataset_task.IndexProcessorFactory") as mock_factory:
|
||||
with patch("tasks.clean_dataset_task.IndexProcessorFactory", autospec=True) as mock_factory:
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.clean.return_value = None
|
||||
mock_factory_instance = MagicMock()
|
||||
@ -104,7 +104,7 @@ def mock_index_processor_factory():
|
||||
@pytest.fixture
|
||||
def mock_get_image_upload_file_ids():
|
||||
"""Mock get_image_upload_file_ids function."""
|
||||
with patch("tasks.clean_dataset_task.get_image_upload_file_ids") as mock_func:
|
||||
with patch("tasks.clean_dataset_task.get_image_upload_file_ids", autospec=True) as mock_func:
|
||||
mock_func.return_value = []
|
||||
yield mock_func
|
||||
|
||||
|
||||
@ -75,7 +75,7 @@ def mock_document(document_id, dataset_id, notion_workspace_id, notion_page_id,
|
||||
@pytest.fixture
|
||||
def mock_db_session(mock_document, mock_dataset):
|
||||
"""Mock session_factory.create_session to drive deterministic read-only task flow."""
|
||||
with patch("tasks.document_indexing_sync_task.session_factory") as mock_session_factory:
|
||||
with patch("tasks.document_indexing_sync_task.session_factory", autospec=True) as mock_session_factory:
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
||||
@ -96,7 +96,7 @@ def mock_db_session(mock_document, mock_dataset):
|
||||
@pytest.fixture
|
||||
def mock_datasource_provider_service():
|
||||
"""Mock datasource credential provider."""
|
||||
with patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_service_class:
|
||||
with patch("tasks.document_indexing_sync_task.DatasourceProviderService", autospec=True) as mock_service_class:
|
||||
mock_service = MagicMock()
|
||||
mock_service.get_datasource_credentials.return_value = {"integration_secret": "test_token"}
|
||||
mock_service_class.return_value = mock_service
|
||||
@ -106,7 +106,7 @@ def mock_datasource_provider_service():
|
||||
@pytest.fixture
|
||||
def mock_notion_extractor():
|
||||
"""Mock notion extractor class and instance."""
|
||||
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
|
||||
with patch("tasks.document_indexing_sync_task.NotionExtractor", autospec=True) as mock_extractor_class:
|
||||
mock_extractor = MagicMock()
|
||||
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||
mock_extractor_class.return_value = mock_extractor
|
||||
|
||||
@ -95,7 +95,7 @@ def mock_document_segments(document_ids):
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Mock database session via session_factory.create_session()."""
|
||||
with patch("tasks.duplicate_document_indexing_task.session_factory") as mock_sf:
|
||||
with patch("tasks.duplicate_document_indexing_task.session_factory", autospec=True) as mock_sf:
|
||||
session = MagicMock()
|
||||
# Allow tests to observe session.close() via context manager teardown
|
||||
session.close = MagicMock()
|
||||
@ -118,7 +118,7 @@ def mock_db_session():
|
||||
@pytest.fixture
|
||||
def mock_indexing_runner():
|
||||
"""Mock IndexingRunner."""
|
||||
with patch("tasks.duplicate_document_indexing_task.IndexingRunner") as mock_runner_class:
|
||||
with patch("tasks.duplicate_document_indexing_task.IndexingRunner", autospec=True) as mock_runner_class:
|
||||
mock_runner = MagicMock(spec=IndexingRunner)
|
||||
mock_runner_class.return_value = mock_runner
|
||||
yield mock_runner
|
||||
@ -127,7 +127,7 @@ def mock_indexing_runner():
|
||||
@pytest.fixture
|
||||
def mock_feature_service():
|
||||
"""Mock FeatureService."""
|
||||
with patch("tasks.duplicate_document_indexing_task.FeatureService") as mock_service:
|
||||
with patch("tasks.duplicate_document_indexing_task.FeatureService", autospec=True) as mock_service:
|
||||
mock_features = Mock()
|
||||
mock_features.billing = Mock()
|
||||
mock_features.billing.enabled = False
|
||||
@ -141,7 +141,7 @@ def mock_feature_service():
|
||||
@pytest.fixture
|
||||
def mock_index_processor_factory():
|
||||
"""Mock IndexProcessorFactory."""
|
||||
with patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory") as mock_factory:
|
||||
with patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory", autospec=True) as mock_factory:
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.clean = Mock()
|
||||
mock_factory.return_value.init_index_processor.return_value = mock_processor
|
||||
@ -151,7 +151,7 @@ def mock_index_processor_factory():
|
||||
@pytest.fixture
|
||||
def mock_tenant_isolated_queue():
|
||||
"""Mock TenantIsolatedTaskQueue."""
|
||||
with patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") as mock_queue_class:
|
||||
with patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) as mock_queue_class:
|
||||
mock_queue = MagicMock(spec=TenantIsolatedTaskQueue)
|
||||
mock_queue.pull_tasks.return_value = []
|
||||
mock_queue.delete_task_key = Mock()
|
||||
@ -168,7 +168,7 @@ def mock_tenant_isolated_queue():
|
||||
class TestDuplicateDocumentIndexingTask:
|
||||
"""Tests for the deprecated duplicate_document_indexing_task function."""
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task", autospec=True)
|
||||
def test_duplicate_document_indexing_task_calls_core_function(self, mock_core_func, dataset_id, document_ids):
|
||||
"""Test that duplicate_document_indexing_task calls the core _duplicate_document_indexing_task function."""
|
||||
# Act
|
||||
@ -177,7 +177,7 @@ class TestDuplicateDocumentIndexingTask:
|
||||
# Assert
|
||||
mock_core_func.assert_called_once_with(dataset_id, document_ids)
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task", autospec=True)
|
||||
def test_duplicate_document_indexing_task_with_empty_document_ids(self, mock_core_func, dataset_id):
|
||||
"""Test duplicate_document_indexing_task with empty document_ids list."""
|
||||
# Arrange
|
||||
@ -445,7 +445,7 @@ class TestDuplicateDocumentIndexingTaskCore:
|
||||
class TestDuplicateDocumentIndexingTaskWithTenantQueue:
|
||||
"""Tests for _duplicate_document_indexing_task_with_tenant_queue function."""
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task", autospec=True)
|
||||
def test_tenant_queue_wrapper_calls_core_function(
|
||||
self,
|
||||
mock_core_func,
|
||||
@ -464,7 +464,7 @@ class TestDuplicateDocumentIndexingTaskWithTenantQueue:
|
||||
# Assert
|
||||
mock_core_func.assert_called_once_with(dataset_id, document_ids)
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task", autospec=True)
|
||||
def test_tenant_queue_wrapper_deletes_key_when_no_tasks(
|
||||
self,
|
||||
mock_core_func,
|
||||
@ -484,7 +484,7 @@ class TestDuplicateDocumentIndexingTaskWithTenantQueue:
|
||||
# Assert
|
||||
mock_tenant_isolated_queue.delete_task_key.assert_called_once()
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task", autospec=True)
|
||||
def test_tenant_queue_wrapper_processes_next_tasks(
|
||||
self,
|
||||
mock_core_func,
|
||||
@ -514,7 +514,7 @@ class TestDuplicateDocumentIndexingTaskWithTenantQueue:
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task", autospec=True)
|
||||
def test_tenant_queue_wrapper_handles_core_function_error(
|
||||
self,
|
||||
mock_core_func,
|
||||
@ -544,7 +544,7 @@ class TestDuplicateDocumentIndexingTaskWithTenantQueue:
|
||||
class TestNormalDuplicateDocumentIndexingTask:
|
||||
"""Tests for normal_duplicate_document_indexing_task function."""
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue", autospec=True)
|
||||
def test_normal_task_calls_tenant_queue_wrapper(
|
||||
self,
|
||||
mock_wrapper_func,
|
||||
@ -561,7 +561,7 @@ class TestNormalDuplicateDocumentIndexingTask:
|
||||
tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task
|
||||
)
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue", autospec=True)
|
||||
def test_normal_task_with_empty_document_ids(
|
||||
self,
|
||||
mock_wrapper_func,
|
||||
@ -589,7 +589,7 @@ class TestNormalDuplicateDocumentIndexingTask:
|
||||
class TestPriorityDuplicateDocumentIndexingTask:
|
||||
"""Tests for priority_duplicate_document_indexing_task function."""
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue", autospec=True)
|
||||
def test_priority_task_calls_tenant_queue_wrapper(
|
||||
self,
|
||||
mock_wrapper_func,
|
||||
@ -606,7 +606,7 @@ class TestPriorityDuplicateDocumentIndexingTask:
|
||||
tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task
|
||||
)
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue", autospec=True)
|
||||
def test_priority_task_with_single_document(
|
||||
self,
|
||||
mock_wrapper_func,
|
||||
@ -625,7 +625,7 @@ class TestPriorityDuplicateDocumentIndexingTask:
|
||||
tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task
|
||||
)
|
||||
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
|
||||
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue", autospec=True)
|
||||
def test_priority_task_with_large_batch(
|
||||
self,
|
||||
mock_wrapper_func,
|
||||
|
||||
@ -321,7 +321,9 @@ def test_structured_output_parser():
|
||||
)
|
||||
else:
|
||||
# Test successful cases
|
||||
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
|
||||
with patch(
|
||||
"core.llm_generator.output_parser.structured_output.json_repair.loads", autospec=True
|
||||
) as mock_json_repair:
|
||||
# Configure json_repair mock for cases that need it
|
||||
if case["name"] == "json_repair_scenario":
|
||||
mock_json_repair.return_value = {"name": "test"}
|
||||
@ -402,7 +404,9 @@ def test_parse_structured_output_edge_cases():
|
||||
|
||||
prompt_messages = [UserPromptMessage(content="Test reasoning")]
|
||||
|
||||
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
|
||||
with patch(
|
||||
"core.llm_generator.output_parser.structured_output.json_repair.loads", autospec=True
|
||||
) as mock_json_repair:
|
||||
# Mock json_repair to return a list with dict
|
||||
mock_json_repair.return_value = [{"thought": "reasoning process"}, "other content"]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user