test(api): add autospec to MagicMock-based patch usage (#32752)

This commit is contained in:
-LAN- 2026-03-01 04:30:45 +08:00 committed by GitHub
parent c034eb036c
commit 20fcc95db9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
86 changed files with 865 additions and 804 deletions

View File

@ -360,7 +360,7 @@ class TestEndToEndCacheFlow:
class TestRedisFailover:
"""Test behavior when Redis is unavailable."""
@patch("services.api_token_service.redis_client")
@patch("services.api_token_service.redis_client", autospec=True)
def test_graceful_degradation_when_redis_fails(self, mock_redis):
"""Test system degrades gracefully when Redis is unavailable."""
from redis import RedisError

View File

@ -41,17 +41,15 @@ class TestOpenSearchConfig:
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
assert params["http_auth"] == ("admin", "password")
@patch("boto3.Session")
@patch("core.rag.datasource.vdb.opensearch.opensearch_vector.Urllib3AWSV4SignerAuth")
@patch("boto3.Session", autospec=True)
@patch("core.rag.datasource.vdb.opensearch.opensearch_vector.Urllib3AWSV4SignerAuth", autospec=True)
def test_to_opensearch_params_with_aws_managed_iam(
self, mock_aws_signer_auth: MagicMock, mock_boto_session: MagicMock
):
mock_credentials = MagicMock()
mock_boto_session.return_value.get_credentials.return_value = mock_credentials
mock_auth_instance = MagicMock()
mock_aws_signer_auth.return_value = mock_auth_instance
mock_auth_instance = mock_aws_signer_auth.return_value
aws_region = "ap-southeast-2"
aws_service = "aoss"
host = f"aoss-endpoint.{aws_region}.aoss.amazonaws.com"
@ -157,7 +155,7 @@ class TestOpenSearchVector:
doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id})
embedding = [0.1] * 128
with patch("opensearchpy.helpers.bulk") as mock_bulk:
with patch("opensearchpy.helpers.bulk", autospec=True) as mock_bulk:
mock_bulk.return_value = ([], [])
self.vector.add_texts([doc], [embedding])
@ -171,7 +169,7 @@ class TestOpenSearchVector:
doc = Document(page_content="Test content", metadata={"document_id": self.example_doc_id})
embedding = [0.1] * 128
with patch("opensearchpy.helpers.bulk") as mock_bulk:
with patch("opensearchpy.helpers.bulk", autospec=True) as mock_bulk:
mock_bulk.return_value = ([], [])
self.vector.add_texts([doc], [embedding])

View File

@ -19,14 +19,14 @@ class TestAgentService:
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.agent_service.PluginAgentClient") as mock_plugin_agent_client,
patch("services.agent_service.ToolManager") as mock_tool_manager,
patch("services.agent_service.AgentConfigManager") as mock_agent_config_manager,
patch("services.agent_service.PluginAgentClient", autospec=True) as mock_plugin_agent_client,
patch("services.agent_service.ToolManager", autospec=True) as mock_tool_manager,
patch("services.agent_service.AgentConfigManager", autospec=True) as mock_agent_config_manager,
patch("services.agent_service.current_user", create_autospec(Account, instance=True)) as mock_current_user,
patch("services.app_service.FeatureService") as mock_feature_service,
patch("services.app_service.EnterpriseService") as mock_enterprise_service,
patch("services.app_service.ModelManager") as mock_model_manager,
patch("services.account_service.FeatureService") as mock_account_feature_service,
patch("services.app_service.FeatureService", autospec=True) as mock_feature_service,
patch("services.app_service.EnterpriseService", autospec=True) as mock_enterprise_service,
patch("services.app_service.ModelManager", autospec=True) as mock_model_manager,
patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service,
):
# Setup default mock returns for agent service
mock_plugin_agent_client_instance = mock_plugin_agent_client.return_value

View File

@ -18,18 +18,22 @@ class TestAppGenerateService:
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.billing_service.BillingService") as mock_billing_service,
patch("services.app_generate_service.WorkflowService") as mock_workflow_service,
patch("services.app_generate_service.RateLimit") as mock_rate_limit,
patch("services.app_generate_service.CompletionAppGenerator") as mock_completion_generator,
patch("services.app_generate_service.ChatAppGenerator") as mock_chat_generator,
patch("services.app_generate_service.AgentChatAppGenerator") as mock_agent_chat_generator,
patch("services.app_generate_service.AdvancedChatAppGenerator") as mock_advanced_chat_generator,
patch("services.app_generate_service.WorkflowAppGenerator") as mock_workflow_generator,
patch("services.app_generate_service.MessageBasedAppGenerator") as mock_message_based_generator,
patch("services.account_service.FeatureService") as mock_account_feature_service,
patch("services.app_generate_service.dify_config") as mock_dify_config,
patch("configs.dify_config") as mock_global_dify_config,
patch("services.billing_service.BillingService", autospec=True) as mock_billing_service,
patch("services.app_generate_service.WorkflowService", autospec=True) as mock_workflow_service,
patch("services.app_generate_service.RateLimit", autospec=True) as mock_rate_limit,
patch("services.app_generate_service.CompletionAppGenerator", autospec=True) as mock_completion_generator,
patch("services.app_generate_service.ChatAppGenerator", autospec=True) as mock_chat_generator,
patch("services.app_generate_service.AgentChatAppGenerator", autospec=True) as mock_agent_chat_generator,
patch(
"services.app_generate_service.AdvancedChatAppGenerator", autospec=True
) as mock_advanced_chat_generator,
patch("services.app_generate_service.WorkflowAppGenerator", autospec=True) as mock_workflow_generator,
patch(
"services.app_generate_service.MessageBasedAppGenerator", autospec=True
) as mock_message_based_generator,
patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service,
patch("services.app_generate_service.dify_config", autospec=True) as mock_dify_config,
patch("configs.dify_config", autospec=True) as mock_global_dify_config,
):
# Setup default mock returns for billing service
mock_billing_service.update_tenant_feature_plan_usage.return_value = {
@ -983,7 +987,7 @@ class TestAppGenerateService:
}
# Execute the method under test
with patch("services.app_generate_service.AppExecutionParams") as mock_exec_params:
with patch("services.app_generate_service.AppExecutionParams", autospec=True) as mock_exec_params:
mock_payload = MagicMock()
mock_payload.workflow_run_id = fake.uuid4()
mock_payload.model_dump_json.return_value = "{}"

View File

@ -17,10 +17,12 @@ class TestModelLoadBalancingService:
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.model_load_balancing_service.ProviderManager") as mock_provider_manager,
patch("services.model_load_balancing_service.LBModelManager") as mock_lb_model_manager,
patch("services.model_load_balancing_service.ModelProviderFactory") as mock_model_provider_factory,
patch("services.model_load_balancing_service.encrypter") as mock_encrypter,
patch("services.model_load_balancing_service.ProviderManager", autospec=True) as mock_provider_manager,
patch("services.model_load_balancing_service.LBModelManager", autospec=True) as mock_lb_model_manager,
patch(
"services.model_load_balancing_service.ModelProviderFactory", autospec=True
) as mock_model_provider_factory,
patch("services.model_load_balancing_service.encrypter", autospec=True) as mock_encrypter,
):
# Setup default mock returns
mock_provider_manager_instance = mock_provider_manager.return_value

View File

@ -17,8 +17,8 @@ class TestModelProviderService:
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("services.model_provider_service.ProviderManager") as mock_provider_manager,
patch("services.model_provider_service.ModelProviderFactory") as mock_model_provider_factory,
patch("services.model_provider_service.ProviderManager", autospec=True) as mock_provider_manager,
patch("services.model_provider_service.ModelProviderFactory", autospec=True) as mock_model_provider_factory,
):
# Setup default mock returns
mock_provider_manager.return_value.get_configurations.return_value = MagicMock()
@ -526,7 +526,9 @@ class TestModelProviderService:
# Act: Execute the method under test
service = ModelProviderService()
with patch.object(service, "get_provider_credential", return_value=expected_credentials) as mock_method:
with patch.object(
service, "get_provider_credential", return_value=expected_credentials, autospec=True
) as mock_method:
result = service.get_provider_credential(tenant.id, "openai")
# Assert: Verify the expected outcomes
@ -854,7 +856,9 @@ class TestModelProviderService:
# Act: Execute the method under test
service = ModelProviderService()
with patch.object(service, "get_model_credential", return_value=expected_credentials) as mock_method:
with patch.object(
service, "get_model_credential", return_value=expected_credentials, autospec=True
) as mock_method:
result = service.get_model_credential(tenant.id, "openai", "llm", "gpt-4", None)
# Assert: Verify the expected outcomes

View File

@ -22,16 +22,13 @@ class TestWebhookService:
def mock_external_dependencies(self):
"""Mock external service dependencies."""
with (
patch("services.trigger.webhook_service.AsyncWorkflowService") as mock_async_service,
patch("services.trigger.webhook_service.ToolFileManager") as mock_tool_file_manager,
patch("services.trigger.webhook_service.file_factory") as mock_file_factory,
patch("services.account_service.FeatureService") as mock_feature_service,
patch("services.trigger.webhook_service.AsyncWorkflowService", autospec=True) as mock_async_service,
patch("services.trigger.webhook_service.ToolFileManager", autospec=True) as mock_tool_file_manager,
patch("services.trigger.webhook_service.file_factory", autospec=True) as mock_file_factory,
patch("services.account_service.FeatureService", autospec=True) as mock_feature_service,
):
# Mock ToolFileManager
mock_tool_file_instance = MagicMock()
mock_tool_file_manager.return_value = mock_tool_file_instance
# Mock file creation
mock_tool_file_instance = mock_tool_file_manager.return_value # Mock file creation
mock_tool_file = MagicMock()
mock_tool_file.id = "test_file_id"
mock_tool_file_instance.create_file_by_raw.return_value = mock_tool_file
@ -435,12 +432,12 @@ class TestWebhookService:
with flask_app_with_containers.app_context():
# Mock tenant owner lookup to return the test account
with patch("services.trigger.webhook_service.select") as mock_select:
with patch("services.trigger.webhook_service.select", autospec=True) as mock_select:
mock_query = MagicMock()
mock_select.return_value.join.return_value.where.return_value = mock_query
# Mock the session to return our test account
with patch("services.trigger.webhook_service.Session") as mock_session:
with patch("services.trigger.webhook_service.Session", autospec=True) as mock_session:
mock_session_instance = MagicMock()
mock_session.return_value.__enter__.return_value = mock_session_instance
mock_session_instance.scalar.return_value = test_data["account"]
@ -462,7 +459,7 @@ class TestWebhookService:
with flask_app_with_containers.app_context():
# Mock EndUserService to raise an exception
with patch(
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type"
"services.trigger.webhook_service.EndUserService.get_or_create_end_user_by_type", autospec=True
) as mock_end_user:
mock_end_user.side_effect = ValueError("Failed to create end user")

View File

@ -764,7 +764,7 @@ class TestWorkflowService:
# Act - Mock current_user context and pass session
from unittest.mock import patch
with patch("flask_login.utils._get_user", return_value=account):
with patch("flask_login.utils._get_user", return_value=account, autospec=True):
result = workflow_service.publish_workflow(
session=db_session_with_containers, app_model=app, account=account
)
@ -1401,6 +1401,7 @@ class TestWorkflowService:
DifyNodeFactory,
"_build_model_instance_for_llm_node",
return_value=MagicMock(spec=ModelInstance),
autospec=True,
):
result = workflow_service.run_free_workflow_node(
node_data=node_data, tenant_id=tenant_id, user_id=user_id, node_id=node_id, user_inputs=user_inputs

View File

@ -18,7 +18,9 @@ class TestAddDocumentToIndexTask:
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("tasks.add_document_to_index_task.IndexProcessorFactory") as mock_index_processor_factory,
patch(
"tasks.add_document_to_index_task.IndexProcessorFactory", autospec=True
) as mock_index_processor_factory,
):
# Setup mock index processor
mock_processor = MagicMock()
@ -378,7 +380,7 @@ class TestAddDocumentToIndexTask:
redis_client.set(indexing_cache_key, "processing", ex=300)
# Mock the get_child_chunks method for each segment
with patch.object(DocumentSegment, "get_child_chunks") as mock_get_child_chunks:
with patch.object(DocumentSegment, "get_child_chunks", autospec=True) as mock_get_child_chunks:
# Setup mock to return child chunks for each segment
mock_child_chunks = []
for i in range(2): # Each segment has 2 child chunks

View File

@ -51,9 +51,9 @@ class TestBatchCreateSegmentToIndexTask:
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("tasks.batch_create_segment_to_index_task.storage") as mock_storage,
patch("tasks.batch_create_segment_to_index_task.ModelManager") as mock_model_manager,
patch("tasks.batch_create_segment_to_index_task.VectorService") as mock_vector_service,
patch("tasks.batch_create_segment_to_index_task.storage", autospec=True) as mock_storage,
patch("tasks.batch_create_segment_to_index_task.ModelManager", autospec=True) as mock_model_manager,
patch("tasks.batch_create_segment_to_index_task.VectorService", autospec=True) as mock_vector_service,
):
# Setup default mock returns
mock_storage.download.return_value = None

View File

@ -63,8 +63,8 @@ class TestCleanDatasetTask:
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("tasks.clean_dataset_task.storage") as mock_storage,
patch("tasks.clean_dataset_task.IndexProcessorFactory") as mock_index_processor_factory,
patch("tasks.clean_dataset_task.storage", autospec=True) as mock_storage,
patch("tasks.clean_dataset_task.IndexProcessorFactory", autospec=True) as mock_index_processor_factory,
):
# Setup default mock returns
mock_storage.delete.return_value = None
@ -597,7 +597,7 @@ class TestCleanDatasetTask:
db_session_with_containers.commit()
# Mock the get_image_upload_file_ids function to return our image file IDs
with patch("tasks.clean_dataset_task.get_image_upload_file_ids") as mock_get_image_ids:
with patch("tasks.clean_dataset_task.get_image_upload_file_ids", autospec=True) as mock_get_image_ids:
mock_get_image_ids.return_value = [f.id for f in image_files]
# Execute the task

View File

@ -41,7 +41,7 @@ class TestCreateSegmentToIndexTask:
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("tasks.create_segment_to_index_task.IndexProcessorFactory") as mock_factory,
patch("tasks.create_segment_to_index_task.IndexProcessorFactory", autospec=True) as mock_factory,
):
# Setup default mock returns
mock_processor = MagicMock()
@ -708,7 +708,7 @@ class TestCreateSegmentToIndexTask:
redis_client.set(cache_key, "processing", ex=300)
# Mock Redis to raise exception in finally block
with patch.object(redis_client, "delete", side_effect=Exception("Redis connection failed")):
with patch.object(redis_client, "delete", side_effect=Exception("Redis connection failed"), autospec=True):
# Act: Execute the task - Redis failure should not prevent completion
with pytest.raises(Exception) as exc_info:
create_segment_to_index_task(segment.id)

View File

@ -37,7 +37,7 @@ class _TrackedSessionContext:
self._closed_sessions.append(self._session)
return original_close(*args, **kwargs)
self._close_patcher = patch.object(self._session, "close", side_effect=_tracked_close)
self._close_patcher = patch.object(self._session, "close", side_effect=_tracked_close, autospec=True)
self._close_patcher.start()
return self._session
@ -69,7 +69,9 @@ def session_close_tracker():
original_context_manager = original_create_session(*args, **kwargs)
return _TrackedSessionContext(original_context_manager, opened_sessions, closed_sessions)
with patch.object(task_module.session_factory, "create_session", side_effect=_tracked_create_session):
with patch.object(
task_module.session_factory, "create_session", side_effect=_tracked_create_session, autospec=True
):
yield {"opened_sessions": opened_sessions, "closed_sessions": closed_sessions}
@ -77,13 +79,11 @@ def session_close_tracker():
def patched_external_dependencies():
"""Patch non-DB collaborators while keeping database behavior real."""
with (
patch("tasks.document_indexing_task.IndexingRunner") as mock_indexing_runner,
patch("tasks.document_indexing_task.FeatureService") as mock_feature_service,
patch("tasks.document_indexing_task.generate_summary_index_task") as mock_summary_task,
patch("tasks.document_indexing_task.IndexingRunner", autospec=True) as mock_indexing_runner,
patch("tasks.document_indexing_task.FeatureService", autospec=True) as mock_feature_service,
patch("tasks.document_indexing_task.generate_summary_index_task", autospec=True) as mock_summary_task,
):
mock_runner_instance = MagicMock()
mock_indexing_runner.return_value = mock_runner_instance
mock_runner_instance = mock_indexing_runner.return_value
mock_features = MagicMock()
mock_features.billing.enabled = False
mock_features.billing.subscription.plan = CloudPlan.PROFESSIONAL
@ -307,9 +307,17 @@ class TestDatasetIndexingTaskIntegration:
# Act
with (
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[next_task]),
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time") as set_waiting_spy,
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key") as delete_key_spy,
patch(
"tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks",
return_value=[next_task],
autospec=True,
),
patch(
"tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time", autospec=True
) as set_waiting_spy,
patch(
"tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key", autospec=True
) as delete_key_spy,
):
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
@ -336,8 +344,10 @@ class TestDatasetIndexingTaskIntegration:
# Act
with (
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[]),
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key") as delete_key_spy,
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[], autospec=True),
patch(
"tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key", autospec=True
) as delete_key_spy,
):
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
@ -426,9 +436,13 @@ class TestDatasetIndexingTaskIntegration:
# Act
with (
patch("tasks.document_indexing_task._document_indexing", side_effect=Exception("failed")),
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[next_task]),
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time"),
patch("tasks.document_indexing_task._document_indexing", side_effect=Exception("failed"), autospec=True),
patch(
"tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks",
return_value=[next_task],
autospec=True,
),
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time", autospec=True),
):
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
@ -511,8 +525,11 @@ class TestDatasetIndexingTaskIntegration:
patch(
"tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks",
return_value=pending_tasks[:concurrency_limit],
autospec=True,
),
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time") as set_waiting_spy,
patch(
"tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time", autospec=True
) as set_waiting_spy,
):
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
@ -538,8 +555,12 @@ class TestDatasetIndexingTaskIntegration:
# Act
with (
patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", 3),
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=ordered_tasks),
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time"),
patch(
"tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks",
return_value=ordered_tasks,
autospec=True,
),
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.set_task_waiting_time", autospec=True),
):
_document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy)
@ -578,8 +599,10 @@ class TestDatasetIndexingTaskIntegration:
# Act
with (
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[]),
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key") as delete_key_spy,
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[], autospec=True),
patch(
"tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key", autospec=True
) as delete_key_spy,
):
normal_document_indexing_task(dataset.tenant_id, dataset.id, document_ids)
@ -599,8 +622,10 @@ class TestDatasetIndexingTaskIntegration:
# Act
with (
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[]),
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key") as delete_key_spy,
patch("tasks.document_indexing_task.TenantIsolatedTaskQueue.pull_tasks", return_value=[], autospec=True),
patch(
"tasks.document_indexing_task.TenantIsolatedTaskQueue.delete_task_key", autospec=True
) as delete_key_spy,
):
priority_document_indexing_task(dataset.tenant_id, dataset.id, document_ids)

View File

@ -216,7 +216,7 @@ class TestDeleteSegmentFromIndexTask:
db_session_with_containers.commit()
return segments
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory")
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory", autospec=True)
def test_delete_segment_from_index_task_success(self, mock_index_processor_factory, db_session_with_containers):
"""
Test successful segment deletion from index with comprehensive verification.
@ -399,7 +399,7 @@ class TestDeleteSegmentFromIndexTask:
# Verify the task completed without exceptions
assert result is None # Task should return None when indexing is not completed
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory")
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory", autospec=True)
def test_delete_segment_from_index_task_index_processor_clean(
self, mock_index_processor_factory, db_session_with_containers
):
@ -457,7 +457,7 @@ class TestDeleteSegmentFromIndexTask:
mock_index_processor_factory.reset_mock()
mock_processor.reset_mock()
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory")
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory", autospec=True)
def test_delete_segment_from_index_task_exception_handling(
self, mock_index_processor_factory, db_session_with_containers
):
@ -501,7 +501,7 @@ class TestDeleteSegmentFromIndexTask:
assert call_args[1]["with_keywords"] is True
assert call_args[1]["delete_child_chunks"] is True
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory")
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory", autospec=True)
def test_delete_segment_from_index_task_empty_index_node_ids(
self, mock_index_processor_factory, db_session_with_containers
):
@ -543,7 +543,7 @@ class TestDeleteSegmentFromIndexTask:
assert call_args[1]["with_keywords"] is True
assert call_args[1]["delete_child_chunks"] is True
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory")
@patch("tasks.delete_segment_from_index_task.IndexProcessorFactory", autospec=True)
def test_delete_segment_from_index_task_large_index_node_ids(
self, mock_index_processor_factory, db_session_with_containers
):

View File

@ -32,14 +32,11 @@ class TestDocumentIndexingTasks:
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("tasks.document_indexing_task.IndexingRunner") as mock_indexing_runner,
patch("tasks.document_indexing_task.FeatureService") as mock_feature_service,
patch("tasks.document_indexing_task.IndexingRunner", autospec=True) as mock_indexing_runner,
patch("tasks.document_indexing_task.FeatureService", autospec=True) as mock_feature_service,
):
# Setup mock indexing runner
mock_runner_instance = MagicMock()
mock_indexing_runner.return_value = mock_runner_instance
# Setup mock feature service
mock_runner_instance = mock_indexing_runner.return_value # Setup mock feature service
mock_features = MagicMock()
mock_features.billing.enabled = False
mock_feature_service.get_features.return_value = mock_features

View File

@ -16,15 +16,13 @@ class TestDocumentIndexingUpdateTask:
- IndexingRunner.run([...])
"""
with (
patch("tasks.document_indexing_update_task.IndexProcessorFactory") as mock_factory,
patch("tasks.document_indexing_update_task.IndexingRunner") as mock_runner,
patch("tasks.document_indexing_update_task.IndexProcessorFactory", autospec=True) as mock_factory,
patch("tasks.document_indexing_update_task.IndexingRunner", autospec=True) as mock_runner,
):
processor_instance = MagicMock()
mock_factory.return_value.init_index_processor.return_value = processor_instance
runner_instance = MagicMock()
mock_runner.return_value = runner_instance
runner_instance = mock_runner.return_value
yield {
"factory": mock_factory,
"processor": processor_instance,

View File

@ -31,15 +31,14 @@ class TestDuplicateDocumentIndexingTasks:
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("tasks.duplicate_document_indexing_task.IndexingRunner") as mock_indexing_runner,
patch("tasks.duplicate_document_indexing_task.FeatureService") as mock_feature_service,
patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory") as mock_index_processor_factory,
patch("tasks.duplicate_document_indexing_task.IndexingRunner", autospec=True) as mock_indexing_runner,
patch("tasks.duplicate_document_indexing_task.FeatureService", autospec=True) as mock_feature_service,
patch(
"tasks.duplicate_document_indexing_task.IndexProcessorFactory", autospec=True
) as mock_index_processor_factory,
):
# Setup mock indexing runner
mock_runner_instance = MagicMock()
mock_indexing_runner.return_value = mock_runner_instance
# Setup mock feature service
mock_runner_instance = mock_indexing_runner.return_value # Setup mock feature service
mock_features = MagicMock()
mock_features.billing.enabled = False
mock_feature_service.get_features.return_value = mock_features
@ -650,7 +649,7 @@ class TestDuplicateDocumentIndexingTasks:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True)
def test_normal_duplicate_document_indexing_task_with_tenant_queue(
self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies
):
@ -693,7 +692,7 @@ class TestDuplicateDocumentIndexingTasks:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True)
def test_priority_duplicate_document_indexing_task_with_tenant_queue(
self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies
):
@ -737,7 +736,7 @@ class TestDuplicateDocumentIndexingTasks:
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True)
def test_tenant_queue_wrapper_processes_next_tasks(
self, mock_queue_class, db_session_with_containers, mock_external_service_dependencies
):

View File

@ -18,7 +18,9 @@ class TestEnableSegmentsToIndexTask:
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("tasks.enable_segments_to_index_task.IndexProcessorFactory") as mock_index_processor_factory,
patch(
"tasks.enable_segments_to_index_task.IndexProcessorFactory", autospec=True
) as mock_index_processor_factory,
):
# Setup mock index processor
mock_processor = MagicMock()
@ -370,7 +372,7 @@ class TestEnableSegmentsToIndexTask:
redis_client.set(indexing_cache_key, "processing", ex=300)
# Mock the get_child_chunks method for each segment
with patch.object(DocumentSegment, "get_child_chunks") as mock_get_child_chunks:
with patch.object(DocumentSegment, "get_child_chunks", autospec=True) as mock_get_child_chunks:
# Setup mock to return child chunks for each segment
mock_child_chunks = []
for i in range(2): # Each segment has 2 child chunks

View File

@ -1,4 +1,4 @@
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from faker import Faker
@ -16,16 +16,14 @@ class TestMailAccountDeletionTask:
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("tasks.mail_account_deletion_task.mail") as mock_mail,
patch("tasks.mail_account_deletion_task.get_email_i18n_service") as mock_get_email_service,
patch("tasks.mail_account_deletion_task.mail", autospec=True) as mock_mail,
patch("tasks.mail_account_deletion_task.get_email_i18n_service", autospec=True) as mock_get_email_service,
):
# Setup mock mail service
mock_mail.is_inited.return_value = True
# Setup mock email service
mock_email_service = MagicMock()
mock_get_email_service.return_value = mock_email_service
mock_email_service = mock_get_email_service.return_value
yield {
"mail": mock_mail,
"get_email_service": mock_get_email_service,

View File

@ -1,4 +1,4 @@
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from faker import Faker
@ -15,16 +15,14 @@ class TestMailChangeMailTask:
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("tasks.mail_change_mail_task.mail") as mock_mail,
patch("tasks.mail_change_mail_task.get_email_i18n_service") as mock_get_email_i18n_service,
patch("tasks.mail_change_mail_task.mail", autospec=True) as mock_mail,
patch("tasks.mail_change_mail_task.get_email_i18n_service", autospec=True) as mock_get_email_i18n_service,
):
# Setup mock mail service
mock_mail.is_inited.return_value = True
# Setup mock email i18n service
mock_email_service = MagicMock()
mock_get_email_i18n_service.return_value = mock_email_service
mock_email_service = mock_get_email_i18n_service.return_value
yield {
"mail": mock_mail,
"email_i18n_service": mock_email_service,

View File

@ -53,8 +53,8 @@ class TestSendEmailCodeLoginMailTask:
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("tasks.mail_email_code_login.mail") as mock_mail,
patch("tasks.mail_email_code_login.get_email_i18n_service") as mock_email_service,
patch("tasks.mail_email_code_login.mail", autospec=True) as mock_mail,
patch("tasks.mail_email_code_login.get_email_i18n_service", autospec=True) as mock_email_service,
):
# Setup default mock returns
mock_mail.is_inited.return_value = True
@ -573,7 +573,7 @@ class TestSendEmailCodeLoginMailTask:
mock_email_service_instance.send_email.side_effect = exception
# Mock logging to capture error messages
with patch("tasks.mail_email_code_login.logger") as mock_logger:
with patch("tasks.mail_email_code_login.logger", autospec=True) as mock_logger:
# Act: Execute the task - it should handle the exception gracefully
send_email_code_login_mail_task(
language=test_language,

View File

@ -1,4 +1,4 @@
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from faker import Faker
@ -13,18 +13,15 @@ class TestMailInnerTask:
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("tasks.mail_inner_task.mail") as mock_mail,
patch("tasks.mail_inner_task.get_email_i18n_service") as mock_get_email_i18n_service,
patch("tasks.mail_inner_task._render_template_with_strategy") as mock_render_template,
patch("tasks.mail_inner_task.mail", autospec=True) as mock_mail,
patch("tasks.mail_inner_task.get_email_i18n_service", autospec=True) as mock_get_email_i18n_service,
patch("tasks.mail_inner_task._render_template_with_strategy", autospec=True) as mock_render_template,
):
# Setup mock mail service
mock_mail.is_inited.return_value = True
# Setup mock email i18n service
mock_email_service = MagicMock()
mock_get_email_i18n_service.return_value = mock_email_service
# Setup mock template rendering
mock_email_service = mock_get_email_i18n_service.return_value # Setup mock template rendering
mock_render_template.return_value = "<html>Test email content</html>"
yield {

View File

@ -56,9 +56,9 @@ class TestMailInviteMemberTask:
def mock_external_service_dependencies(self):
"""Mock setup for external service dependencies."""
with (
patch("tasks.mail_invite_member_task.mail") as mock_mail,
patch("tasks.mail_invite_member_task.get_email_i18n_service") as mock_email_service,
patch("tasks.mail_invite_member_task.dify_config") as mock_config,
patch("tasks.mail_invite_member_task.mail", autospec=True) as mock_mail,
patch("tasks.mail_invite_member_task.get_email_i18n_service", autospec=True) as mock_email_service,
patch("tasks.mail_invite_member_task.dify_config", autospec=True) as mock_config,
):
# Setup mail service mock
mock_mail.is_inited.return_value = True
@ -306,7 +306,7 @@ class TestMailInviteMemberTask:
mock_email_service.send_email.side_effect = Exception("Email service failed")
# Act & Assert: Execute task and verify exception is handled
with patch("tasks.mail_invite_member_task.logger") as mock_logger:
with patch("tasks.mail_invite_member_task.logger", autospec=True) as mock_logger:
send_invite_member_mail_task(
language="en-US",
to="test@example.com",

View File

@ -7,7 +7,7 @@ testing with actual database and service dependencies.
"""
import logging
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from faker import Faker
@ -30,16 +30,14 @@ class TestMailOwnerTransferTask:
def mock_mail_dependencies(self):
"""Mock setup for mail service dependencies."""
with (
patch("tasks.mail_owner_transfer_task.mail") as mock_mail,
patch("tasks.mail_owner_transfer_task.get_email_i18n_service") as mock_get_email_service,
patch("tasks.mail_owner_transfer_task.mail", autospec=True) as mock_mail,
patch("tasks.mail_owner_transfer_task.get_email_i18n_service", autospec=True) as mock_get_email_service,
):
# Setup mock mail service
mock_mail.is_inited.return_value = True
# Setup mock email service
mock_email_service = MagicMock()
mock_get_email_service.return_value = mock_email_service
mock_email_service = mock_get_email_service.return_value
yield {
"mail": mock_mail,
"email_service": mock_email_service,

View File

@ -5,7 +5,7 @@ This module provides integration tests for email registration tasks
using TestContainers to ensure real database and service interactions.
"""
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from faker import Faker
@ -21,16 +21,14 @@ class TestMailRegisterTask:
def mock_mail_dependencies(self):
"""Mock setup for mail service dependencies."""
with (
patch("tasks.mail_register_task.mail") as mock_mail,
patch("tasks.mail_register_task.get_email_i18n_service") as mock_get_email_service,
patch("tasks.mail_register_task.mail", autospec=True) as mock_mail,
patch("tasks.mail_register_task.get_email_i18n_service", autospec=True) as mock_get_email_service,
):
# Setup mock mail service
mock_mail.is_inited.return_value = True
# Setup mock email i18n service
mock_email_service = MagicMock()
mock_get_email_service.return_value = mock_email_service
mock_email_service = mock_get_email_service.return_value
yield {
"mail": mock_mail,
"email_service": mock_email_service,
@ -76,7 +74,7 @@ class TestMailRegisterTask:
to_email = fake.email()
code = fake.numerify("######")
with patch("tasks.mail_register_task.logger") as mock_logger:
with patch("tasks.mail_register_task.logger", autospec=True) as mock_logger:
send_email_register_mail_task(language="en-US", to=to_email, code=code)
mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", to_email)
@ -89,7 +87,7 @@ class TestMailRegisterTask:
to_email = fake.email()
account_name = fake.name()
with patch("tasks.mail_register_task.dify_config") as mock_config:
with patch("tasks.mail_register_task.dify_config", autospec=True) as mock_config:
mock_config.CONSOLE_WEB_URL = "https://console.dify.ai"
send_email_register_mail_task_when_account_exist(language=language, to=to_email, account_name=account_name)
@ -129,6 +127,6 @@ class TestMailRegisterTask:
to_email = fake.email()
account_name = fake.name()
with patch("tasks.mail_register_task.logger") as mock_logger:
with patch("tasks.mail_register_task.logger", autospec=True) as mock_logger:
send_email_register_mail_task_when_account_exist(language="en-US", to=to_email, account_name=account_name)
mock_logger.exception.assert_called_once_with("Send email register mail to %s failed", to_email)

View File

@ -12,9 +12,17 @@ def test_get_conversation_mark_read_keeps_updated_at_unchanged():
conversation.id = "conversation-id"
with (
patch("controllers.console.app.conversation.current_account_with_tenant", return_value=(account, None)),
patch("controllers.console.app.conversation.naive_utc_now", return_value=datetime(2026, 2, 9, 0, 0, 0)),
patch("controllers.console.app.conversation.db.session") as mock_session,
patch(
"controllers.console.app.conversation.current_account_with_tenant",
return_value=(account, None),
autospec=True,
),
patch(
"controllers.console.app.conversation.naive_utc_now",
return_value=datetime(2026, 2, 9, 0, 0, 0),
autospec=True,
),
patch("controllers.console.app.conversation.db.session", autospec=True) as mock_session,
):
mock_session.query.return_value.where.return_value.first.return_value = conversation

View File

@ -40,7 +40,7 @@ class TestWorkflowDraftVariableFields:
mock_variable.variable_file = mock_variable_file
# Mock the file helpers
with patch("controllers.console.app.workflow_draft_variable.file_helpers") as mock_file_helpers:
with patch("controllers.console.app.workflow_draft_variable.file_helpers", autospec=True) as mock_file_helpers:
mock_file_helpers.get_signed_file_url.return_value = "http://example.com/signed-url"
# Call the function
@ -203,7 +203,7 @@ class TestWorkflowDraftVariableFields:
}
)
with patch("controllers.console.app.workflow_draft_variable.file_helpers") as mock_file_helpers:
with patch("controllers.console.app.workflow_draft_variable.file_helpers", autospec=True) as mock_file_helpers:
mock_file_helpers.get_signed_file_url.return_value = "http://example.com/signed-url"
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
expected_with_value = expected_without_value.copy()

View File

@ -47,8 +47,8 @@ class TestRefreshTokenApi:
token_pair.csrf_token = "new_csrf_token"
return token_pair
@patch("controllers.console.auth.login.extract_refresh_token")
@patch("controllers.console.auth.login.AccountService.refresh_token")
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
def test_successful_token_refresh(self, mock_refresh_token, mock_extract_token, app, mock_token_pair):
"""
Test successful token refresh flow.
@ -73,7 +73,7 @@ class TestRefreshTokenApi:
mock_refresh_token.assert_called_once_with("valid_refresh_token")
assert response.json["result"] == "success"
@patch("controllers.console.auth.login.extract_refresh_token")
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
def test_refresh_fails_without_token(self, mock_extract_token, app):
"""
Test token refresh failure when no refresh token provided.
@ -96,8 +96,8 @@ class TestRefreshTokenApi:
assert response["result"] == "fail"
assert "No refresh token provided" in response["message"]
@patch("controllers.console.auth.login.extract_refresh_token")
@patch("controllers.console.auth.login.AccountService.refresh_token")
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
def test_refresh_fails_with_invalid_token(self, mock_refresh_token, mock_extract_token, app):
"""
Test token refresh failure with invalid refresh token.
@ -121,8 +121,8 @@ class TestRefreshTokenApi:
assert response["result"] == "fail"
assert "Invalid refresh token" in response["message"]
@patch("controllers.console.auth.login.extract_refresh_token")
@patch("controllers.console.auth.login.AccountService.refresh_token")
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
def test_refresh_fails_with_expired_token(self, mock_refresh_token, mock_extract_token, app):
"""
Test token refresh failure with expired refresh token.
@ -146,8 +146,8 @@ class TestRefreshTokenApi:
assert response["result"] == "fail"
assert "expired" in response["message"].lower()
@patch("controllers.console.auth.login.extract_refresh_token")
@patch("controllers.console.auth.login.AccountService.refresh_token")
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
def test_refresh_with_empty_token(self, mock_refresh_token, mock_extract_token, app):
"""
Test token refresh with empty string token.
@ -168,8 +168,8 @@ class TestRefreshTokenApi:
assert status_code == 401
assert response["result"] == "fail"
@patch("controllers.console.auth.login.extract_refresh_token")
@patch("controllers.console.auth.login.AccountService.refresh_token")
@patch("controllers.console.auth.login.extract_refresh_token", autospec=True)
@patch("controllers.console.auth.login.AccountService.refresh_token", autospec=True)
def test_refresh_updates_all_tokens(self, mock_refresh_token, mock_extract_token, app, mock_token_pair):
"""
Test that token refresh updates all three tokens.

View File

@ -39,10 +39,12 @@ def client():
@patch(
"controllers.console.workspace.tool_providers.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")
"controllers.console.workspace.tool_providers.current_account_with_tenant",
return_value=(MagicMock(id="u1"), "t1"),
autospec=True,
)
@patch("controllers.console.workspace.tool_providers.Session")
@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url")
@patch("controllers.console.workspace.tool_providers.Session", autospec=True)
@patch("controllers.console.workspace.tool_providers.MCPToolManageService._reconnect_with_url", autospec=True)
@pytest.mark.usefixtures("_mock_cache", "_mock_user_tenant")
def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_current_account_with_tenant, client):
# Arrange: reconnect returns tools immediately
@ -62,7 +64,7 @@ def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_
svc.get_provider.return_value = MagicMock(id="provider-1", tenant_id="t1") # used by reload path
mock_session.return_value.__enter__.return_value = MagicMock()
# Patch MCPToolManageService constructed inside controller
with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc):
with patch("controllers.console.workspace.tool_providers.MCPToolManageService", return_value=svc, autospec=True):
payload = {
"server_url": "http://example.com/mcp",
"name": "demo",
@ -77,12 +79,19 @@ def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_
# Act
with (
patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"), # bypass setup_required DB check
patch("controllers.console.wraps.current_account_with_tenant", return_value=(MagicMock(id="u1"), "t1")),
patch("libs.login.check_csrf_token", return_value=None), # bypass CSRF in login_required
patch("libs.login._get_user", return_value=MagicMock(id="u1", is_authenticated=True)), # login
patch(
"controllers.console.wraps.current_account_with_tenant",
return_value=(MagicMock(id="u1"), "t1"),
autospec=True,
),
patch("libs.login.check_csrf_token", return_value=None, autospec=True), # bypass CSRF in login_required
patch(
"libs.login._get_user", return_value=MagicMock(id="u1", is_authenticated=True), autospec=True
), # login
patch(
"services.tools.tools_transform_service.ToolTransformService.mcp_provider_to_user_provider",
return_value={"id": "provider-1", "tools": [{"name": "ping"}]},
autospec=True,
),
):
resp = client.post(

View File

@ -77,7 +77,7 @@ class DummyResult:
class TestMCPAppApi:
@patch.object(module, "handle_mcp_request", return_value=DummyResult())
@patch.object(module, "handle_mcp_request", return_value=DummyResult(), autospec=True)
def test_success_request(self, mock_handle):
fake_payload(
{
@ -321,7 +321,7 @@ class TestMCPAppApi:
post_fn("server-1")
assert "App is unavailable" in str(exc_info.value)
@patch.object(module, "handle_mcp_request", return_value=None)
@patch.object(module, "handle_mcp_request", return_value=None, autospec=True)
def test_mcp_request_no_response(self, mock_handle):
"""Test when handle_mcp_request returns None"""
fake_payload(
@ -380,7 +380,7 @@ class TestMCPAppApi:
api = module.MCPAppApi()
api._get_mcp_server_and_app = MagicMock(return_value=(server, app))
with patch.object(module, "handle_mcp_request", return_value=DummyResult()):
with patch.object(module, "handle_mcp_request", return_value=DummyResult(), autospec=True):
post_fn = unwrap(api.post)
response = post_fn("server-1")
assert isinstance(response, Response)
@ -409,7 +409,7 @@ class TestMCPAppApi:
api = module.MCPAppApi()
api._get_mcp_server_and_app = MagicMock(return_value=(server, app))
with patch.object(module, "handle_mcp_request", return_value=DummyResult()):
with patch.object(module, "handle_mcp_request", return_value=DummyResult(), autospec=True):
post_fn = unwrap(api.post)
response = post_fn("server-1")
assert isinstance(response, Response)

View File

@ -12,7 +12,7 @@ from controllers.service_api.index import IndexApi
class TestIndexApi:
"""Test suite for IndexApi resource."""
@patch("controllers.service_api.index.dify_config")
@patch("controllers.service_api.index.dify_config", autospec=True)
def test_get_returns_api_info(self, mock_config, app):
"""Test that GET returns API metadata with correct structure."""
# Arrange

View File

@ -71,17 +71,17 @@ class TestBaseAppRunnerMultimodal:
mime_type="image/png",
)
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class:
# Setup mock tool file manager
mock_mgr = MagicMock()
mock_mgr.create_file_by_url.return_value = mock_tool_file
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class:
# Setup mock message file
mock_msg_file_class.return_value = mock_message_file
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session:
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.refresh = MagicMock()
@ -158,17 +158,17 @@ class TestBaseAppRunnerMultimodal:
mime_type="image/png",
)
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class:
# Setup mock tool file manager
mock_mgr = MagicMock()
mock_mgr.create_file_by_raw.return_value = mock_tool_file
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class:
# Setup mock message file
mock_msg_file_class.return_value = mock_message_file
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session:
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.refresh = MagicMock()
@ -231,17 +231,17 @@ class TestBaseAppRunnerMultimodal:
mime_type="image/png",
)
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class:
# Setup mock tool file manager
mock_mgr = MagicMock()
mock_mgr.create_file_by_raw.return_value = mock_tool_file
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class:
# Setup mock message file
mock_msg_file_class.return_value = mock_message_file
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session:
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.refresh = MagicMock()
@ -282,9 +282,9 @@ class TestBaseAppRunnerMultimodal:
mime_type="image/png",
)
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class:
with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class:
with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session:
# Act
# Create a mock runner with the method bound
runner = MagicMock()
@ -321,14 +321,14 @@ class TestBaseAppRunnerMultimodal:
mime_type="image/png",
)
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class:
# Setup mock to raise exception
mock_mgr = MagicMock()
mock_mgr.create_file_by_url.side_effect = Exception("Network error")
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class:
with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session:
# Act
# Create a mock runner with the method bound
runner = MagicMock()
@ -368,17 +368,17 @@ class TestBaseAppRunnerMultimodal:
)
mock_queue_manager.invoke_from = InvokeFrom.DEBUGGER
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class:
# Setup mock tool file manager
mock_mgr = MagicMock()
mock_mgr.create_file_by_url.return_value = mock_tool_file
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class:
# Setup mock message file
mock_msg_file_class.return_value = mock_message_file
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session:
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.refresh = MagicMock()
@ -420,17 +420,17 @@ class TestBaseAppRunnerMultimodal:
)
mock_queue_manager.invoke_from = InvokeFrom.SERVICE_API
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
with patch("core.app.apps.base_app_runner.ToolFileManager", autospec=True) as mock_mgr_class:
# Setup mock tool file manager
mock_mgr = MagicMock()
mock_mgr.create_file_by_url.return_value = mock_tool_file
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
with patch("core.app.apps.base_app_runner.MessageFile", autospec=True) as mock_msg_file_class:
# Setup mock message file
mock_msg_file_class.return_value = mock_message_file
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
with patch("core.app.apps.base_app_runner.db.session", autospec=True) as mock_session:
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.refresh = MagicMock()

View File

@ -84,7 +84,7 @@ def mock_time():
mock_time_val += seconds
return mock_time_val
with patch("time.time", return_value=mock_time_val) as mock:
with patch("time.time", return_value=mock_time_val, autospec=True) as mock:
mock.increment = increment_time
yield mock

View File

@ -9,7 +9,7 @@ from core.helper.ssrf_proxy import (
)
@patch("core.helper.ssrf_proxy._get_ssrf_client")
@patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True)
def test_successful_request(mock_get_client):
mock_client = MagicMock()
mock_response = MagicMock()
@ -22,7 +22,7 @@ def test_successful_request(mock_get_client):
mock_client.request.assert_called_once()
@patch("core.helper.ssrf_proxy._get_ssrf_client")
@patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True)
def test_retry_exceed_max_retries(mock_get_client):
mock_client = MagicMock()
mock_response = MagicMock()
@ -71,7 +71,7 @@ class TestGetUserProvidedHostHeader:
assert result in ("first.com", "second.com")
@patch("core.helper.ssrf_proxy._get_ssrf_client")
@patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True)
def test_host_header_preservation_with_user_header(mock_get_client):
"""Test that user-provided Host header is preserved in the request."""
mock_client = MagicMock()
@ -89,7 +89,7 @@ def test_host_header_preservation_with_user_header(mock_get_client):
assert call_kwargs["headers"]["host"] == custom_host
@patch("core.helper.ssrf_proxy._get_ssrf_client")
@patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True)
@pytest.mark.parametrize("host_key", ["host", "HOST", "Host"])
def test_host_header_preservation_case_insensitive(mock_get_client, host_key):
"""Test that Host header is preserved regardless of case."""
@ -113,7 +113,7 @@ class TestFollowRedirectsParameter:
These tests verify that follow_redirects is correctly passed to client.request().
"""
@patch("core.helper.ssrf_proxy._get_ssrf_client")
@patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True)
def test_follow_redirects_passed_to_request(self, mock_get_client):
"""Verify follow_redirects IS passed to client.request()."""
mock_client = MagicMock()
@ -128,7 +128,7 @@ class TestFollowRedirectsParameter:
call_kwargs = mock_client.request.call_args.kwargs
assert call_kwargs.get("follow_redirects") is True
@patch("core.helper.ssrf_proxy._get_ssrf_client")
@patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True)
def test_allow_redirects_converted_to_follow_redirects(self, mock_get_client):
"""Verify allow_redirects (requests-style) is converted to follow_redirects (httpx-style)."""
mock_client = MagicMock()
@ -145,7 +145,7 @@ class TestFollowRedirectsParameter:
assert call_kwargs.get("follow_redirects") is True
assert "allow_redirects" not in call_kwargs
@patch("core.helper.ssrf_proxy._get_ssrf_client")
@patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True)
def test_follow_redirects_not_set_when_not_specified(self, mock_get_client):
"""Verify follow_redirects is not in kwargs when not specified (httpx default behavior)."""
mock_client = MagicMock()
@ -160,7 +160,7 @@ class TestFollowRedirectsParameter:
call_kwargs = mock_client.request.call_args.kwargs
assert "follow_redirects" not in call_kwargs
@patch("core.helper.ssrf_proxy._get_ssrf_client")
@patch("core.helper.ssrf_proxy._get_ssrf_client", autospec=True)
def test_follow_redirects_takes_precedence_over_allow_redirects(self, mock_get_client):
"""Verify follow_redirects takes precedence when both are specified."""
mock_client = MagicMock()

View File

@ -72,7 +72,7 @@ class TestTraceContextFilter:
mock_span.get_span_context.return_value = mock_context
with (
mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span),
mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span, autospec=True),
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
):
@ -108,7 +108,9 @@ class TestIdentityContextFilter:
filter = IdentityContextFilter()
# Should not raise even if something goes wrong
with mock.patch("core.logging.filters.flask.has_request_context", side_effect=Exception("Test error")):
with mock.patch(
"core.logging.filters.flask.has_request_context", side_effect=Exception("Test error"), autospec=True
):
result = filter.filter(log_record)
assert result is True
assert log_record.tenant_id == ""

View File

@ -8,7 +8,7 @@ class TestGetSpanIdFromOtelContext:
def test_returns_none_without_span(self):
from core.helper.trace_id_helper import get_span_id_from_otel_context
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
with mock.patch("opentelemetry.trace.get_current_span", return_value=None, autospec=True):
result = get_span_id_from_otel_context()
assert result is None
@ -20,7 +20,7 @@ class TestGetSpanIdFromOtelContext:
mock_context.span_id = 0x051581BF3BB55C45
mock_span.get_span_context.return_value = mock_context
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span, autospec=True):
with mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0):
result = get_span_id_from_otel_context()
assert result == "051581bf3bb55c45"
@ -28,7 +28,7 @@ class TestGetSpanIdFromOtelContext:
def test_returns_none_on_exception(self):
from core.helper.trace_id_helper import get_span_id_from_otel_context
with mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception("Test error")):
with mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception("Test error"), autospec=True):
result = get_span_id_from_otel_context()
assert result is None
@ -37,7 +37,7 @@ class TestGenerateTraceparentHeader:
def test_generates_valid_format(self):
from core.helper.trace_id_helper import generate_traceparent_header
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
with mock.patch("opentelemetry.trace.get_current_span", return_value=None, autospec=True):
result = generate_traceparent_header()
assert result is not None
@ -58,7 +58,7 @@ class TestGenerateTraceparentHeader:
mock_context.span_id = 0x051581BF3BB55C45
mock_span.get_span_context.return_value = mock_context
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span):
with mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span, autospec=True):
with (
mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0),
mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0),
@ -70,7 +70,7 @@ class TestGenerateTraceparentHeader:
def test_generates_hex_only_values(self):
from core.helper.trace_id_helper import generate_traceparent_header
with mock.patch("opentelemetry.trace.get_current_span", return_value=None):
with mock.patch("opentelemetry.trace.get_current_span", return_value=None, autospec=True):
result = generate_traceparent_header()
parts = result.split("-")

View File

@ -32,7 +32,7 @@ class TestConstants:
class TestCreateSSRFProxyMCPHTTPClient:
"""Test create_ssrf_proxy_mcp_http_client function."""
@patch("core.mcp.utils.dify_config")
@patch("core.mcp.utils.dify_config", autospec=True)
def test_create_client_with_all_url_proxy(self, mock_config):
"""Test client creation with SSRF_PROXY_ALL_URL configured."""
mock_config.SSRF_PROXY_ALL_URL = "http://proxy.example.com:8080"
@ -50,7 +50,7 @@ class TestCreateSSRFProxyMCPHTTPClient:
# Clean up
client.close()
@patch("core.mcp.utils.dify_config")
@patch("core.mcp.utils.dify_config", autospec=True)
def test_create_client_with_http_https_proxies(self, mock_config):
"""Test client creation with separate HTTP/HTTPS proxies."""
mock_config.SSRF_PROXY_ALL_URL = None
@ -66,7 +66,7 @@ class TestCreateSSRFProxyMCPHTTPClient:
# Clean up
client.close()
@patch("core.mcp.utils.dify_config")
@patch("core.mcp.utils.dify_config", autospec=True)
def test_create_client_without_proxy(self, mock_config):
"""Test client creation without proxy configuration."""
mock_config.SSRF_PROXY_ALL_URL = None
@ -88,7 +88,7 @@ class TestCreateSSRFProxyMCPHTTPClient:
# Clean up
client.close()
@patch("core.mcp.utils.dify_config")
@patch("core.mcp.utils.dify_config", autospec=True)
def test_create_client_default_params(self, mock_config):
"""Test client creation with default parameters."""
mock_config.SSRF_PROXY_ALL_URL = None
@ -111,8 +111,8 @@ class TestCreateSSRFProxyMCPHTTPClient:
class TestSSRFProxySSEConnect:
"""Test ssrf_proxy_sse_connect function."""
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
@patch("core.mcp.utils.connect_sse", autospec=True)
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client", autospec=True)
def test_sse_connect_with_provided_client(self, mock_create_client, mock_connect_sse):
"""Test SSE connection with pre-configured client."""
# Setup mocks
@ -138,9 +138,9 @@ class TestSSRFProxySSEConnect:
# Verify result
assert result == mock_context
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
@patch("core.mcp.utils.dify_config")
@patch("core.mcp.utils.connect_sse", autospec=True)
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client", autospec=True)
@patch("core.mcp.utils.dify_config", autospec=True)
def test_sse_connect_without_client(self, mock_config, mock_create_client, mock_connect_sse):
"""Test SSE connection without pre-configured client."""
# Setup config
@ -183,8 +183,8 @@ class TestSSRFProxySSEConnect:
# Verify result
assert result == mock_context
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
@patch("core.mcp.utils.connect_sse", autospec=True)
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client", autospec=True)
def test_sse_connect_with_custom_timeout(self, mock_create_client, mock_connect_sse):
"""Test SSE connection with custom timeout."""
# Setup mocks
@ -209,8 +209,8 @@ class TestSSRFProxySSEConnect:
# Verify result
assert result == mock_context
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client")
@patch("core.mcp.utils.connect_sse", autospec=True)
@patch("core.mcp.utils.create_ssrf_proxy_mcp_http_client", autospec=True)
def test_sse_connect_error_cleanup(self, mock_create_client, mock_connect_sse):
"""Test SSE connection cleans up client on error."""
# Setup mocks
@ -227,7 +227,7 @@ class TestSSRFProxySSEConnect:
# Verify client was cleaned up
mock_client.close.assert_called_once()
@patch("core.mcp.utils.connect_sse")
@patch("core.mcp.utils.connect_sse", autospec=True)
def test_sse_connect_error_no_cleanup_with_provided_client(self, mock_connect_sse):
"""Test SSE connection doesn't clean up provided client on error."""
# Setup mocks

View File

@ -324,7 +324,7 @@ class TestOpenAIModeration:
with pytest.raises(ValueError, match="At least one of inputs_config or outputs_config must be enabled"):
OpenAIModeration.validate_config("test-tenant", config)
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
def test_moderation_for_inputs_no_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration):
"""Test input moderation when OpenAI API returns no violations."""
# Mock the model manager and instance
@ -341,7 +341,7 @@ class TestOpenAIModeration:
assert result.action == ModerationAction.DIRECT_OUTPUT
assert result.preset_response == "Content flagged by OpenAI moderation."
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
def test_moderation_for_inputs_with_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration):
"""Test input moderation when OpenAI API detects violations."""
# Mock the model manager to return violation
@ -358,7 +358,7 @@ class TestOpenAIModeration:
assert result.action == ModerationAction.DIRECT_OUTPUT
assert result.preset_response == "Content flagged by OpenAI moderation."
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
def test_moderation_for_inputs_query_included(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration):
"""Test that query is included in moderation check with special key."""
mock_instance = MagicMock()
@ -385,7 +385,7 @@ class TestOpenAIModeration:
assert "u" in moderated_text
assert "e" in moderated_text
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
def test_moderation_for_inputs_disabled(self, mock_model_manager: Mock):
"""Test input moderation when inputs_config is disabled."""
config = {
@ -400,7 +400,7 @@ class TestOpenAIModeration:
# Should not call the API when disabled
mock_model_manager.assert_not_called()
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
def test_moderation_for_outputs_no_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration):
"""Test output moderation when OpenAI API returns no violations."""
mock_instance = MagicMock()
@ -414,7 +414,7 @@ class TestOpenAIModeration:
assert result.action == ModerationAction.DIRECT_OUTPUT
assert result.preset_response == "Response blocked by moderation."
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
def test_moderation_for_outputs_with_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration):
"""Test output moderation when OpenAI API detects violations."""
mock_instance = MagicMock()
@ -427,7 +427,7 @@ class TestOpenAIModeration:
assert result.flagged is True
assert result.action == ModerationAction.DIRECT_OUTPUT
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
def test_moderation_for_outputs_disabled(self, mock_model_manager: Mock):
"""Test output moderation when outputs_config is disabled."""
config = {
@ -441,7 +441,7 @@ class TestOpenAIModeration:
assert result.flagged is False
mock_model_manager.assert_not_called()
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
def test_model_manager_called_with_correct_params(
self, mock_model_manager: Mock, openai_moderation: OpenAIModeration
):
@ -494,7 +494,7 @@ class TestModerationRuleStructure:
class TestModerationFactoryIntegration:
"""Test suite for ModerationFactory integration."""
@patch("core.moderation.factory.code_based_extension")
@patch("core.moderation.factory.code_based_extension", autospec=True)
def test_factory_delegates_to_extension(self, mock_extension: Mock):
"""Test ModerationFactory delegates to extension system."""
from core.moderation.factory import ModerationFactory
@ -518,7 +518,7 @@ class TestModerationFactoryIntegration:
assert result.flagged is False
mock_instance.moderation_for_inputs.assert_called_once()
@patch("core.moderation.factory.code_based_extension")
@patch("core.moderation.factory.code_based_extension", autospec=True)
def test_factory_validate_config_delegates(self, mock_extension: Mock):
"""Test ModerationFactory.validate_config delegates to extension."""
from core.moderation.factory import ModerationFactory
@ -629,7 +629,7 @@ class TestPresetManagement:
assert result.flagged is True
assert result.preset_response == "Custom output blocked message"
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
def test_openai_preset_response_in_inputs(self, mock_model_manager: Mock):
"""Test preset response is properly returned for OpenAI input violations."""
mock_instance = MagicMock()
@ -650,7 +650,7 @@ class TestPresetManagement:
assert result.flagged is True
assert result.preset_response == "OpenAI input blocked"
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
def test_openai_preset_response_in_outputs(self, mock_model_manager: Mock):
"""Test preset response is properly returned for OpenAI output violations."""
mock_instance = MagicMock()
@ -989,7 +989,7 @@ class TestOpenAIModerationAdvanced:
- Performance considerations
"""
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
def test_openai_api_timeout_handling(self, mock_model_manager: Mock):
"""
Test graceful handling of OpenAI API timeouts.
@ -1012,7 +1012,7 @@ class TestOpenAIModerationAdvanced:
with pytest.raises(TimeoutError):
moderation.moderation_for_inputs({"text": "test"}, "")
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
def test_openai_api_rate_limit_handling(self, mock_model_manager: Mock):
"""
Test handling of OpenAI API rate limit errors.
@ -1035,7 +1035,7 @@ class TestOpenAIModerationAdvanced:
with pytest.raises(Exception, match="Rate limit exceeded"):
moderation.moderation_for_inputs({"text": "test"}, "")
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
def test_openai_with_multiple_input_fields(self, mock_model_manager: Mock):
"""
Test OpenAI moderation with multiple input fields.
@ -1079,7 +1079,7 @@ class TestOpenAIModerationAdvanced:
assert "u" in moderated_text
assert "e" in moderated_text
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
def test_openai_empty_text_handling(self, mock_model_manager: Mock):
"""
Test OpenAI moderation with empty text inputs.
@ -1103,7 +1103,7 @@ class TestOpenAIModerationAdvanced:
assert result.flagged is False
mock_instance.invoke_moderation.assert_called_once()
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager")
@patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True)
def test_openai_model_instance_fetched_on_each_call(self, mock_model_manager: Mock):
"""
Test that ModelManager fetches a fresh model instance on each call.

View File

@ -64,7 +64,7 @@ class TestPluginEndpointClientDelete:
"data": True,
}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act
result = endpoint_client.delete_endpoint(
tenant_id=tenant_id,
@ -102,7 +102,7 @@ class TestPluginEndpointClientDelete:
),
}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act
result = endpoint_client.delete_endpoint(
tenant_id=tenant_id,
@ -139,7 +139,7 @@ class TestPluginEndpointClientDelete:
),
}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(PluginDaemonInternalServerError) as exc_info:
endpoint_client.delete_endpoint(
@ -174,7 +174,7 @@ class TestPluginEndpointClientDelete:
"message": '{"error_type": "PluginDaemonInternalServerError", "message": "Record Not Found"}',
}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act
result = endpoint_client.delete_endpoint(
tenant_id=tenant_id,
@ -222,7 +222,7 @@ class TestPluginEndpointClientDelete:
),
}
with patch("httpx.request") as mock_request:
with patch("httpx.request", autospec=True) as mock_request:
# Act - first call
mock_request.return_value = mock_response_success
result1 = endpoint_client.delete_endpoint(
@ -266,7 +266,7 @@ class TestPluginEndpointClientDelete:
"message": '{"error_type": "PluginDaemonUnauthorizedError", "message": "unauthorized access"}',
}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(Exception) as exc_info:
endpoint_client.delete_endpoint(

View File

@ -114,7 +114,7 @@ class TestPluginRuntimeExecution:
mock_response.status_code = 200
mock_response.json.return_value = {"result": "success"}
with patch("httpx.request", return_value=mock_response) as mock_request:
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
# Act
response = plugin_client._request("GET", "plugin/test-tenant/management/list")
@ -132,7 +132,7 @@ class TestPluginRuntimeExecution:
mock_response = MagicMock()
mock_response.status_code = 200
with patch("httpx.request", return_value=mock_response) as mock_request:
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
# Act
plugin_client._request("GET", "plugin/test-tenant/test")
@ -143,7 +143,7 @@ class TestPluginRuntimeExecution:
def test_request_connection_error(self, plugin_client, mock_config):
"""Test handling of connection errors during request."""
# Arrange
with patch("httpx.request", side_effect=httpx.RequestError("Connection failed")):
with patch("httpx.request", side_effect=httpx.RequestError("Connection failed"), autospec=True):
# Act & Assert
with pytest.raises(PluginDaemonInnerError) as exc_info:
plugin_client._request("GET", "plugin/test-tenant/test")
@ -182,7 +182,7 @@ class TestPluginRuntimeSandboxIsolation:
mock_response.status_code = 200
mock_response.json.return_value = {"code": 0, "message": "", "data": True}
with patch("httpx.request", return_value=mock_response) as mock_request:
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
# Act
plugin_client._request("GET", "plugin/test-tenant/test")
@ -201,7 +201,7 @@ class TestPluginRuntimeSandboxIsolation:
mock_response.status_code = 200
mock_response.json.return_value = {"code": 0, "message": "", "data": {"result": "isolated_execution"}}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act
result = plugin_client._request_with_plugin_daemon_response(
"POST", "plugin/test-tenant/dispatch/tool/invoke", TestResponse, data={"tool": "test"}
@ -218,7 +218,7 @@ class TestPluginRuntimeSandboxIsolation:
error_message = json.dumps({"error_type": "PluginDaemonUnauthorizedError", "message": "Unauthorized access"})
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(PluginDaemonUnauthorizedError) as exc_info:
plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool)
@ -234,7 +234,7 @@ class TestPluginRuntimeSandboxIsolation:
)
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(PluginPermissionDeniedError) as exc_info:
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/test", bool)
@ -272,7 +272,7 @@ class TestPluginRuntimeResourceLimits:
mock_response = MagicMock()
mock_response.status_code = 200
with patch("httpx.request", return_value=mock_response) as mock_request:
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
# Act
plugin_client._request("GET", "plugin/test-tenant/test")
@ -283,7 +283,7 @@ class TestPluginRuntimeResourceLimits:
def test_timeout_error_handling(self, plugin_client, mock_config):
"""Test handling of timeout errors."""
# Arrange
with patch("httpx.request", side_effect=httpx.TimeoutException("Request timeout")):
with patch("httpx.request", side_effect=httpx.TimeoutException("Request timeout"), autospec=True):
# Act & Assert
with pytest.raises(PluginDaemonInnerError) as exc_info:
plugin_client._request("GET", "plugin/test-tenant/test")
@ -292,7 +292,7 @@ class TestPluginRuntimeResourceLimits:
def test_streaming_request_timeout(self, plugin_client, mock_config):
"""Test timeout handling for streaming requests."""
# Arrange
with patch("httpx.stream", side_effect=httpx.TimeoutException("Stream timeout")):
with patch("httpx.stream", side_effect=httpx.TimeoutException("Stream timeout"), autospec=True):
# Act & Assert
with pytest.raises(PluginDaemonInnerError) as exc_info:
list(plugin_client._stream_request("POST", "plugin/test-tenant/stream"))
@ -308,7 +308,7 @@ class TestPluginRuntimeResourceLimits:
)
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(PluginDaemonInternalServerError) as exc_info:
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/test", bool)
@ -352,7 +352,7 @@ class TestPluginRuntimeErrorHandling:
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(InvokeRateLimitError) as exc_info:
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/invoke", bool)
@ -371,7 +371,7 @@ class TestPluginRuntimeErrorHandling:
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(InvokeAuthorizationError) as exc_info:
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/invoke", bool)
@ -390,7 +390,7 @@ class TestPluginRuntimeErrorHandling:
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(InvokeBadRequestError) as exc_info:
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/invoke", bool)
@ -409,7 +409,7 @@ class TestPluginRuntimeErrorHandling:
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(InvokeConnectionError) as exc_info:
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/invoke", bool)
@ -428,7 +428,7 @@ class TestPluginRuntimeErrorHandling:
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(InvokeServerUnavailableError) as exc_info:
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/invoke", bool)
@ -446,7 +446,7 @@ class TestPluginRuntimeErrorHandling:
error_message = json.dumps({"error_type": "PluginInvokeError", "message": json.dumps(invoke_error)})
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(CredentialsValidateFailedError) as exc_info:
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/validate", bool)
@ -462,7 +462,7 @@ class TestPluginRuntimeErrorHandling:
)
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(PluginNotFoundError) as exc_info:
plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/get", bool)
@ -478,7 +478,7 @@ class TestPluginRuntimeErrorHandling:
)
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(PluginUniqueIdentifierError) as exc_info:
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/install", bool)
@ -494,7 +494,7 @@ class TestPluginRuntimeErrorHandling:
)
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(PluginDaemonBadRequestError) as exc_info:
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/test", bool)
@ -508,7 +508,7 @@ class TestPluginRuntimeErrorHandling:
error_message = json.dumps({"error_type": "PluginDaemonNotFoundError", "message": "Resource not found"})
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(PluginDaemonNotFoundError) as exc_info:
plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/resource", bool)
@ -526,7 +526,7 @@ class TestPluginRuntimeErrorHandling:
error_message = json.dumps({"error_type": "PluginInvokeError", "message": invoke_error_message})
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(PluginInvokeError) as exc_info:
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/invoke", bool)
@ -540,7 +540,7 @@ class TestPluginRuntimeErrorHandling:
error_message = json.dumps({"error_type": "UnknownErrorType", "message": "Unknown error occurred"})
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(Exception) as exc_info:
plugin_client._request_with_plugin_daemon_response("POST", "plugin/test-tenant/test", bool)
@ -555,7 +555,7 @@ class TestPluginRuntimeErrorHandling:
"Server Error", request=MagicMock(), response=mock_response
)
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(httpx.HTTPStatusError):
plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool)
@ -567,7 +567,7 @@ class TestPluginRuntimeErrorHandling:
mock_response.status_code = 200
mock_response.json.return_value = {"code": 0, "message": "", "data": None}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(ValueError) as exc_info:
plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool)
@ -610,7 +610,7 @@ class TestPluginRuntimeCommunication:
mock_response.status_code = 200
mock_response.json.return_value = {"code": 0, "message": "", "data": {"value": "test", "count": 42}}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act
result = plugin_client._request_with_plugin_daemon_response(
"POST", "plugin/test-tenant/test", TestModel, data={"input": "data"}
@ -637,7 +637,7 @@ class TestPluginRuntimeCommunication:
mock_response = MagicMock()
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
with patch("httpx.stream") as mock_stream:
with patch("httpx.stream", autospec=True) as mock_stream:
mock_stream.return_value.__enter__.return_value = mock_response
# Act
@ -667,7 +667,7 @@ class TestPluginRuntimeCommunication:
mock_response = MagicMock()
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
with patch("httpx.stream") as mock_stream:
with patch("httpx.stream", autospec=True) as mock_stream:
mock_stream.return_value.__enter__.return_value = mock_response
# Act
@ -689,7 +689,7 @@ class TestPluginRuntimeCommunication:
def test_streaming_connection_error(self, plugin_client, mock_config):
"""Test connection error during streaming."""
# Arrange
with patch("httpx.stream", side_effect=httpx.RequestError("Stream connection failed")):
with patch("httpx.stream", side_effect=httpx.RequestError("Stream connection failed"), autospec=True):
# Act & Assert
with pytest.raises(PluginDaemonInnerError) as exc_info:
list(plugin_client._stream_request("POST", "plugin/test-tenant/stream"))
@ -707,7 +707,7 @@ class TestPluginRuntimeCommunication:
mock_response.status_code = 200
mock_response.json.return_value = {"status": "success", "data": {"key": "value"}}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act
result = plugin_client._request_with_model("GET", "plugin/test-tenant/direct", DirectModel)
@ -732,7 +732,7 @@ class TestPluginRuntimeCommunication:
mock_response = MagicMock()
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
with patch("httpx.stream") as mock_stream:
with patch("httpx.stream", autospec=True) as mock_stream:
mock_stream.return_value.__enter__.return_value = mock_response
# Act
@ -764,7 +764,7 @@ class TestPluginRuntimeCommunication:
mock_response = MagicMock()
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
with patch("httpx.stream") as mock_stream:
with patch("httpx.stream", autospec=True) as mock_stream:
mock_stream.return_value.__enter__.return_value = mock_response
# Act
@ -814,7 +814,7 @@ class TestPluginToolManagerIntegration:
mock_response = MagicMock()
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
with patch("httpx.stream") as mock_stream:
with patch("httpx.stream", autospec=True) as mock_stream:
mock_stream.return_value.__enter__.return_value = mock_response
# Act
@ -844,7 +844,7 @@ class TestPluginToolManagerIntegration:
mock_response = MagicMock()
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
with patch("httpx.stream") as mock_stream:
with patch("httpx.stream", autospec=True) as mock_stream:
mock_stream.return_value.__enter__.return_value = mock_response
# Act
@ -868,7 +868,7 @@ class TestPluginToolManagerIntegration:
mock_response = MagicMock()
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
with patch("httpx.stream") as mock_stream:
with patch("httpx.stream", autospec=True) as mock_stream:
mock_stream.return_value.__enter__.return_value = mock_response
# Act
@ -892,7 +892,7 @@ class TestPluginToolManagerIntegration:
mock_response = MagicMock()
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
with patch("httpx.stream") as mock_stream:
with patch("httpx.stream", autospec=True) as mock_stream:
mock_stream.return_value.__enter__.return_value = mock_response
# Act
@ -945,7 +945,7 @@ class TestPluginInstallerIntegration:
},
}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act
result = installer.list_plugins("test-tenant")
@ -959,7 +959,7 @@ class TestPluginInstallerIntegration:
mock_response.status_code = 200
mock_response.json.return_value = {"code": 0, "message": "", "data": True}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act
result = installer.uninstall("test-tenant", "plugin-installation-id")
@ -973,7 +973,7 @@ class TestPluginInstallerIntegration:
mock_response.status_code = 200
mock_response.json.return_value = {"code": 0, "message": "", "data": True}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act
result = installer.fetch_plugin_by_identifier("test-tenant", "plugin-identifier")
@ -1012,7 +1012,7 @@ class TestPluginRuntimeEdgeCases:
mock_response.status_code = 200
mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(ValueError):
plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool)
@ -1025,7 +1025,7 @@ class TestPluginRuntimeEdgeCases:
# Missing required fields in response
mock_response.json.return_value = {"invalid": "structure"}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(ValueError):
plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool)
@ -1041,7 +1041,7 @@ class TestPluginRuntimeEdgeCases:
mock_response = MagicMock()
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
with patch("httpx.stream") as mock_stream:
with patch("httpx.stream", autospec=True) as mock_stream:
mock_stream.return_value.__enter__.return_value = mock_response
# Act
@ -1065,7 +1065,7 @@ class TestPluginRuntimeEdgeCases:
mock_response = MagicMock()
mock_response.status_code = 200
with patch("httpx.request", return_value=mock_response) as mock_request:
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
# Act
plugin_client._request("POST", "plugin/test-tenant/upload", data=b"binary data")
@ -1081,7 +1081,7 @@ class TestPluginRuntimeEdgeCases:
files = {"file": ("test.txt", b"file content", "text/plain")}
with patch("httpx.request", return_value=mock_response) as mock_request:
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
# Act
plugin_client._request("POST", "plugin/test-tenant/upload", files=files)
@ -1095,7 +1095,7 @@ class TestPluginRuntimeEdgeCases:
mock_response = MagicMock()
mock_response.iter_lines.return_value = []
with patch("httpx.stream") as mock_stream:
with patch("httpx.stream", autospec=True) as mock_stream:
mock_stream.return_value.__enter__.return_value = mock_response
# Act
@ -1115,7 +1115,7 @@ class TestPluginRuntimeEdgeCases:
mock_response = MagicMock()
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
with patch("httpx.stream") as mock_stream:
with patch("httpx.stream", autospec=True) as mock_stream:
mock_stream.return_value.__enter__.return_value = mock_response
# Act & Assert
@ -1136,7 +1136,7 @@ class TestPluginRuntimeEdgeCases:
mock_response.status_code = 200
mock_response.json.return_value = {"code": -1, "message": "Plain text error message", "data": None}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(ValueError) as exc_info:
plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool)
@ -1174,7 +1174,7 @@ class TestPluginRuntimeAdvancedScenarios:
mock_response.status_code = 200
mock_response.json.return_value = {"code": 0, "message": "", "data": True}
with patch("httpx.request", return_value=mock_response) as mock_request:
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
# Act
for i in range(5):
result = plugin_client._request_with_plugin_daemon_response("GET", f"plugin/test-tenant/test/{i}", bool)
@ -1203,7 +1203,7 @@ class TestPluginRuntimeAdvancedScenarios:
mock_response.status_code = 200
mock_response.json.return_value = {"code": 0, "message": "", "data": complex_data}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act
result = plugin_client._request_with_plugin_daemon_response(
"POST", "plugin/test-tenant/complex", ComplexModel
@ -1231,7 +1231,7 @@ class TestPluginRuntimeAdvancedScenarios:
mock_response = MagicMock()
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
with patch("httpx.stream") as mock_stream:
with patch("httpx.stream", autospec=True) as mock_stream:
mock_stream.return_value.__enter__.return_value = mock_response
# Act
@ -1262,7 +1262,7 @@ class TestPluginRuntimeAdvancedScenarios:
mock_response.status_code = 200
return mock_response
with patch("httpx.request", side_effect=side_effect):
with patch("httpx.request", side_effect=side_effect, autospec=True):
# Act & Assert - First two calls should fail
with pytest.raises(PluginDaemonInnerError):
plugin_client._request("GET", "plugin/test-tenant/test")
@ -1286,7 +1286,7 @@ class TestPluginRuntimeAdvancedScenarios:
mock_response = MagicMock()
mock_response.status_code = 200
with patch("httpx.request", return_value=mock_response) as mock_request:
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
# Act
plugin_client._request("GET", "plugin/test-tenant/test", headers=custom_headers)
@ -1312,7 +1312,7 @@ class TestPluginRuntimeAdvancedScenarios:
mock_response = MagicMock()
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
with patch("httpx.stream") as mock_stream:
with patch("httpx.stream", autospec=True) as mock_stream:
mock_stream.return_value.__enter__.return_value = mock_response
# Act
@ -1359,7 +1359,7 @@ class TestPluginRuntimeSecurityAndValidation:
mock_response = MagicMock()
mock_response.status_code = 200
with patch("httpx.request", return_value=mock_response) as mock_request:
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
# Act
plugin_client._request("GET", "plugin/test-tenant/test")
@ -1381,7 +1381,7 @@ class TestPluginRuntimeSecurityAndValidation:
mock_response.status_code = 200
mock_response.json.return_value = {"code": 0, "message": "", "data": True}
with patch("httpx.request", return_value=mock_response) as mock_request:
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
# Act
plugin_client._request_with_plugin_daemon_response(
"POST",
@ -1403,7 +1403,7 @@ class TestPluginRuntimeSecurityAndValidation:
error_message = json.dumps({"error_type": "PluginDaemonUnauthorizedError", "message": "Invalid API key"})
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(PluginDaemonUnauthorizedError) as exc_info:
plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool)
@ -1424,7 +1424,7 @@ class TestPluginRuntimeSecurityAndValidation:
)
mock_response.json.return_value = {"code": -1, "message": error_message, "data": None}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert
with pytest.raises(PluginDaemonBadRequestError) as exc_info:
plugin_client._request_with_plugin_daemon_response(
@ -1438,7 +1438,7 @@ class TestPluginRuntimeSecurityAndValidation:
mock_response = MagicMock()
mock_response.status_code = 200
with patch("httpx.request", return_value=mock_response) as mock_request:
with patch("httpx.request", return_value=mock_response, autospec=True) as mock_request:
# Act
plugin_client._request(
"POST", "plugin/test-tenant/test", headers={"Content-Type": "application/json"}, data={"key": "value"}
@ -1489,7 +1489,7 @@ class TestPluginRuntimePerformanceScenarios:
mock_response = MagicMock()
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
with patch("httpx.stream") as mock_stream:
with patch("httpx.stream", autospec=True) as mock_stream:
mock_stream.return_value.__enter__.return_value = mock_response
# Act
@ -1524,7 +1524,7 @@ class TestPluginRuntimePerformanceScenarios:
mock_response = MagicMock()
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
with patch("httpx.stream") as mock_stream:
with patch("httpx.stream", autospec=True) as mock_stream:
mock_stream.return_value.__enter__.return_value = mock_response
# Act - Process chunks one by one
@ -1539,7 +1539,7 @@ class TestPluginRuntimePerformanceScenarios:
def test_timeout_with_slow_response(self, plugin_client, mock_config):
"""Test timeout handling with slow response simulation."""
# Arrange
with patch("httpx.request", side_effect=httpx.TimeoutException("Request timed out after 30s")):
with patch("httpx.request", side_effect=httpx.TimeoutException("Request timed out after 30s"), autospec=True):
# Act & Assert
with pytest.raises(PluginDaemonInnerError) as exc_info:
plugin_client._request("GET", "plugin/test-tenant/slow-endpoint")
@ -1554,7 +1554,7 @@ class TestPluginRuntimePerformanceScenarios:
request_results = []
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act - Simulate 10 concurrent requests
for i in range(10):
result = plugin_client._request_with_plugin_daemon_response(
@ -1612,7 +1612,7 @@ class TestPluginToolManagerAdvanced:
mock_response = MagicMock()
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
with patch("httpx.stream") as mock_stream:
with patch("httpx.stream", autospec=True) as mock_stream:
mock_stream.return_value.__enter__.return_value = mock_response
# Act
@ -1641,7 +1641,7 @@ class TestPluginToolManagerAdvanced:
mock_response = MagicMock()
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
with patch("httpx.stream") as mock_stream:
with patch("httpx.stream", autospec=True) as mock_stream:
mock_stream.return_value.__enter__.return_value = mock_response
# Act
@ -1673,7 +1673,7 @@ class TestPluginToolManagerAdvanced:
mock_response = MagicMock()
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
with patch("httpx.stream") as mock_stream:
with patch("httpx.stream", autospec=True) as mock_stream:
mock_stream.return_value.__enter__.return_value = mock_response
# Act
@ -1704,7 +1704,7 @@ class TestPluginToolManagerAdvanced:
mock_response = MagicMock()
mock_response.iter_lines.return_value = [line.encode("utf-8") for line in stream_data]
with patch("httpx.stream") as mock_stream:
with patch("httpx.stream", autospec=True) as mock_stream:
mock_stream.return_value.__enter__.return_value = mock_response
# Act
@ -1770,7 +1770,7 @@ class TestPluginInstallerAdvanced:
},
}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act
result = installer.upload_pkg("test-tenant", plugin_package, verify_signature=False)
@ -1788,7 +1788,7 @@ class TestPluginInstallerAdvanced:
"data": {"content": "# Plugin README\n\nThis is a test plugin.", "language": "en"},
}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act
result = installer.fetch_plugin_readme("test-tenant", "test-org/test-plugin", "en")
@ -1807,7 +1807,7 @@ class TestPluginInstallerAdvanced:
mock_response.raise_for_status = raise_for_status
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act & Assert - Should raise HTTPStatusError for 404
with pytest.raises(httpx.HTTPStatusError):
installer.fetch_plugin_readme("test-tenant", "test-org/test-plugin", "en")
@ -1826,7 +1826,7 @@ class TestPluginInstallerAdvanced:
},
}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act
result = installer.list_plugins_with_total("test-tenant", page=2, page_size=20)
@ -1848,7 +1848,7 @@ class TestPluginInstallerAdvanced:
mock_response.status_code = 200
mock_response.json.return_value = {"code": 0, "message": "", "data": [True, False]}
with patch("httpx.request", return_value=mock_response):
with patch("httpx.request", return_value=mock_response, autospec=True):
# Act
result = installer.check_tools_existence("test-tenant", provider_ids)

View File

@ -142,7 +142,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg
prompt_transform = AdvancedPromptTransform()
prompt_transform._calculate_rest_token = MagicMock(return_value=2000)
with patch("core.workflow.file.file_manager.to_prompt_message_content") as mock_get_encoded_string:
with patch("core.workflow.file.file_manager.to_prompt_message_content", autospec=True) as mock_get_encoded_string:
mock_get_encoded_string.return_value = ImagePromptMessageContent(
url=str(files[0].remote_url), format="jpg", mime_type="image/jpg"
)

View File

@ -83,7 +83,7 @@ def test_extract_images_formats(mock_dependencies, monkeypatch, image_bytes, exp
extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1")
# We need to handle the import inside _extract_images
with patch("pypdfium2.raw") as mock_raw:
with patch("pypdfium2.raw", autospec=True) as mock_raw:
mock_raw.FPDF_PAGEOBJ_IMAGE = 1
result = extractor._extract_images(mock_page)
@ -115,7 +115,7 @@ def test_extract_images_get_objects_scenarios(mock_dependencies, get_objects_sid
extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1")
with patch("pypdfium2.raw") as mock_raw:
with patch("pypdfium2.raw", autospec=True) as mock_raw:
mock_raw.FPDF_PAGEOBJ_IMAGE = 1
result = extractor._extract_images(mock_page)
@ -133,11 +133,11 @@ def test_extract_calls_extract_images(mock_dependencies, monkeypatch):
mock_text_page.get_text_range.return_value = "Page text content"
mock_page.get_textpage.return_value = mock_text_page
with patch("pypdfium2.PdfDocument", return_value=mock_pdf_doc):
with patch("pypdfium2.PdfDocument", return_value=mock_pdf_doc, autospec=True):
# Mock Blob
mock_blob = MagicMock()
mock_blob.source = "test.pdf"
with patch("core.rag.extractor.pdf_extractor.Blob.from_path", return_value=mock_blob):
with patch("core.rag.extractor.pdf_extractor.Blob.from_path", return_value=mock_blob, autospec=True):
extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1")
# Mock _extract_images to return a known string
@ -175,7 +175,7 @@ def test_extract_images_failures(mock_dependencies):
extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id="t1", user_id="u1")
with patch("pypdfium2.raw") as mock_raw:
with patch("pypdfium2.raw", autospec=True) as mock_raw:
mock_raw.FPDF_PAGEOBJ_IMAGE = 1
result = extractor._extract_images(mock_page)

View File

@ -52,7 +52,7 @@ class TestRerankModelRunner:
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
@ -397,19 +397,19 @@ class TestWeightRerankRunner:
@pytest.fixture
def mock_model_manager(self):
"""Mock ModelManager for embedding model."""
with patch("core.rag.rerank.weight_rerank.ModelManager") as mock_manager:
with patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager:
yield mock_manager
@pytest.fixture
def mock_cache_embedding(self):
"""Mock CacheEmbedding for vector operations."""
with patch("core.rag.rerank.weight_rerank.CacheEmbedding") as mock_cache:
with patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache:
yield mock_cache
@pytest.fixture
def mock_jieba_handler(self):
"""Mock JiebaKeywordTableHandler for keyword extraction."""
with patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler") as mock_jieba:
with patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba:
yield mock_jieba
@pytest.fixture
@ -914,7 +914,7 @@ class TestRerankIntegration:
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
@ -1026,7 +1026,7 @@ class TestRerankEdgeCases:
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
@ -1295,9 +1295,9 @@ class TestRerankEdgeCases:
# Mock dependencies
with (
patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler") as mock_jieba,
patch("core.rag.rerank.weight_rerank.ModelManager") as mock_manager,
patch("core.rag.rerank.weight_rerank.CacheEmbedding") as mock_cache,
patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba,
patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager,
patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache,
):
mock_handler = MagicMock()
mock_handler.extract_keywords.return_value = ["test"]
@ -1367,7 +1367,7 @@ class TestRerankPerformance:
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
@ -1441,9 +1441,9 @@ class TestRerankPerformance:
runner = WeightRerankRunner(tenant_id="tenant123", weights=weights)
with (
patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler") as mock_jieba,
patch("core.rag.rerank.weight_rerank.ModelManager") as mock_manager,
patch("core.rag.rerank.weight_rerank.CacheEmbedding") as mock_cache,
patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba,
patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager,
patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache,
):
mock_handler = MagicMock()
# Track keyword extraction calls
@ -1484,7 +1484,7 @@ class TestRerankErrorHandling:
@pytest.fixture(autouse=True)
def mock_model_manager(self):
"""Auto-use fixture to patch ModelManager for all tests in this class."""
with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm:
with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm:
mock_mm.return_value.check_model_support_vision.return_value = False
yield mock_mm
@ -1592,9 +1592,9 @@ class TestRerankErrorHandling:
runner = WeightRerankRunner(tenant_id="tenant123", weights=weights)
with (
patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler") as mock_jieba,
patch("core.rag.rerank.weight_rerank.ModelManager") as mock_manager,
patch("core.rag.rerank.weight_rerank.CacheEmbedding") as mock_cache,
patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba,
patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager,
patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache,
):
mock_handler = MagicMock()
mock_handler.extract_keywords.return_value = ["test"]

View File

@ -48,7 +48,7 @@ class TestRepositoryFactory:
import_string("invalidpath")
assert "doesn't look like a module path" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
@patch("core.repositories.factory.dify_config", autospec=True)
def test_create_workflow_execution_repository_success(self, mock_config):
"""Test successful WorkflowExecutionRepository creation."""
# Setup mock configuration
@ -66,7 +66,7 @@ class TestRepositoryFactory:
mock_repository_class.return_value = mock_repository_instance
# Mock import_string
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
with patch("core.repositories.factory.import_string", return_value=mock_repository_class, autospec=True):
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
@ -83,7 +83,7 @@ class TestRepositoryFactory:
)
assert result is mock_repository_instance
@patch("core.repositories.factory.dify_config")
@patch("core.repositories.factory.dify_config", autospec=True)
def test_create_workflow_execution_repository_import_error(self, mock_config):
"""Test WorkflowExecutionRepository creation with import error."""
# Setup mock configuration with invalid class path
@ -101,7 +101,7 @@ class TestRepositoryFactory:
)
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
@patch("core.repositories.factory.dify_config", autospec=True)
def test_create_workflow_execution_repository_instantiation_error(self, mock_config):
"""Test WorkflowExecutionRepository creation with instantiation error."""
# Setup mock configuration
@ -115,7 +115,7 @@ class TestRepositoryFactory:
mock_repository_class.side_effect = Exception("Instantiation failed")
# Mock import_string to return a failing class
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
with patch("core.repositories.factory.import_string", return_value=mock_repository_class, autospec=True):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory,
@ -125,7 +125,7 @@ class TestRepositoryFactory:
)
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
@patch("core.repositories.factory.dify_config", autospec=True)
def test_create_workflow_node_execution_repository_success(self, mock_config):
"""Test successful WorkflowNodeExecutionRepository creation."""
# Setup mock configuration
@ -143,7 +143,7 @@ class TestRepositoryFactory:
mock_repository_class.return_value = mock_repository_instance
# Mock import_string
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
with patch("core.repositories.factory.import_string", return_value=mock_repository_class, autospec=True):
result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
user=mock_user,
@ -160,7 +160,7 @@ class TestRepositoryFactory:
)
assert result is mock_repository_instance
@patch("core.repositories.factory.dify_config")
@patch("core.repositories.factory.dify_config", autospec=True)
def test_create_workflow_node_execution_repository_import_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with import error."""
# Setup mock configuration with invalid class path
@ -178,7 +178,7 @@ class TestRepositoryFactory:
)
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
@patch("core.repositories.factory.dify_config")
@patch("core.repositories.factory.dify_config", autospec=True)
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
"""Test WorkflowNodeExecutionRepository creation with instantiation error."""
# Setup mock configuration
@ -192,7 +192,7 @@ class TestRepositoryFactory:
mock_repository_class.side_effect = Exception("Instantiation failed")
# Mock import_string to return a failing class
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
with patch("core.repositories.factory.import_string", return_value=mock_repository_class, autospec=True):
with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory,
@ -208,7 +208,7 @@ class TestRepositoryFactory:
error = RepositoryImportError(error_message)
assert str(error) == error_message
@patch("core.repositories.factory.dify_config")
@patch("core.repositories.factory.dify_config", autospec=True)
def test_create_with_engine_instead_of_sessionmaker(self, mock_config):
"""Test repository creation with Engine instead of sessionmaker."""
# Setup mock configuration
@ -226,7 +226,7 @@ class TestRepositoryFactory:
mock_repository_class.return_value = mock_repository_instance
# Mock import_string
with patch("core.repositories.factory.import_string", return_value=mock_repository_class):
with patch("core.repositories.factory.import_string", return_value=mock_repository_class, autospec=True):
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=mock_engine, # Using Engine instead of sessionmaker
user=mock_user,

View File

@ -196,7 +196,7 @@ class TestSchemaResolver:
resolved1 = resolve_dify_schema_refs(schema)
# Mock the registry to return different data
with patch.object(self.registry, "get_schema") as mock_get:
with patch.object(self.registry, "get_schema", autospec=True) as mock_get:
mock_get.return_value = {"type": "different"}
# Second resolution should use cache
@ -445,7 +445,7 @@ class TestSchemaResolverClass:
# Second resolver should use the same cache
resolver2 = SchemaResolver()
with patch.object(resolver2.registry, "get_schema") as mock_get:
with patch.object(resolver2.registry, "get_schema", autospec=True) as mock_get:
result2 = resolver2.resolve(schema)
# Should not call registry since it's in cache
mock_get.assert_not_called()

View File

@ -138,8 +138,8 @@ class TestFlaskExecutionContext:
class TestCaptureFlaskContext:
"""Test capture_flask_context function."""
@patch("context.flask_app_context.current_app")
@patch("context.flask_app_context.g")
@patch("context.flask_app_context.current_app", autospec=True)
@patch("context.flask_app_context.g", autospec=True)
def test_capture_flask_context_captures_app(self, mock_g, mock_current_app):
"""Test capture_flask_context captures Flask app."""
mock_app = MagicMock()
@ -152,8 +152,8 @@ class TestCaptureFlaskContext:
assert ctx._flask_app == mock_app
@patch("context.flask_app_context.current_app")
@patch("context.flask_app_context.g")
@patch("context.flask_app_context.current_app", autospec=True)
@patch("context.flask_app_context.g", autospec=True)
def test_capture_flask_context_captures_user_from_g(self, mock_g, mock_current_app):
"""Test capture_flask_context captures user from Flask g object."""
mock_app = MagicMock()
@ -170,7 +170,7 @@ class TestCaptureFlaskContext:
assert ctx.user == mock_user
@patch("context.flask_app_context.current_app")
@patch("context.flask_app_context.current_app", autospec=True)
def test_capture_flask_context_with_explicit_user(self, mock_current_app):
"""Test capture_flask_context uses explicit user parameter."""
mock_app = MagicMock()
@ -186,7 +186,7 @@ class TestCaptureFlaskContext:
assert ctx.user == explicit_user
@patch("context.flask_app_context.current_app")
@patch("context.flask_app_context.current_app", autospec=True)
def test_capture_flask_context_captures_contextvars(self, mock_current_app):
"""Test capture_flask_context captures context variables."""
mock_app = MagicMock()
@ -267,7 +267,7 @@ class TestFlaskExecutionContextIntegration:
# Verify app context was entered
assert mock_flask_app.app_context.called
@patch("context.flask_app_context.g")
@patch("context.flask_app_context.g", autospec=True)
def test_enter_restores_user_in_g(self, mock_g, mock_flask_app):
"""Test that enter restores user in Flask g object."""
mock_user = MagicMock()

View File

@ -138,10 +138,10 @@ class TestGraphRuntimeState:
_ = state.response_coordinator
mock_graph = MagicMock()
with patch("core.workflow.graph_engine.response_coordinator.ResponseStreamCoordinator") as coordinator_cls:
coordinator_instance = MagicMock()
coordinator_cls.return_value = coordinator_instance
with patch(
"core.workflow.graph_engine.response_coordinator.ResponseStreamCoordinator", autospec=True
) as coordinator_cls:
coordinator_instance = coordinator_cls.return_value
state.configure(graph=mock_graph)
assert state.response_coordinator is coordinator_instance
@ -204,7 +204,7 @@ class TestGraphRuntimeState:
mock_graph = MagicMock()
stub = StubCoordinator()
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub):
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub, autospec=True):
state.attach_graph(mock_graph)
stub.state = "configured"
@ -230,7 +230,7 @@ class TestGraphRuntimeState:
assert restored_execution.started is True
new_stub = StubCoordinator()
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub):
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub, autospec=True):
restored.attach_graph(mock_graph)
assert new_stub.state == "configured"
@ -251,14 +251,14 @@ class TestGraphRuntimeState:
mock_graph = MagicMock()
original_stub = StubCoordinator()
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=original_stub):
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=original_stub, autospec=True):
state.attach_graph(mock_graph)
original_stub.state = "configured"
snapshot = state.dumps()
new_stub = StubCoordinator()
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub):
with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub, autospec=True):
restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
restored.attach_graph(mock_graph)
restored.loads(snapshot)

View File

@ -63,7 +63,7 @@ class TestPrivateWorkflowPauseEntity:
assert entity.resumed_at is None
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage", autospec=True)
def test_get_state_first_call(self, mock_storage):
"""Test get_state loads from storage on first call."""
state_data = b'{"test": "data", "step": 5}'
@ -81,7 +81,7 @@ class TestPrivateWorkflowPauseEntity:
mock_storage.load.assert_called_once_with("test-state-key")
assert entity._cached_state == state_data
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage", autospec=True)
def test_get_state_cached_call(self, mock_storage):
"""Test get_state returns cached data on subsequent calls."""
state_data = b'{"test": "data", "step": 5}'
@ -102,7 +102,7 @@ class TestPrivateWorkflowPauseEntity:
# Storage should only be called once
mock_storage.load.assert_called_once_with("test-state-key")
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage")
@patch("repositories.sqlalchemy_api_workflow_run_repository.storage", autospec=True)
def test_get_state_with_pre_cached_data(self, mock_storage):
"""Test get_state returns pre-cached data."""
state_data = b'{"test": "data", "step": 5}'
@ -125,7 +125,7 @@ class TestPrivateWorkflowPauseEntity:
# Test with binary data that's not valid JSON
binary_data = b"\x00\x01\x02\x03\x04\x05\xff\xfe"
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage:
with patch("repositories.sqlalchemy_api_workflow_run_repository.storage", autospec=True) as mock_storage:
mock_storage.load.return_value = binary_data
mock_pause_model = MagicMock(spec=WorkflowPauseModel)

View File

@ -90,14 +90,14 @@ def mock_tool_node():
@pytest.fixture
def mock_is_instrument_flag_enabled_false():
"""Mock is_instrument_flag_enabled to return False."""
with patch("core.app.workflow.layers.observability.is_instrument_flag_enabled", return_value=False):
with patch("core.app.workflow.layers.observability.is_instrument_flag_enabled", return_value=False, autospec=True):
yield
@pytest.fixture
def mock_is_instrument_flag_enabled_true():
"""Mock is_instrument_flag_enabled to return True."""
with patch("core.app.workflow.layers.observability.is_instrument_flag_enabled", return_value=True):
with patch("core.app.workflow.layers.observability.is_instrument_flag_enabled", return_value=True, autospec=True):
yield

View File

@ -117,9 +117,7 @@ def test_parallel_streaming_workflow():
# Create node factory and graph
node_factory = DifyNodeFactory(graph_init_params=init_params, graph_runtime_state=graph_runtime_state)
with patch.object(
DifyNodeFactory,
"_build_model_instance_for_llm_node",
return_value=MagicMock(spec=ModelInstance),
DifyNodeFactory, "_build_model_instance_for_llm_node", return_value=MagicMock(spec=ModelInstance), autospec=True
):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)

View File

@ -378,7 +378,7 @@ class TestStopEventIntegration:
class TestStopEventTimeoutBehavior:
"""Test stop_event behavior with join timeouts."""
@patch("core.workflow.graph_engine.orchestration.dispatcher.threading.Thread")
@patch("core.workflow.graph_engine.orchestration.dispatcher.threading.Thread", autospec=True)
def test_dispatcher_uses_shorter_timeout(self, mock_thread_cls: MagicMock):
"""Test that Dispatcher uses 2s timeout instead of 10s."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
@ -405,7 +405,7 @@ class TestStopEventTimeoutBehavior:
mock_thread_instance.join.assert_called_once_with(timeout=2.0)
@patch("core.workflow.graph_engine.worker_management.worker_pool.Worker")
@patch("core.workflow.graph_engine.worker_management.worker_pool.Worker", autospec=True)
def test_worker_pool_uses_shorter_timeout(self, mock_worker_cls: MagicMock):
"""Test that WorkerPool uses 2s timeout instead of 10s."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())

View File

@ -198,11 +198,10 @@ def test_fetch_model_config_uses_ports(model_config: ModelConfigWithCredentialsE
provider_model_bundle.configuration.__class__,
"get_provider_model",
return_value=provider_model,
autospec=True,
),
mock.patch.object(
model_type_instance.__class__,
"get_model_schema",
return_value=model_config.model_schema,
model_type_instance.__class__, "get_model_schema", return_value=model_config.model_schema, autospec=True
),
):
fetch_model_config(

View File

@ -128,7 +128,8 @@ class TestTemplateTransformNode:
assert TemplateTransformNode.version() == "1"
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
autospec=True,
)
def test_run_simple_template(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
@ -165,7 +166,8 @@ class TestTemplateTransformNode:
assert result.inputs["age"] == 30
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
autospec=True,
)
def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with None variable values."""
@ -192,7 +194,8 @@ class TestTemplateTransformNode:
assert result.inputs["value"] is None
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
autospec=True,
)
def test_run_with_code_execution_error(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
@ -215,7 +218,8 @@ class TestTemplateTransformNode:
assert "Template syntax error" in result.error
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
autospec=True,
)
def test_run_output_length_exceeds_limit(
self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params
@ -239,7 +243,8 @@ class TestTemplateTransformNode:
assert "Output length exceeds" in result.error
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
autospec=True,
)
def test_run_with_complex_jinja2_template(
self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params
@ -303,7 +308,8 @@ class TestTemplateTransformNode:
assert mapping["node_123.var2"] == ["sys", "input2"]
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
autospec=True,
)
def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with no variables (static template)."""
@ -330,7 +336,8 @@ class TestTemplateTransformNode:
assert result.inputs == {}
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
autospec=True,
)
def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with numeric variable values."""
@ -369,7 +376,8 @@ class TestTemplateTransformNode:
assert result.outputs["output"] == "Total: $31.5"
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
autospec=True,
)
def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with dictionary variable values."""
@ -400,7 +408,8 @@ class TestTemplateTransformNode:
assert "john@example.com" in result.outputs["output"]
@patch(
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template"
"core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template",
autospec=True,
)
def test_run_with_list_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params):
"""Test _run with list variable values."""

View File

@ -92,7 +92,9 @@ def _run_transform(tool_node: ToolNode, message: ToolInvokeMessage) -> tuple[lis
return messages
tool_runtime = MagicMock()
with patch.object(ToolFileMessageTransformer, "transform_tool_invoke_messages", side_effect=_identity_transform):
with patch.object(
ToolFileMessageTransformer, "transform_tool_invoke_messages", side_effect=_identity_transform, autospec=True
):
generator = tool_node._transform_message(
messages=iter([message]),
tool_info={"provider_type": "builtin", "provider_id": "provider"},

View File

@ -26,11 +26,8 @@ class TestWorkflowEntryRedisChannel:
redis_channel = RedisChannel(mock_redis_client, "test:channel:key")
# Patch GraphEngine to verify it receives the Redis channel
with patch("core.workflow.workflow_entry.GraphEngine") as MockGraphEngine:
mock_graph_engine = MagicMock()
MockGraphEngine.return_value = mock_graph_engine
# Create WorkflowEntry with Redis channel
with patch("core.workflow.workflow_entry.GraphEngine", autospec=True) as MockGraphEngine:
mock_graph_engine = MockGraphEngine.return_value # Create WorkflowEntry with Redis channel
workflow_entry = WorkflowEntry(
tenant_id="test-tenant",
app_id="test-app",
@ -63,15 +60,11 @@ class TestWorkflowEntryRedisChannel:
# Patch GraphEngine and InMemoryChannel
with (
patch("core.workflow.workflow_entry.GraphEngine") as MockGraphEngine,
patch("core.workflow.workflow_entry.InMemoryChannel") as MockInMemoryChannel,
patch("core.workflow.workflow_entry.GraphEngine", autospec=True) as MockGraphEngine,
patch("core.workflow.workflow_entry.InMemoryChannel", autospec=True) as MockInMemoryChannel,
):
mock_graph_engine = MagicMock()
MockGraphEngine.return_value = mock_graph_engine
mock_inmemory_channel = MagicMock()
MockInMemoryChannel.return_value = mock_inmemory_channel
# Create WorkflowEntry without providing a channel
mock_graph_engine = MockGraphEngine.return_value
mock_inmemory_channel = MockInMemoryChannel.return_value # Create WorkflowEntry without providing a channel
workflow_entry = WorkflowEntry(
tenant_id="test-tenant",
app_id="test-app",
@ -114,7 +107,7 @@ class TestWorkflowEntryRedisChannel:
mock_event2 = MagicMock()
# Patch GraphEngine
with patch("core.workflow.workflow_entry.GraphEngine") as MockGraphEngine:
with patch("core.workflow.workflow_entry.GraphEngine", autospec=True) as MockGraphEngine:
mock_graph_engine = MagicMock()
mock_graph_engine.run.return_value = iter([mock_event1, mock_event2])
MockGraphEngine.return_value = mock_graph_engine

View File

@ -40,7 +40,7 @@ def mock_upload_file():
mock.source_url = TEST_REMOTE_URL
mock.size = 1024
mock.key = "test_key"
with patch("factories.file_factory.db.session.scalar", return_value=mock) as m:
with patch("factories.file_factory.db.session.scalar", return_value=mock, autospec=True) as m:
yield m
@ -54,7 +54,7 @@ def mock_tool_file():
mock.mimetype = "application/pdf"
mock.original_url = "http://example.com/tool.pdf"
mock.size = 2048
with patch("factories.file_factory.db.session.scalar", return_value=mock):
with patch("factories.file_factory.db.session.scalar", return_value=mock, autospec=True):
yield mock
@ -70,7 +70,7 @@ def mock_http_head():
},
)
with patch("factories.file_factory.ssrf_proxy.head") as mock_head:
with patch("factories.file_factory.ssrf_proxy.head", autospec=True) as mock_head:
mock_head.return_value = _mock_response("remote_test.jpg", 2048, "image/jpeg")
yield mock_head
@ -188,7 +188,7 @@ def test_build_from_remote_url_without_strict_validation(mock_http_head):
def test_tool_file_not_found():
"""Test ToolFile not found in database."""
with patch("factories.file_factory.db.session.scalar", return_value=None):
with patch("factories.file_factory.db.session.scalar", return_value=None, autospec=True):
mapping = tool_file_mapping()
with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
@ -196,7 +196,7 @@ def test_tool_file_not_found():
def test_local_file_not_found():
"""Test UploadFile not found in database."""
with patch("factories.file_factory.db.session.scalar", return_value=None):
with patch("factories.file_factory.db.session.scalar", return_value=None, autospec=True):
mapping = local_file_mapping()
with pytest.raises(ValueError, match="Invalid upload file"):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
@ -268,7 +268,7 @@ def test_tenant_mismatch():
mock_file.key = "test_key"
# Mock the database query to return None (no file found for this tenant)
with patch("factories.file_factory.db.session.scalar", return_value=None):
with patch("factories.file_factory.db.session.scalar", return_value=None, autospec=True):
mapping = local_file_mapping()
with pytest.raises(ValueError, match="Invalid upload file"):
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)

View File

@ -403,7 +403,7 @@ class TestRedisSubscription:
# ==================== Listener Thread Tests ====================
@patch("time.sleep", side_effect=lambda x: None) # Speed up test
@patch("time.sleep", side_effect=lambda x: None, autospec=True) # Speed up test
def test_listener_thread_normal_operation(
self, mock_sleep, subscription: _RedisSubscription, mock_pubsub: MagicMock
):
@ -826,7 +826,7 @@ class TestRedisShardedSubscription:
# ==================== Listener Thread Tests ====================
@patch("time.sleep", side_effect=lambda x: None) # Speed up test
@patch("time.sleep", side_effect=lambda x: None, autospec=True) # Speed up test
def test_listener_thread_normal_operation(
self, mock_sleep, sharded_subscription: _RedisShardedSubscription, mock_pubsub: MagicMock
):

View File

@ -104,7 +104,7 @@ class TestParseTimeRange:
def test_parse_time_range_dst_ambiguous_time(self):
"""Test parsing during DST ambiguous time (fall back)."""
# This test simulates DST fall back where 2:30 AM occurs twice
with patch("pytz.timezone") as mock_timezone:
with patch("pytz.timezone", autospec=True) as mock_timezone:
# Mock timezone that raises AmbiguousTimeError
mock_tz = mock_timezone.return_value
@ -135,7 +135,7 @@ class TestParseTimeRange:
def test_parse_time_range_dst_nonexistent_time(self):
"""Test parsing during DST nonexistent time (spring forward)."""
with patch("pytz.timezone") as mock_timezone:
with patch("pytz.timezone", autospec=True) as mock_timezone:
# Mock timezone that raises NonExistentTimeError
mock_tz = mock_timezone.return_value

View File

@ -55,7 +55,7 @@ class TestLoginRequired:
with setup_app.test_request_context():
# Mock authenticated user
mock_user = MockUser("test_user", is_authenticated=True)
with patch("libs.login._get_user", return_value=mock_user):
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
result = protected_view()
assert result == "Protected content"
@ -70,7 +70,7 @@ class TestLoginRequired:
with setup_app.test_request_context():
# Mock unauthenticated user
mock_user = MockUser("test_user", is_authenticated=False)
with patch("libs.login._get_user", return_value=mock_user):
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
result = protected_view()
assert result == "Unauthorized"
setup_app.login_manager.unauthorized.assert_called_once()
@ -86,8 +86,8 @@ class TestLoginRequired:
with setup_app.test_request_context():
# Mock unauthenticated user and LOGIN_DISABLED
mock_user = MockUser("test_user", is_authenticated=False)
with patch("libs.login._get_user", return_value=mock_user):
with patch("libs.login.dify_config") as mock_config:
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
with patch("libs.login.dify_config", autospec=True) as mock_config:
mock_config.LOGIN_DISABLED = True
result = protected_view()
@ -106,7 +106,7 @@ class TestLoginRequired:
with setup_app.test_request_context(method="OPTIONS"):
# Mock unauthenticated user
mock_user = MockUser("test_user", is_authenticated=False)
with patch("libs.login._get_user", return_value=mock_user):
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
result = protected_view()
assert result == "Protected content"
# Ensure unauthorized was not called
@ -125,7 +125,7 @@ class TestLoginRequired:
with setup_app.test_request_context():
mock_user = MockUser("test_user", is_authenticated=True)
with patch("libs.login._get_user", return_value=mock_user):
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
result = protected_view()
assert result == "Synced content"
setup_app.ensure_sync.assert_called_once()
@ -144,7 +144,7 @@ class TestLoginRequired:
with setup_app.test_request_context():
mock_user = MockUser("test_user", is_authenticated=True)
with patch("libs.login._get_user", return_value=mock_user):
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
result = protected_view()
assert result == "Protected content"
@ -197,14 +197,14 @@ class TestCurrentUser:
mock_user = MockUser("test_user", is_authenticated=True)
with app.test_request_context():
with patch("libs.login._get_user", return_value=mock_user):
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
assert current_user.id == "test_user"
assert current_user.is_authenticated is True
def test_current_user_proxy_returns_none_when_no_user(self, app: Flask):
"""Test that current_user proxy handles None user."""
with app.test_request_context():
with patch("libs.login._get_user", return_value=None):
with patch("libs.login._get_user", return_value=None, autospec=True):
# When _get_user returns None, accessing attributes should fail
# or current_user should evaluate to falsy
try:
@ -224,7 +224,7 @@ class TestCurrentUser:
def check_user_in_thread(user_id: str, index: int):
with app.test_request_context():
mock_user = MockUser(user_id)
with patch("libs.login._get_user", return_value=mock_user):
with patch("libs.login._get_user", return_value=mock_user, autospec=True):
results[index] = current_user.id
# Create multiple threads with different users

View File

@ -68,7 +68,7 @@ class TestGitHubOAuth(BaseOAuthTest):
({}, None, True),
],
)
@patch("httpx.post")
@patch("httpx.post", autospec=True)
def test_should_retrieve_access_token(
self, mock_post, oauth, mock_response, response_data, expected_token, should_raise
):
@ -105,7 +105,7 @@ class TestGitHubOAuth(BaseOAuthTest):
),
],
)
@patch("httpx.get")
@patch("httpx.get", autospec=True)
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email):
user_response = MagicMock()
user_response.json.return_value = user_data
@ -121,7 +121,7 @@ class TestGitHubOAuth(BaseOAuthTest):
assert user_info.name == user_data["name"]
assert user_info.email == expected_email
@patch("httpx.get")
@patch("httpx.get", autospec=True)
def test_should_handle_network_errors(self, mock_get, oauth):
mock_get.side_effect = httpx.RequestError("Network error")
@ -167,7 +167,7 @@ class TestGoogleOAuth(BaseOAuthTest):
({}, None, True),
],
)
@patch("httpx.post")
@patch("httpx.post", autospec=True)
def test_should_retrieve_access_token(
self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise
):
@ -201,7 +201,7 @@ class TestGoogleOAuth(BaseOAuthTest):
({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string
],
)
@patch("httpx.get")
@patch("httpx.get", autospec=True)
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name):
mock_response.json.return_value = user_data
mock_get.return_value = mock_response
@ -222,7 +222,7 @@ class TestGoogleOAuth(BaseOAuthTest):
httpx.TimeoutException,
],
)
@patch("httpx.get")
@patch("httpx.get", autospec=True)
def test_should_handle_http_errors(self, mock_get, oauth, exception_type):
mock_response = MagicMock()
mock_response.raise_for_status.side_effect = exception_type("Error")

View File

@ -9,11 +9,9 @@ def _mail() -> dict:
return {"to": "user@example.com", "subject": "Hi", "html": "<b>Hi</b>"}
@patch("libs.smtp.smtplib.SMTP")
@patch("libs.smtp.smtplib.SMTP", autospec=True)
def test_smtp_plain_success(mock_smtp_cls: MagicMock):
mock_smtp = MagicMock()
mock_smtp_cls.return_value = mock_smtp
mock_smtp = mock_smtp_cls.return_value
client = SMTPClient(server="smtp.example.com", port=25, username="", password="", _from="noreply@example.com")
client.send(_mail())
@ -22,11 +20,9 @@ def test_smtp_plain_success(mock_smtp_cls: MagicMock):
mock_smtp.quit.assert_called_once()
@patch("libs.smtp.smtplib.SMTP")
@patch("libs.smtp.smtplib.SMTP", autospec=True)
def test_smtp_tls_opportunistic_success(mock_smtp_cls: MagicMock):
mock_smtp = MagicMock()
mock_smtp_cls.return_value = mock_smtp
mock_smtp = mock_smtp_cls.return_value
client = SMTPClient(
server="smtp.example.com",
port=587,
@ -46,7 +42,7 @@ def test_smtp_tls_opportunistic_success(mock_smtp_cls: MagicMock):
mock_smtp.quit.assert_called_once()
@patch("libs.smtp.smtplib.SMTP_SSL")
@patch("libs.smtp.smtplib.SMTP_SSL", autospec=True)
def test_smtp_tls_ssl_branch_and_timeout(mock_smtp_ssl_cls: MagicMock):
# Cover SMTP_SSL branch and TimeoutError handling
mock_smtp = MagicMock()
@ -67,7 +63,7 @@ def test_smtp_tls_ssl_branch_and_timeout(mock_smtp_ssl_cls: MagicMock):
mock_smtp.quit.assert_called_once()
@patch("libs.smtp.smtplib.SMTP")
@patch("libs.smtp.smtplib.SMTP", autospec=True)
def test_smtp_generic_exception_propagates(mock_smtp_cls: MagicMock):
mock_smtp = MagicMock()
mock_smtp.sendmail.side_effect = RuntimeError("oops")
@ -79,7 +75,7 @@ def test_smtp_generic_exception_propagates(mock_smtp_cls: MagicMock):
mock_smtp.quit.assert_called_once()
@patch("libs.smtp.smtplib.SMTP")
@patch("libs.smtp.smtplib.SMTP", autospec=True)
def test_smtp_smtplib_exception_in_login(mock_smtp_cls: MagicMock):
# Ensure we hit the specific SMTPException except branch
import smtplib

View File

@ -301,7 +301,7 @@ class TestAppModelConfig:
)
# Mock database query to return None
with patch("models.model.db.session.query") as mock_query:
with patch("models.model.db.session.query", autospec=True) as mock_query:
mock_query.return_value.where.return_value.first.return_value = None
# Act
@ -952,7 +952,7 @@ class TestSiteModel:
def test_site_generate_code(self):
"""Test Site.generate_code static method."""
# Mock database query to return 0 (no existing codes)
with patch("models.model.db.session.query") as mock_query:
with patch("models.model.db.session.query", autospec=True) as mock_query:
mock_query.return_value.where.return_value.count.return_value = 0
# Act
@ -1167,7 +1167,7 @@ class TestConversationStatusCount:
conversation.id = str(uuid4())
# Mock the database query to return no messages
with patch("models.model.db.session.scalars") as mock_scalars:
with patch("models.model.db.session.scalars", autospec=True) as mock_scalars:
mock_scalars.return_value.all.return_value = []
# Act
@ -1192,7 +1192,7 @@ class TestConversationStatusCount:
conversation.id = conversation_id
# Mock the database query to return no messages with workflow_run_id
with patch("models.model.db.session.scalars") as mock_scalars:
with patch("models.model.db.session.scalars", autospec=True) as mock_scalars:
mock_scalars.return_value.all.return_value = []
# Act
@ -1277,7 +1277,7 @@ class TestConversationStatusCount:
return mock_result
# Act & Assert
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
with patch("models.model.db.session.scalars", side_effect=mock_scalars, autospec=True):
result = conversation.status_count
# Verify only 2 database queries were made (not N+1)
@ -1340,7 +1340,7 @@ class TestConversationStatusCount:
return mock_result
# Act
with patch("models.model.db.session.scalars", side_effect=mock_scalars):
with patch("models.model.db.session.scalars", side_effect=mock_scalars, autospec=True):
result = conversation.status_count
# Assert - query should include app_id filter
@ -1385,7 +1385,7 @@ class TestConversationStatusCount:
),
]
with patch("models.model.db.session.scalars") as mock_scalars:
with patch("models.model.db.session.scalars", autospec=True) as mock_scalars:
# Mock the messages query
def mock_scalars_side_effect(query):
mock_result = MagicMock()
@ -1441,7 +1441,7 @@ class TestConversationStatusCount:
),
]
with patch("models.model.db.session.scalars") as mock_scalars:
with patch("models.model.db.session.scalars", autospec=True) as mock_scalars:
def mock_scalars_side_effect(query):
mock_result = MagicMock()

View File

@ -15,7 +15,7 @@ class TestTencentCos(BaseStorageTest):
@pytest.fixture(autouse=True)
def setup_method(self, setup_tencent_cos_mock):
"""Executed before each test method."""
with patch.object(CosConfig, "__init__", return_value=None):
with patch.object(CosConfig, "__init__", return_value=None, autospec=True):
self.storage = TencentCosStorage()
self.storage.bucket_name = get_example_bucket()
@ -39,9 +39,9 @@ class TestTencentCosConfiguration:
with (
patch("extensions.storage.tencent_cos_storage.dify_config", mock_dify_config),
patch(
"extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance
"extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance, autospec=True
) as mock_cos_config,
patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client),
patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client, autospec=True),
):
TencentCosStorage()
@ -72,9 +72,9 @@ class TestTencentCosConfiguration:
with (
patch("extensions.storage.tencent_cos_storage.dify_config", mock_dify_config),
patch(
"extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance
"extensions.storage.tencent_cos_storage.CosConfig", return_value=mock_config_instance, autospec=True
) as mock_cos_config,
patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client),
patch("extensions.storage.tencent_cos_storage.CosS3Client", return_value=mock_client, autospec=True),
):
TencentCosStorage()

View File

@ -19,7 +19,7 @@ class TestApiKeyAuthFactory:
)
def test_get_apikey_auth_factory_valid_providers(self, provider, auth_class_path):
"""Test getting auth factory for all valid providers"""
with patch(auth_class_path) as mock_auth:
with patch(auth_class_path, autospec=True) as mock_auth:
auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider)
assert auth_class == mock_auth
@ -46,7 +46,7 @@ class TestApiKeyAuthFactory:
(False, False),
],
)
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory")
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory", autospec=True)
def test_validate_credentials_delegates_to_auth_instance(
self, mock_get_factory, credentials_return_value, expected_result
):
@ -65,7 +65,7 @@ class TestApiKeyAuthFactory:
assert result is expected_result
mock_auth_instance.validate_credentials.assert_called_once()
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory")
@patch("services.auth.api_key_auth_factory.ApiKeyAuthFactory.get_apikey_auth_factory", autospec=True)
def test_validate_credentials_propagates_exceptions(self, mock_get_factory):
"""Test that exceptions from auth instance are propagated"""
# Arrange

View File

@ -65,7 +65,7 @@ class TestFirecrawlAuth:
FirecrawlAuth(credentials)
assert str(exc_info.value) == expected_error
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance):
"""Test successful credential validation"""
mock_response = MagicMock()
@ -96,7 +96,7 @@ class TestFirecrawlAuth:
(500, "Internal server error"),
],
)
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance):
"""Test handling of various HTTP error codes"""
mock_response = MagicMock()
@ -118,7 +118,7 @@ class TestFirecrawlAuth:
(401, "Not JSON", True, "Failed to authorize. Status code: 401. Error: Not JSON"),
],
)
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
def test_should_handle_unexpected_errors(
self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance
):
@ -145,7 +145,7 @@ class TestFirecrawlAuth:
(httpx.ConnectTimeout, "Connection timeout"),
],
)
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance):
"""Test handling of various network-related errors including timeouts"""
mock_post.side_effect = exception_type(exception_message)
@ -167,7 +167,7 @@ class TestFirecrawlAuth:
FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
assert "super_secret_key_12345" not in str(exc_info.value)
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
def test_should_use_custom_base_url_in_validation(self, mock_post):
"""Test that custom base URL is used in validation and normalized"""
mock_response = MagicMock()
@ -185,7 +185,7 @@ class TestFirecrawlAuth:
assert result is True
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
@patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.firecrawl.firecrawl.httpx.post", autospec=True)
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
"""Test that timeout errors are handled gracefully with appropriate error message"""
mock_post.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")

View File

@ -35,7 +35,7 @@ class TestJinaAuth:
JinaAuth(credentials)
assert str(exc_info.value) == "No API key provided"
@patch("services.auth.jina.jina.httpx.post")
@patch("services.auth.jina.jina.httpx.post", autospec=True)
def test_should_validate_valid_credentials_successfully(self, mock_post):
"""Test successful credential validation"""
mock_response = MagicMock()
@ -53,7 +53,7 @@ class TestJinaAuth:
json={"url": "https://example.com"},
)
@patch("services.auth.jina.jina.httpx.post")
@patch("services.auth.jina.jina.httpx.post", autospec=True)
def test_should_handle_http_402_error(self, mock_post):
"""Test handling of 402 Payment Required error"""
mock_response = MagicMock()
@ -68,7 +68,7 @@ class TestJinaAuth:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required"
@patch("services.auth.jina.jina.httpx.post")
@patch("services.auth.jina.jina.httpx.post", autospec=True)
def test_should_handle_http_409_error(self, mock_post):
"""Test handling of 409 Conflict error"""
mock_response = MagicMock()
@ -83,7 +83,7 @@ class TestJinaAuth:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error"
@patch("services.auth.jina.jina.httpx.post")
@patch("services.auth.jina.jina.httpx.post", autospec=True)
def test_should_handle_http_500_error(self, mock_post):
"""Test handling of 500 Internal Server Error"""
mock_response = MagicMock()
@ -98,7 +98,7 @@ class TestJinaAuth:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error"
@patch("services.auth.jina.jina.httpx.post")
@patch("services.auth.jina.jina.httpx.post", autospec=True)
def test_should_handle_unexpected_error_with_text_response(self, mock_post):
"""Test handling of unexpected errors with text response"""
mock_response = MagicMock()
@ -114,7 +114,7 @@ class TestJinaAuth:
auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden"
@patch("services.auth.jina.jina.httpx.post")
@patch("services.auth.jina.jina.httpx.post", autospec=True)
def test_should_handle_unexpected_error_without_text(self, mock_post):
"""Test handling of unexpected errors without text response"""
mock_response = MagicMock()
@ -130,7 +130,7 @@ class TestJinaAuth:
auth.validate_credentials()
assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404"
@patch("services.auth.jina.jina.httpx.post")
@patch("services.auth.jina.jina.httpx.post", autospec=True)
def test_should_handle_network_errors(self, mock_post):
"""Test handling of network connection errors"""
mock_post.side_effect = httpx.ConnectError("Network error")

View File

@ -64,7 +64,7 @@ class TestWatercrawlAuth:
WatercrawlAuth(credentials)
assert str(exc_info.value) == expected_error
@patch("services.auth.watercrawl.watercrawl.httpx.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance):
"""Test successful credential validation"""
mock_response = MagicMock()
@ -87,7 +87,7 @@ class TestWatercrawlAuth:
(500, "Internal server error"),
],
)
@patch("services.auth.watercrawl.watercrawl.httpx.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance):
"""Test handling of various HTTP error codes"""
mock_response = MagicMock()
@ -107,7 +107,7 @@ class TestWatercrawlAuth:
(401, "Not JSON", True, "Expecting value"), # JSON decode error
],
)
@patch("services.auth.watercrawl.watercrawl.httpx.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
def test_should_handle_unexpected_errors(
self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance
):
@ -132,7 +132,7 @@ class TestWatercrawlAuth:
(httpx.ConnectTimeout, "Connection timeout"),
],
)
@patch("services.auth.watercrawl.watercrawl.httpx.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance):
"""Test handling of various network-related errors including timeouts"""
mock_get.side_effect = exception_type(exception_message)
@ -154,7 +154,7 @@ class TestWatercrawlAuth:
WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}})
assert "super_secret_key_12345" not in str(exc_info.value)
@patch("services.auth.watercrawl.watercrawl.httpx.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
def test_should_use_custom_base_url_in_validation(self, mock_get):
"""Test that custom base URL is used in validation"""
mock_response = MagicMock()
@ -179,7 +179,7 @@ class TestWatercrawlAuth:
("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
],
)
@patch("services.auth.watercrawl.watercrawl.httpx.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url):
"""Test that urljoin is used correctly for URL construction with various base URLs"""
mock_response = MagicMock()
@ -193,7 +193,7 @@ class TestWatercrawlAuth:
# Verify the correct URL was called
assert mock_get.call_args[0][0] == expected_url
@patch("services.auth.watercrawl.watercrawl.httpx.get")
@patch("services.auth.watercrawl.watercrawl.httpx.get", autospec=True)
def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance):
"""Test that timeout errors are handled gracefully with appropriate error message"""
mock_get.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")

View File

@ -27,7 +27,7 @@ class TestTraceparentPropagation:
@pytest.fixture
def mock_httpx_client(self):
"""Mock httpx.Client for testing."""
with patch("services.enterprise.base.httpx.Client") as mock_client_class:
with patch("services.enterprise.base.httpx.Client", autospec=True) as mock_client_class:
mock_client = MagicMock()
mock_client_class.return_value.__enter__.return_value = mock_client
mock_client_class.return_value.__exit__.return_value = None
@ -44,7 +44,9 @@ class TestTraceparentPropagation:
# Arrange
expected_traceparent = "00-5b8aa5a2d2c872e8321cf37308d69df2-051581bf3bb55c45-01"
with patch("services.enterprise.base.generate_traceparent_header", return_value=expected_traceparent):
with patch(
"services.enterprise.base.generate_traceparent_header", return_value=expected_traceparent, autospec=True
):
# Act
EnterpriseRequest.send_request("GET", "/test")

View File

@ -135,8 +135,8 @@ class TestExternalDatasetServiceGetExternalKnowledgeApis:
"""
with (
patch("services.external_knowledge_service.db.paginate") as mock_paginate,
patch("services.external_knowledge_service.select"),
patch("services.external_knowledge_service.db.paginate", autospec=True) as mock_paginate,
patch("services.external_knowledge_service.select", autospec=True),
):
yield mock_paginate
@ -245,7 +245,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
Patch ``db.session`` for all CRUD tests in this class.
"""
with patch("services.external_knowledge_service.db.session") as mock_session:
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
yield mock_session
def test_create_external_knowledge_api_success(self, mock_db_session: MagicMock):
@ -263,7 +263,7 @@ class TestExternalDatasetServiceCrudExternalKnowledgeApi:
}
# We do not want to actually call the remote endpoint here, so we patch the validator.
with patch.object(ExternalDatasetService, "check_endpoint_and_api_key") as mock_check:
with patch.object(ExternalDatasetService, "check_endpoint_and_api_key", autospec=True) as mock_check:
result = ExternalDatasetService.create_external_knowledge_api(tenant_id, user_id, args)
assert isinstance(result, ExternalKnowledgeApis)
@ -386,7 +386,7 @@ class TestExternalDatasetServiceUsageAndBindings:
@pytest.fixture
def mock_db_session(self):
with patch("services.external_knowledge_service.db.session") as mock_session:
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
yield mock_session
def test_external_knowledge_api_use_check_in_use(self, mock_db_session: MagicMock):
@ -447,7 +447,7 @@ class TestExternalDatasetServiceDocumentCreateArgsValidate:
@pytest.fixture
def mock_db_session(self):
with patch("services.external_knowledge_service.db.session") as mock_session:
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
yield mock_session
def test_document_create_args_validate_success(self, mock_db_session: MagicMock):
@ -520,7 +520,7 @@ class TestExternalDatasetServiceProcessExternalApi:
fake_response = httpx.Response(200)
with patch("services.external_knowledge_service.ssrf_proxy.post") as mock_post:
with patch("services.external_knowledge_service.ssrf_proxy.post", autospec=True) as mock_post:
mock_post.return_value = fake_response
result = ExternalDatasetService.process_external_api(settings, files=None)
@ -681,7 +681,7 @@ class TestExternalDatasetServiceCreateExternalDataset:
@pytest.fixture
def mock_db_session(self):
with patch("services.external_knowledge_service.db.session") as mock_session:
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
yield mock_session
def test_create_external_dataset_success(self, mock_db_session: MagicMock):
@ -801,7 +801,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
@pytest.fixture
def mock_db_session(self):
with patch("services.external_knowledge_service.db.session") as mock_session:
with patch("services.external_knowledge_service.db.session", autospec=True) as mock_session:
yield mock_session
def test_fetch_external_knowledge_retrieval_success(self, mock_db_session: MagicMock):
@ -838,7 +838,9 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
metadata_condition = SimpleNamespace(model_dump=lambda: {"field": "value"})
with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response) as mock_process:
with patch.object(
ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True
) as mock_process:
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id=tenant_id,
dataset_id=dataset_id,
@ -908,7 +910,7 @@ class TestExternalDatasetServiceFetchExternalKnowledgeRetrieval:
fake_response.status_code = 500
fake_response.json.return_value = {}
with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response):
with patch.object(ExternalDatasetService, "process_external_api", return_value=fake_response, autospec=True):
result = ExternalDatasetService.fetch_external_knowledge_retrieval(
tenant_id="tenant-1",
dataset_id="ds-1",

View File

@ -146,7 +146,7 @@ class TestHitTestingServiceRetrieve:
Provides a mocked database session for testing database operations
like adding and committing DatasetQuery records.
"""
with patch("services.hit_testing_service.db.session") as mock_db:
with patch("services.hit_testing_service.db.session", autospec=True) as mock_db:
yield mock_db
def test_retrieve_success_with_default_retrieval_model(self, mock_db_session):
@ -174,9 +174,11 @@ class TestHitTestingServiceRetrieve:
]
with (
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
patch("services.hit_testing_service.RetrievalService.retrieve", autospec=True) as mock_retrieve,
patch(
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
) as mock_format,
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
):
mock_perf_counter.side_effect = [0.0, 0.1] # start, end
mock_retrieve.return_value = documents
@ -218,9 +220,11 @@ class TestHitTestingServiceRetrieve:
mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()]
with (
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
patch("services.hit_testing_service.RetrievalService.retrieve", autospec=True) as mock_retrieve,
patch(
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
) as mock_format,
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
):
mock_perf_counter.side_effect = [0.0, 0.1]
mock_retrieve.return_value = documents
@ -268,10 +272,12 @@ class TestHitTestingServiceRetrieve:
mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()]
with (
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
patch("services.hit_testing_service.DatasetRetrieval") as mock_dataset_retrieval_class,
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
patch("services.hit_testing_service.RetrievalService.retrieve", autospec=True) as mock_retrieve,
patch(
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
) as mock_format,
patch("services.hit_testing_service.DatasetRetrieval", autospec=True) as mock_dataset_retrieval_class,
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
):
mock_perf_counter.side_effect = [0.0, 0.1]
mock_dataset_retrieval_class.return_value = mock_dataset_retrieval
@ -311,8 +317,10 @@ class TestHitTestingServiceRetrieve:
mock_dataset_retrieval.get_metadata_filter_condition.return_value = ({}, True)
with (
patch("services.hit_testing_service.DatasetRetrieval") as mock_dataset_retrieval_class,
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
patch("services.hit_testing_service.DatasetRetrieval", autospec=True) as mock_dataset_retrieval_class,
patch(
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
) as mock_format,
):
mock_dataset_retrieval_class.return_value = mock_dataset_retrieval
mock_format.return_value = []
@ -346,9 +354,11 @@ class TestHitTestingServiceRetrieve:
mock_records = [HitTestingTestDataFactory.create_retrieval_record_mock()]
with (
patch("services.hit_testing_service.RetrievalService.retrieve") as mock_retrieve,
patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format,
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
patch("services.hit_testing_service.RetrievalService.retrieve", autospec=True) as mock_retrieve,
patch(
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
) as mock_format,
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
):
mock_perf_counter.side_effect = [0.0, 0.1]
mock_retrieve.return_value = documents
@ -380,7 +390,7 @@ class TestHitTestingServiceExternalRetrieve:
Provides a mocked database session for testing database operations
like adding and committing DatasetQuery records.
"""
with patch("services.hit_testing_service.db.session") as mock_db:
with patch("services.hit_testing_service.db.session", autospec=True) as mock_db:
yield mock_db
def test_external_retrieve_success(self, mock_db_session):
@ -403,8 +413,10 @@ class TestHitTestingServiceExternalRetrieve:
]
with (
patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve,
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
patch(
"services.hit_testing_service.RetrievalService.external_retrieve", autospec=True
) as mock_external_retrieve,
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
):
mock_perf_counter.side_effect = [0.0, 0.1]
mock_external_retrieve.return_value = external_documents
@ -467,8 +479,10 @@ class TestHitTestingServiceExternalRetrieve:
external_documents = [{"content": "Doc 1", "title": "Title", "score": 0.9, "metadata": {}}]
with (
patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve,
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
patch(
"services.hit_testing_service.RetrievalService.external_retrieve", autospec=True
) as mock_external_retrieve,
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
):
mock_perf_counter.side_effect = [0.0, 0.1]
mock_external_retrieve.return_value = external_documents
@ -499,8 +513,10 @@ class TestHitTestingServiceExternalRetrieve:
metadata_filtering_conditions = {}
with (
patch("services.hit_testing_service.RetrievalService.external_retrieve") as mock_external_retrieve,
patch("services.hit_testing_service.time.perf_counter") as mock_perf_counter,
patch(
"services.hit_testing_service.RetrievalService.external_retrieve", autospec=True
) as mock_external_retrieve,
patch("services.hit_testing_service.time.perf_counter", autospec=True) as mock_perf_counter,
):
mock_perf_counter.side_effect = [0.0, 0.1]
mock_external_retrieve.return_value = []
@ -542,7 +558,9 @@ class TestHitTestingServiceCompactRetrieveResponse:
HitTestingTestDataFactory.create_retrieval_record_mock(content="Doc 2", score=0.85),
]
with patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format:
with patch(
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
) as mock_format:
mock_format.return_value = mock_records
# Act
@ -566,7 +584,9 @@ class TestHitTestingServiceCompactRetrieveResponse:
query = "test query"
documents = []
with patch("services.hit_testing_service.RetrievalService.format_retrieval_documents") as mock_format:
with patch(
"services.hit_testing_service.RetrievalService.format_retrieval_documents", autospec=True
) as mock_format:
mock_format.return_value = []
# Act

View File

@ -147,7 +147,7 @@ class TestSegmentServiceCreateSegment:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
@ -172,10 +172,12 @@ class TestSegmentServiceCreateSegment:
mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment
with (
patch("services.dataset_service.redis_client.lock") as mock_lock,
patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
patch("services.dataset_service.naive_utc_now") as mock_now,
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
patch(
"services.dataset_service.VectorService.create_segments_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_lock.return_value.__enter__ = Mock()
mock_lock.return_value.__exit__ = Mock(return_value=None)
@ -219,10 +221,12 @@ class TestSegmentServiceCreateSegment:
mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment
with (
patch("services.dataset_service.redis_client.lock") as mock_lock,
patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
patch("services.dataset_service.naive_utc_now") as mock_now,
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
patch(
"services.dataset_service.VectorService.create_segments_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_lock.return_value.__enter__ = Mock()
mock_lock.return_value.__exit__ = Mock(return_value=None)
@ -257,11 +261,13 @@ class TestSegmentServiceCreateSegment:
mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment
with (
patch("services.dataset_service.redis_client.lock") as mock_lock,
patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service,
patch("services.dataset_service.ModelManager") as mock_model_manager_class,
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
patch("services.dataset_service.naive_utc_now") as mock_now,
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
patch(
"services.dataset_service.VectorService.create_segments_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.ModelManager", autospec=True) as mock_model_manager_class,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_lock.return_value.__enter__ = Mock()
mock_lock.return_value.__exit__ = Mock(return_value=None)
@ -292,10 +298,12 @@ class TestSegmentServiceCreateSegment:
mock_db_session.query.return_value.where.return_value.first.return_value = mock_segment
with (
patch("services.dataset_service.redis_client.lock") as mock_lock,
patch("services.dataset_service.VectorService.create_segments_vector") as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
patch("services.dataset_service.naive_utc_now") as mock_now,
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
patch(
"services.dataset_service.VectorService.create_segments_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_lock.return_value.__enter__ = Mock()
mock_lock.return_value.__exit__ = Mock(return_value=None)
@ -317,7 +325,7 @@ class TestSegmentServiceUpdateSegment:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
@ -338,10 +346,10 @@ class TestSegmentServiceUpdateSegment:
mock_db_session.query.return_value.where.return_value.first.return_value = segment
with (
patch("services.dataset_service.redis_client.get") as mock_redis_get,
patch("services.dataset_service.VectorService.update_segment_vector") as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
patch("services.dataset_service.naive_utc_now") as mock_now,
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
patch("services.dataset_service.VectorService.update_segment_vector", autospec=True) as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_redis_get.return_value = None # Not indexing
mock_hash.return_value = "new-hash"
@ -368,10 +376,10 @@ class TestSegmentServiceUpdateSegment:
args = SegmentUpdateArgs(enabled=False)
with (
patch("services.dataset_service.redis_client.get") as mock_redis_get,
patch("services.dataset_service.redis_client.setex") as mock_redis_setex,
patch("services.dataset_service.disable_segment_from_index_task") as mock_task,
patch("services.dataset_service.naive_utc_now") as mock_now,
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
patch("services.dataset_service.redis_client.setex", autospec=True) as mock_redis_setex,
patch("services.dataset_service.disable_segment_from_index_task", autospec=True) as mock_task,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_redis_get.return_value = None
mock_now.return_value = "2024-01-01T00:00:00"
@ -394,7 +402,7 @@ class TestSegmentServiceUpdateSegment:
dataset = SegmentTestDataFactory.create_dataset_mock()
args = SegmentUpdateArgs(content="Updated content")
with patch("services.dataset_service.redis_client.get") as mock_redis_get:
with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get:
mock_redis_get.return_value = "1" # Indexing in progress
# Act & Assert
@ -409,7 +417,7 @@ class TestSegmentServiceUpdateSegment:
dataset = SegmentTestDataFactory.create_dataset_mock()
args = SegmentUpdateArgs(content="Updated content")
with patch("services.dataset_service.redis_client.get") as mock_redis_get:
with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get:
mock_redis_get.return_value = None
# Act & Assert
@ -427,10 +435,10 @@ class TestSegmentServiceUpdateSegment:
mock_db_session.query.return_value.where.return_value.first.return_value = segment
with (
patch("services.dataset_service.redis_client.get") as mock_redis_get,
patch("services.dataset_service.VectorService.update_segment_vector") as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
patch("services.dataset_service.naive_utc_now") as mock_now,
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
patch("services.dataset_service.VectorService.update_segment_vector", autospec=True) as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_redis_get.return_value = None
mock_hash.return_value = "new-hash"
@ -456,7 +464,7 @@ class TestSegmentServiceDeleteSegment:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
def test_delete_segment_success(self, mock_db_session):
@ -471,10 +479,10 @@ class TestSegmentServiceDeleteSegment:
mock_db_session.scalars.return_value = mock_scalars
with (
patch("services.dataset_service.redis_client.get") as mock_redis_get,
patch("services.dataset_service.redis_client.setex") as mock_redis_setex,
patch("services.dataset_service.delete_segment_from_index_task") as mock_task,
patch("services.dataset_service.select") as mock_select,
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
patch("services.dataset_service.redis_client.setex", autospec=True) as mock_redis_setex,
patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task,
patch("services.dataset_service.select", autospec=True) as mock_select,
):
mock_redis_get.return_value = None
mock_select.return_value.where.return_value = mock_select
@ -495,8 +503,8 @@ class TestSegmentServiceDeleteSegment:
dataset = SegmentTestDataFactory.create_dataset_mock()
with (
patch("services.dataset_service.redis_client.get") as mock_redis_get,
patch("services.dataset_service.delete_segment_from_index_task") as mock_task,
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task,
):
mock_redis_get.return_value = None
@ -515,7 +523,7 @@ class TestSegmentServiceDeleteSegment:
document = SegmentTestDataFactory.create_document_mock()
dataset = SegmentTestDataFactory.create_dataset_mock()
with patch("services.dataset_service.redis_client.get") as mock_redis_get:
with patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get:
mock_redis_get.return_value = "1" # Deletion in progress
# Act & Assert
@ -529,7 +537,7 @@ class TestSegmentServiceDeleteSegments:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
@ -562,8 +570,8 @@ class TestSegmentServiceDeleteSegments:
mock_db_session.scalars.return_value = mock_scalars
with (
patch("services.dataset_service.delete_segment_from_index_task") as mock_task,
patch("services.dataset_service.select") as mock_select_func,
patch("services.dataset_service.delete_segment_from_index_task", autospec=True) as mock_task,
patch("services.dataset_service.select", autospec=True) as mock_select_func,
):
mock_select_func.return_value = mock_select
@ -594,7 +602,7 @@ class TestSegmentServiceUpdateSegmentsStatus:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
@ -623,9 +631,9 @@ class TestSegmentServiceUpdateSegmentsStatus:
mock_db_session.scalars.return_value = mock_scalars
with (
patch("services.dataset_service.redis_client.get") as mock_redis_get,
patch("services.dataset_service.enable_segments_to_index_task") as mock_task,
patch("services.dataset_service.select") as mock_select_func,
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
patch("services.dataset_service.enable_segments_to_index_task", autospec=True) as mock_task,
patch("services.dataset_service.select", autospec=True) as mock_select_func,
):
mock_redis_get.return_value = None
mock_select_func.return_value = mock_select
@ -657,10 +665,10 @@ class TestSegmentServiceUpdateSegmentsStatus:
mock_db_session.scalars.return_value = mock_scalars
with (
patch("services.dataset_service.redis_client.get") as mock_redis_get,
patch("services.dataset_service.disable_segments_from_index_task") as mock_task,
patch("services.dataset_service.naive_utc_now") as mock_now,
patch("services.dataset_service.select") as mock_select_func,
patch("services.dataset_service.redis_client.get", autospec=True) as mock_redis_get,
patch("services.dataset_service.disable_segments_from_index_task", autospec=True) as mock_task,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
patch("services.dataset_service.select", autospec=True) as mock_select_func,
):
mock_redis_get.return_value = None
mock_now.return_value = "2024-01-01T00:00:00"
@ -693,7 +701,7 @@ class TestSegmentServiceGetSegments:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
@ -771,7 +779,7 @@ class TestSegmentServiceGetSegmentById:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
def test_get_segment_by_id_success(self, mock_db_session):
@ -814,7 +822,7 @@ class TestSegmentServiceGetChildChunks:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
@ -876,7 +884,7 @@ class TestSegmentServiceGetChildChunkById:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
def test_get_child_chunk_by_id_success(self, mock_db_session):
@ -919,7 +927,7 @@ class TestSegmentServiceCreateChildChunk:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
@ -942,9 +950,11 @@ class TestSegmentServiceCreateChildChunk:
mock_db_session.query.return_value = mock_query
with (
patch("services.dataset_service.redis_client.lock") as mock_lock,
patch("services.dataset_service.VectorService.create_child_chunk_vector") as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
patch(
"services.dataset_service.VectorService.create_child_chunk_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
):
mock_lock.return_value.__enter__ = Mock()
mock_lock.return_value.__exit__ = Mock(return_value=None)
@ -972,9 +982,11 @@ class TestSegmentServiceCreateChildChunk:
mock_db_session.query.return_value = mock_query
with (
patch("services.dataset_service.redis_client.lock") as mock_lock,
patch("services.dataset_service.VectorService.create_child_chunk_vector") as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash") as mock_hash,
patch("services.dataset_service.redis_client.lock", autospec=True) as mock_lock,
patch(
"services.dataset_service.VectorService.create_child_chunk_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash,
):
mock_lock.return_value.__enter__ = Mock()
mock_lock.return_value.__exit__ = Mock(return_value=None)
@ -994,7 +1006,7 @@ class TestSegmentServiceUpdateChildChunk:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
@pytest.fixture
@ -1014,8 +1026,10 @@ class TestSegmentServiceUpdateChildChunk:
dataset = SegmentTestDataFactory.create_dataset_mock()
with (
patch("services.dataset_service.VectorService.update_child_chunk_vector") as mock_vector_service,
patch("services.dataset_service.naive_utc_now") as mock_now,
patch(
"services.dataset_service.VectorService.update_child_chunk_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_now.return_value = "2024-01-01T00:00:00"
@ -1040,8 +1054,10 @@ class TestSegmentServiceUpdateChildChunk:
dataset = SegmentTestDataFactory.create_dataset_mock()
with (
patch("services.dataset_service.VectorService.update_child_chunk_vector") as mock_vector_service,
patch("services.dataset_service.naive_utc_now") as mock_now,
patch(
"services.dataset_service.VectorService.update_child_chunk_vector", autospec=True
) as mock_vector_service,
patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now,
):
mock_vector_service.side_effect = Exception("Vector indexing failed")
mock_now.return_value = "2024-01-01T00:00:00"
@ -1059,7 +1075,7 @@ class TestSegmentServiceDeleteChildChunk:
@pytest.fixture
def mock_db_session(self):
"""Mock database session."""
with patch("services.dataset_service.db.session") as mock_db:
with patch("services.dataset_service.db.session", autospec=True) as mock_db:
yield mock_db
def test_delete_child_chunk_success(self, mock_db_session):
@ -1068,7 +1084,9 @@ class TestSegmentServiceDeleteChildChunk:
chunk = SegmentTestDataFactory.create_child_chunk_mock()
dataset = SegmentTestDataFactory.create_dataset_mock()
with patch("services.dataset_service.VectorService.delete_child_chunk_vector") as mock_vector_service:
with patch(
"services.dataset_service.VectorService.delete_child_chunk_vector", autospec=True
) as mock_vector_service:
# Act
SegmentService.delete_child_chunk(chunk, dataset)
@ -1083,7 +1101,9 @@ class TestSegmentServiceDeleteChildChunk:
chunk = SegmentTestDataFactory.create_child_chunk_mock()
dataset = SegmentTestDataFactory.create_dataset_mock()
with patch("services.dataset_service.VectorService.delete_child_chunk_vector") as mock_vector_service:
with patch(
"services.dataset_service.VectorService.delete_child_chunk_vector", autospec=True
) as mock_vector_service:
mock_vector_service.side_effect = Exception("Vector deletion failed")
# Act & Assert

View File

@ -15,8 +15,8 @@ from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
class TestWorkflowRunArchiver:
"""Tests for the WorkflowRunArchiver class."""
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config")
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage")
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.dify_config", autospec=True)
@patch("services.retention.workflow_run.archive_paid_plan_workflow_run.get_archive_storage", autospec=True)
def test_archiver_initialization(self, mock_get_storage, mock_config):
"""Test archiver can be initialized with various options."""
from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver

View File

@ -214,7 +214,7 @@ def factory():
class TestAudioServiceASR:
"""Test speech-to-text (ASR) operations."""
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory):
"""Test successful ASR transcription in CHAT mode."""
# Arrange
@ -226,9 +226,7 @@ class TestAudioServiceASR:
file = factory.create_file_storage_mock()
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
mock_model_instance.invoke_speech2text.return_value = "Transcribed text"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
@ -242,7 +240,7 @@ class TestAudioServiceASR:
call_args = mock_model_instance.invoke_speech2text.call_args
assert call_args.kwargs["user"] == "user-123"
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_asr_success_advanced_chat_mode(self, mock_model_manager_class, factory):
"""Test successful ASR transcription in ADVANCED_CHAT mode."""
# Arrange
@ -254,9 +252,7 @@ class TestAudioServiceASR:
file = factory.create_file_storage_mock()
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
mock_model_instance.invoke_speech2text.return_value = "Workflow transcribed text"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
@ -351,7 +347,7 @@ class TestAudioServiceASR:
with pytest.raises(AudioTooLargeServiceError, match="Audio size larger than 30 mb"):
AudioService.transcript_asr(app_model=app, file=file)
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_asr_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
"""Test that ASR raises error when no model instance is available."""
# Arrange
@ -363,8 +359,7 @@ class TestAudioServiceASR:
file = factory.create_file_storage_mock()
# Mock ModelManager to return None
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_manager.get_default_model_instance.return_value = None
# Act & Assert
@ -375,7 +370,7 @@ class TestAudioServiceASR:
class TestAudioServiceTTS:
"""Test text-to-speech (TTS) operations."""
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory):
"""Test successful TTS with text input."""
# Arrange
@ -388,9 +383,7 @@ class TestAudioServiceTTS:
)
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
mock_model_instance.invoke_tts.return_value = b"audio data"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
@ -412,8 +405,8 @@ class TestAudioServiceTTS:
voice="en-US-Neural",
)
@patch("services.audio_service.db.session")
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.db.session", autospec=True)
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_tts_with_message_id_success(self, mock_model_manager_class, mock_db_session, factory):
"""Test successful TTS with message ID."""
# Arrange
@ -437,9 +430,7 @@ class TestAudioServiceTTS:
mock_query.first.return_value = message
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
mock_model_instance.invoke_tts.return_value = b"audio from message"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
@ -454,7 +445,7 @@ class TestAudioServiceTTS:
assert result == b"audio from message"
mock_model_instance.invoke_tts.assert_called_once()
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory):
"""Test TTS uses default voice when none specified."""
# Arrange
@ -467,9 +458,7 @@ class TestAudioServiceTTS:
)
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
mock_model_instance.invoke_tts.return_value = b"audio data"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
@ -486,7 +475,7 @@ class TestAudioServiceTTS:
call_args = mock_model_instance.invoke_tts.call_args
assert call_args.kwargs["voice"] == "default-voice"
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_tts_gets_first_available_voice_when_none_configured(self, mock_model_manager_class, factory):
"""Test TTS gets first available voice when none is configured."""
# Arrange
@ -499,9 +488,7 @@ class TestAudioServiceTTS:
)
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
mock_model_instance.get_tts_voices.return_value = [{"value": "auto-voice"}]
mock_model_instance.invoke_tts.return_value = b"audio data"
@ -518,8 +505,8 @@ class TestAudioServiceTTS:
call_args = mock_model_instance.invoke_tts.call_args
assert call_args.kwargs["voice"] == "auto-voice"
@patch("services.audio_service.WorkflowService")
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.WorkflowService", autospec=True)
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_tts_workflow_mode_with_draft(
self, mock_model_manager_class, mock_workflow_service_class, factory
):
@ -533,14 +520,11 @@ class TestAudioServiceTTS:
)
# Mock WorkflowService
mock_workflow_service = MagicMock()
mock_workflow_service_class.return_value = mock_workflow_service
mock_workflow_service = mock_workflow_service_class.return_value
mock_workflow_service.get_draft_workflow.return_value = draft_workflow
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
mock_model_instance.invoke_tts.return_value = b"draft audio"
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
@ -565,7 +549,7 @@ class TestAudioServiceTTS:
with pytest.raises(ValueError, match="Text is required"):
AudioService.transcript_tts(app_model=app, text=None)
@patch("services.audio_service.db.session")
@patch("services.audio_service.db.session", autospec=True)
def test_transcript_tts_returns_none_for_invalid_message_id(self, mock_db_session, factory):
"""Test that TTS returns None for invalid message ID format."""
# Arrange
@ -580,7 +564,7 @@ class TestAudioServiceTTS:
# Assert
assert result is None
@patch("services.audio_service.db.session")
@patch("services.audio_service.db.session", autospec=True)
def test_transcript_tts_returns_none_for_nonexistent_message(self, mock_db_session, factory):
"""Test that TTS returns None when message doesn't exist."""
# Arrange
@ -601,7 +585,7 @@ class TestAudioServiceTTS:
# Assert
assert result is None
@patch("services.audio_service.db.session")
@patch("services.audio_service.db.session", autospec=True)
def test_transcript_tts_returns_none_for_empty_message_answer(self, mock_db_session, factory):
"""Test that TTS returns None when message answer is empty."""
# Arrange
@ -627,7 +611,7 @@ class TestAudioServiceTTS:
# Assert
assert result is None
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_tts_raises_error_when_no_voices_available(self, mock_model_manager_class, factory):
"""Test that TTS raises error when no voices are available."""
# Arrange
@ -640,9 +624,7 @@ class TestAudioServiceTTS:
)
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
mock_model_instance.get_tts_voices.return_value = [] # No voices available
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
@ -655,7 +637,7 @@ class TestAudioServiceTTS:
class TestAudioServiceTTSVoices:
"""Test TTS voice listing operations."""
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_tts_voices_success(self, mock_model_manager_class, factory):
"""Test successful retrieval of TTS voices."""
# Arrange
@ -668,9 +650,7 @@ class TestAudioServiceTTSVoices:
]
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
mock_model_instance.get_tts_voices.return_value = expected_voices
mock_model_manager.get_default_model_instance.return_value = mock_model_instance
@ -682,7 +662,7 @@ class TestAudioServiceTTSVoices:
assert result == expected_voices
mock_model_instance.get_tts_voices.assert_called_once_with(language)
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_tts_voices_raises_error_when_no_model_instance(self, mock_model_manager_class, factory):
"""Test that TTS voices raises error when no model instance is available."""
# Arrange
@ -690,15 +670,14 @@ class TestAudioServiceTTSVoices:
language = "en-US"
# Mock ModelManager to return None
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_manager.get_default_model_instance.return_value = None
# Act & Assert
with pytest.raises(ProviderNotSupportTextToSpeechServiceError):
AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language)
@patch("services.audio_service.ModelManager")
@patch("services.audio_service.ModelManager", autospec=True)
def test_transcript_tts_voices_propagates_exceptions(self, mock_model_manager_class, factory):
"""Test that TTS voices propagates exceptions from model instance."""
# Arrange
@ -706,9 +685,7 @@ class TestAudioServiceTTSVoices:
language = "en-US"
# Mock ModelManager
mock_model_manager = MagicMock()
mock_model_manager_class.return_value = mock_model_manager
mock_model_manager = mock_model_manager_class.return_value
mock_model_instance = MagicMock()
mock_model_instance.get_tts_voices.side_effect = RuntimeError("Model error")
mock_model_manager.get_default_model_instance.return_value = mock_model_instance

View File

@ -237,9 +237,9 @@ class TestConversationServiceSummarization:
titles based on the first message.
"""
@patch("services.conversation_service.db.session")
@patch("services.conversation_service.ConversationService.get_conversation")
@patch("services.conversation_service.ConversationService.auto_generate_name")
@patch("services.conversation_service.db.session", autospec=True)
@patch("services.conversation_service.ConversationService.get_conversation", autospec=True)
@patch("services.conversation_service.ConversationService.auto_generate_name", autospec=True)
def test_rename_with_auto_generate(self, mock_auto_generate, mock_get_conversation, mock_db_session):
"""
Test renaming conversation with auto-generation enabled.

View File

@ -28,10 +28,14 @@ class TestArchivedWorkflowRunDeletion:
with (
patch("services.retention.workflow_run.delete_archived_workflow_run.db", mock_db),
patch(
"services.retention.workflow_run.delete_archived_workflow_run.sessionmaker", return_value=session_maker
"services.retention.workflow_run.delete_archived_workflow_run.sessionmaker",
return_value=session_maker,
autospec=True,
),
patch.object(deleter, "_get_workflow_run_repo", return_value=repo),
patch.object(deleter, "_delete_run", return_value=MagicMock(success=True)) as mock_delete_run,
patch.object(deleter, "_get_workflow_run_repo", return_value=repo, autospec=True),
patch.object(
deleter, "_delete_run", return_value=MagicMock(success=True), autospec=True
) as mock_delete_run,
):
result = deleter.delete_by_run_id("run-1")
@ -46,7 +50,7 @@ class TestArchivedWorkflowRunDeletion:
run.id = "run-1"
run.tenant_id = "tenant-1"
with patch.object(deleter, "_get_workflow_run_repo") as mock_get_repo:
with patch.object(deleter, "_get_workflow_run_repo", autospec=True) as mock_get_repo:
result = deleter._delete_run(run)
assert result.success is True

View File

@ -402,7 +402,7 @@ class TestBillingDisabledPolicyFilterMessageIds:
class TestCreateMessageCleanPolicy:
"""Unit tests for create_message_clean_policy factory function."""
@patch("services.retention.conversation.messages_clean_policy.dify_config")
@patch("services.retention.conversation.messages_clean_policy.dify_config", autospec=True)
def test_billing_disabled_returns_billing_disabled_policy(self, mock_config):
"""Test that BILLING_ENABLED=False returns BillingDisabledPolicy."""
# Arrange
@ -414,8 +414,8 @@ class TestCreateMessageCleanPolicy:
# Assert
assert isinstance(policy, BillingDisabledPolicy)
@patch("services.retention.conversation.messages_clean_policy.BillingService")
@patch("services.retention.conversation.messages_clean_policy.dify_config")
@patch("services.retention.conversation.messages_clean_policy.BillingService", autospec=True)
@patch("services.retention.conversation.messages_clean_policy.dify_config", autospec=True)
def test_billing_enabled_policy_has_correct_internals(self, mock_config, mock_billing_service):
"""Test that BillingSandboxPolicy is created with correct internal values."""
# Arrange
@ -554,7 +554,7 @@ class TestMessagesCleanServiceFromDays:
MessagesCleanService.from_days(policy=policy, days=-1)
# Act
with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime:
with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 14, 0, 0)
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta
@ -586,7 +586,7 @@ class TestMessagesCleanServiceFromDays:
dry_run = True
# Act
with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime:
with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta
@ -613,7 +613,7 @@ class TestMessagesCleanServiceFromDays:
policy = BillingDisabledPolicy()
# Act
with patch("services.retention.conversation.messages_clean_service.datetime") as mock_datetime:
with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime:
fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0)
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta

View File

@ -134,8 +134,8 @@ def factory():
class TestRecommendedAppServiceGetApps:
"""Test get_recommended_apps_and_categories operations."""
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommended_apps_success_with_apps(self, mock_config, mock_factory_class, factory):
"""Test successful retrieval of recommended apps when apps are returned."""
# Arrange
@ -161,8 +161,8 @@ class TestRecommendedAppServiceGetApps:
mock_factory_class.get_recommend_app_factory.assert_called_once_with("remote")
mock_retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommended_apps_fallback_to_builtin_when_empty(self, mock_config, mock_factory_class, factory):
"""Test fallback to builtin when no recommended apps are returned."""
# Arrange
@ -199,8 +199,8 @@ class TestRecommendedAppServiceGetApps:
# Verify fallback was called with en-US (hardcoded)
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommended_apps_fallback_when_none_recommended_apps(self, mock_config, mock_factory_class, factory):
"""Test fallback when recommended_apps key is None."""
# Arrange
@ -232,8 +232,8 @@ class TestRecommendedAppServiceGetApps:
assert result == builtin_response
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once()
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommended_apps_with_different_languages(self, mock_config, mock_factory_class, factory):
"""Test retrieval with different language codes."""
# Arrange
@ -262,8 +262,8 @@ class TestRecommendedAppServiceGetApps:
assert result["recommended_apps"][0]["id"] == f"app-{language}"
mock_instance.get_recommended_apps_and_categories.assert_called_with(language)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommended_apps_uses_correct_factory_mode(self, mock_config, mock_factory_class, factory):
"""Test that correct factory is selected based on mode."""
# Arrange
@ -292,8 +292,8 @@ class TestRecommendedAppServiceGetApps:
class TestRecommendedAppServiceGetDetail:
"""Test get_recommend_app_detail operations."""
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommend_app_detail_success(self, mock_config, mock_factory_class, factory):
"""Test successful retrieval of app detail."""
# Arrange
@ -324,8 +324,8 @@ class TestRecommendedAppServiceGetDetail:
assert result["name"] == "Productivity App"
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommend_app_detail_with_different_modes(self, mock_config, mock_factory_class, factory):
"""Test app detail retrieval with different factory modes."""
# Arrange
@ -352,8 +352,8 @@ class TestRecommendedAppServiceGetDetail:
assert result["name"] == f"App from {mode}"
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommend_app_detail_returns_none_when_not_found(self, mock_config, mock_factory_class, factory):
"""Test that None is returned when app is not found."""
# Arrange
@ -375,8 +375,8 @@ class TestRecommendedAppServiceGetDetail:
assert result is None
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommend_app_detail_returns_empty_dict(self, mock_config, mock_factory_class, factory):
"""Test handling of empty dict response."""
# Arrange
@ -397,8 +397,8 @@ class TestRecommendedAppServiceGetDetail:
# Assert
assert result == {}
@patch("services.recommended_app_service.RecommendAppRetrievalFactory")
@patch("services.recommended_app_service.dify_config")
@patch("services.recommended_app_service.RecommendAppRetrievalFactory", autospec=True)
@patch("services.recommended_app_service.dify_config", autospec=True)
def test_get_recommend_app_detail_with_complex_model_config(self, mock_config, mock_factory_class, factory):
"""Test app detail with complex model configuration."""
# Arrange

View File

@ -201,8 +201,8 @@ def factory():
class TestSavedMessageServicePagination:
"""Test saved message pagination operations."""
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True)
@patch("services.saved_message_service.db.session", autospec=True)
def test_pagination_with_account_user(self, mock_db_session, mock_message_pagination, factory):
"""Test pagination with an Account user."""
# Arrange
@ -247,8 +247,8 @@ class TestSavedMessageServicePagination:
include_ids=["msg-0", "msg-1", "msg-2"],
)
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True)
@patch("services.saved_message_service.db.session", autospec=True)
def test_pagination_with_end_user(self, mock_db_session, mock_message_pagination, factory):
"""Test pagination with an EndUser."""
# Arrange
@ -301,8 +301,8 @@ class TestSavedMessageServicePagination:
with pytest.raises(ValueError, match="User is required"):
SavedMessageService.pagination_by_last_id(app_model=app, user=None, last_id=None, limit=20)
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True)
@patch("services.saved_message_service.db.session", autospec=True)
def test_pagination_with_last_id(self, mock_db_session, mock_message_pagination, factory):
"""Test pagination with last_id parameter."""
# Arrange
@ -340,8 +340,8 @@ class TestSavedMessageServicePagination:
call_args = mock_message_pagination.call_args
assert call_args.kwargs["last_id"] == last_id
@patch("services.saved_message_service.MessageService.pagination_by_last_id")
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.MessageService.pagination_by_last_id", autospec=True)
@patch("services.saved_message_service.db.session", autospec=True)
def test_pagination_with_empty_saved_messages(self, mock_db_session, mock_message_pagination, factory):
"""Test pagination when user has no saved messages."""
# Arrange
@ -377,8 +377,8 @@ class TestSavedMessageServicePagination:
class TestSavedMessageServiceSave:
"""Test save message operations."""
@patch("services.saved_message_service.MessageService.get_message")
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.MessageService.get_message", autospec=True)
@patch("services.saved_message_service.db.session", autospec=True)
def test_save_message_for_account(self, mock_db_session, mock_get_message, factory):
"""Test saving a message for an Account user."""
# Arrange
@ -407,8 +407,8 @@ class TestSavedMessageServiceSave:
assert saved_message.created_by_role == "account"
mock_db_session.commit.assert_called_once()
@patch("services.saved_message_service.MessageService.get_message")
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.MessageService.get_message", autospec=True)
@patch("services.saved_message_service.db.session", autospec=True)
def test_save_message_for_end_user(self, mock_db_session, mock_get_message, factory):
"""Test saving a message for an EndUser."""
# Arrange
@ -437,7 +437,7 @@ class TestSavedMessageServiceSave:
assert saved_message.created_by_role == "end_user"
mock_db_session.commit.assert_called_once()
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.db.session", autospec=True)
def test_save_without_user_does_nothing(self, mock_db_session, factory):
"""Test that saving without user is a no-op."""
# Arrange
@ -451,8 +451,8 @@ class TestSavedMessageServiceSave:
mock_db_session.add.assert_not_called()
mock_db_session.commit.assert_not_called()
@patch("services.saved_message_service.MessageService.get_message")
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.MessageService.get_message", autospec=True)
@patch("services.saved_message_service.db.session", autospec=True)
def test_save_duplicate_message_is_idempotent(self, mock_db_session, mock_get_message, factory):
"""Test that saving an already saved message is idempotent."""
# Arrange
@ -480,8 +480,8 @@ class TestSavedMessageServiceSave:
mock_db_session.commit.assert_not_called()
mock_get_message.assert_not_called()
@patch("services.saved_message_service.MessageService.get_message")
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.MessageService.get_message", autospec=True)
@patch("services.saved_message_service.db.session", autospec=True)
def test_save_validates_message_exists(self, mock_db_session, mock_get_message, factory):
"""Test that save validates message exists through MessageService."""
# Arrange
@ -508,7 +508,7 @@ class TestSavedMessageServiceSave:
class TestSavedMessageServiceDelete:
"""Test delete saved message operations."""
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.db.session", autospec=True)
def test_delete_saved_message_for_account(self, mock_db_session, factory):
"""Test deleting a saved message for an Account user."""
# Arrange
@ -535,7 +535,7 @@ class TestSavedMessageServiceDelete:
mock_db_session.delete.assert_called_once_with(saved_message)
mock_db_session.commit.assert_called_once()
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.db.session", autospec=True)
def test_delete_saved_message_for_end_user(self, mock_db_session, factory):
"""Test deleting a saved message for an EndUser."""
# Arrange
@ -562,7 +562,7 @@ class TestSavedMessageServiceDelete:
mock_db_session.delete.assert_called_once_with(saved_message)
mock_db_session.commit.assert_called_once()
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.db.session", autospec=True)
def test_delete_without_user_does_nothing(self, mock_db_session, factory):
"""Test that deleting without user is a no-op."""
# Arrange
@ -576,7 +576,7 @@ class TestSavedMessageServiceDelete:
mock_db_session.delete.assert_not_called()
mock_db_session.commit.assert_not_called()
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.db.session", autospec=True)
def test_delete_non_existent_saved_message_does_nothing(self, mock_db_session, factory):
"""Test that deleting a non-existent saved message is a no-op."""
# Arrange
@ -597,7 +597,7 @@ class TestSavedMessageServiceDelete:
mock_db_session.delete.assert_not_called()
mock_db_session.commit.assert_not_called()
@patch("services.saved_message_service.db.session")
@patch("services.saved_message_service.db.session", autospec=True)
def test_delete_only_affects_user_own_saved_messages(self, mock_db_session, factory):
"""Test that delete only removes the user's own saved message."""
# Arrange

View File

@ -315,7 +315,7 @@ class TestTagServiceRetrieval:
- get_tags_by_target_id: Get all tags bound to a specific target
"""
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_get_tags_with_binding_counts(self, mock_db_session, factory):
"""
Test retrieving tags with their binding counts.
@ -372,7 +372,7 @@ class TestTagServiceRetrieval:
# Verify database query was called
mock_db_session.query.assert_called_once()
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_get_tags_with_keyword_filter(self, mock_db_session, factory):
"""
Test retrieving tags filtered by keyword (case-insensitive).
@ -426,7 +426,7 @@ class TestTagServiceRetrieval:
# 2. Additional WHERE clause for keyword filtering
assert mock_query.where.call_count >= 2, "Keyword filter should add WHERE clause"
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_get_target_ids_by_tag_ids(self, mock_db_session, factory):
"""
Test retrieving target IDs by tag IDs.
@ -482,7 +482,7 @@ class TestTagServiceRetrieval:
# Verify both queries were executed
assert mock_db_session.scalars.call_count == 2, "Should execute tag query and binding query"
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_get_target_ids_with_empty_tag_ids(self, mock_db_session, factory):
"""
Test that empty tag_ids returns empty list.
@ -510,7 +510,7 @@ class TestTagServiceRetrieval:
assert results == [], "Should return empty list for empty input"
mock_db_session.scalars.assert_not_called(), "Should not query database for empty input"
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_get_tag_by_tag_name(self, mock_db_session, factory):
"""
Test retrieving tags by name.
@ -552,7 +552,7 @@ class TestTagServiceRetrieval:
assert len(results) == 1, "Should find exactly one tag"
assert results[0].name == tag_name, "Tag name should match"
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_get_tag_by_tag_name_returns_empty_for_missing_params(self, mock_db_session, factory):
"""
Test that missing tag_type or tag_name returns empty list.
@ -580,7 +580,7 @@ class TestTagServiceRetrieval:
# Verify no database queries were executed
mock_db_session.scalars.assert_not_called(), "Should not query database for invalid input"
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_get_tags_by_target_id(self, mock_db_session, factory):
"""
Test retrieving tags associated with a specific target.
@ -651,10 +651,10 @@ class TestTagServiceCRUD:
- get_tag_binding_count: Get count of bindings for a tag
"""
@patch("services.tag_service.current_user")
@patch("services.tag_service.TagService.get_tag_by_tag_name")
@patch("services.tag_service.db.session")
@patch("services.tag_service.uuid.uuid4")
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
@patch("services.tag_service.uuid.uuid4", autospec=True)
def test_save_tags(self, mock_uuid, mock_db_session, mock_get_tag_by_name, mock_current_user, factory):
"""
Test creating a new tag.
@ -709,8 +709,8 @@ class TestTagServiceCRUD:
assert added_tag.created_by == "user-123", "Created by should match current user"
assert added_tag.tenant_id == "tenant-123", "Tenant ID should match current tenant"
@patch("services.tag_service.current_user")
@patch("services.tag_service.TagService.get_tag_by_tag_name")
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
def test_save_tags_raises_error_for_duplicate_name(self, mock_get_tag_by_name, mock_current_user, factory):
"""
Test that creating a tag with duplicate name raises ValueError.
@ -740,9 +740,9 @@ class TestTagServiceCRUD:
with pytest.raises(ValueError, match="Tag name already exists"):
TagService.save_tags(args)
@patch("services.tag_service.current_user")
@patch("services.tag_service.TagService.get_tag_by_tag_name")
@patch("services.tag_service.db.session")
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
def test_update_tags(self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory):
"""
Test updating a tag name.
@ -792,9 +792,9 @@ class TestTagServiceCRUD:
# Verify transaction was committed
mock_db_session.commit.assert_called_once(), "Should commit transaction"
@patch("services.tag_service.current_user")
@patch("services.tag_service.TagService.get_tag_by_tag_name")
@patch("services.tag_service.db.session")
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
def test_update_tags_raises_error_for_duplicate_name(
self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory
):
@ -826,7 +826,7 @@ class TestTagServiceCRUD:
with pytest.raises(ValueError, match="Tag name already exists"):
TagService.update_tags(args, tag_id="tag-123")
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_update_tags_raises_not_found_for_missing_tag(self, mock_db_session, factory):
"""
Test that updating a non-existent tag raises NotFound.
@ -848,8 +848,8 @@ class TestTagServiceCRUD:
mock_query.first.return_value = None
# Mock duplicate check and current_user
with patch("services.tag_service.TagService.get_tag_by_tag_name", return_value=[]):
with patch("services.tag_service.current_user") as mock_user:
with patch("services.tag_service.TagService.get_tag_by_tag_name", return_value=[], autospec=True):
with patch("services.tag_service.current_user", autospec=True) as mock_user:
mock_user.current_tenant_id = "tenant-123"
args = {"name": "New Name", "type": "app"}
@ -858,7 +858,7 @@ class TestTagServiceCRUD:
with pytest.raises(NotFound, match="Tag not found"):
TagService.update_tags(args, tag_id="nonexistent")
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_get_tag_binding_count(self, mock_db_session, factory):
"""
Test getting the count of bindings for a tag.
@ -894,7 +894,7 @@ class TestTagServiceCRUD:
# Verify count matches expectation
assert result == expected_count, "Binding count should match"
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_delete_tag(self, mock_db_session, factory):
"""
Test deleting a tag and its bindings.
@ -950,7 +950,7 @@ class TestTagServiceCRUD:
# Verify transaction was committed
mock_db_session.commit.assert_called_once(), "Should commit transaction"
@patch("services.tag_service.db.session")
@patch("services.tag_service.db.session", autospec=True)
def test_delete_tag_raises_not_found(self, mock_db_session, factory):
"""
Test that deleting a non-existent tag raises NotFound.
@ -996,9 +996,9 @@ class TestTagServiceBindings:
- check_target_exists: Validate target (dataset/app) existence
"""
@patch("services.tag_service.current_user")
@patch("services.tag_service.TagService.check_target_exists")
@patch("services.tag_service.db.session")
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
def test_save_tag_binding(self, mock_db_session, mock_check_target, mock_current_user, factory):
"""
Test creating tag bindings.
@ -1047,9 +1047,9 @@ class TestTagServiceBindings:
# Verify transaction was committed
mock_db_session.commit.assert_called_once(), "Should commit transaction"
@patch("services.tag_service.current_user")
@patch("services.tag_service.TagService.check_target_exists")
@patch("services.tag_service.db.session")
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
def test_save_tag_binding_is_idempotent(self, mock_db_session, mock_check_target, mock_current_user, factory):
"""
Test that saving duplicate bindings is idempotent.
@ -1088,8 +1088,8 @@ class TestTagServiceBindings:
# Verify no new binding was added (idempotent)
mock_db_session.add.assert_not_called(), "Should not create duplicate binding"
@patch("services.tag_service.TagService.check_target_exists")
@patch("services.tag_service.db.session")
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
def test_delete_tag_binding(self, mock_db_session, mock_check_target, factory):
"""
Test deleting a tag binding.
@ -1136,8 +1136,8 @@ class TestTagServiceBindings:
# Verify transaction was committed
mock_db_session.commit.assert_called_once(), "Should commit transaction"
@patch("services.tag_service.TagService.check_target_exists")
@patch("services.tag_service.db.session")
@patch("services.tag_service.TagService.check_target_exists", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
def test_delete_tag_binding_does_nothing_if_not_exists(self, mock_db_session, mock_check_target, factory):
"""
Test that deleting a non-existent binding is a no-op.
@ -1173,8 +1173,8 @@ class TestTagServiceBindings:
# Verify no commit was made (nothing changed)
mock_db_session.commit.assert_not_called(), "Should not commit if nothing to delete"
@patch("services.tag_service.current_user")
@patch("services.tag_service.db.session")
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
def test_check_target_exists_for_dataset(self, mock_db_session, mock_current_user, factory):
"""
Test validating that a dataset target exists.
@ -1214,8 +1214,8 @@ class TestTagServiceBindings:
# Verify no exception was raised and query was executed
mock_db_session.query.assert_called_once(), "Should query database for dataset"
@patch("services.tag_service.current_user")
@patch("services.tag_service.db.session")
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
def test_check_target_exists_for_app(self, mock_db_session, mock_current_user, factory):
"""
Test validating that an app target exists.
@ -1255,8 +1255,8 @@ class TestTagServiceBindings:
# Verify no exception was raised and query was executed
mock_db_session.query.assert_called_once(), "Should query database for app"
@patch("services.tag_service.current_user")
@patch("services.tag_service.db.session")
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
def test_check_target_exists_raises_not_found_for_missing_dataset(
self, mock_db_session, mock_current_user, factory
):
@ -1287,8 +1287,8 @@ class TestTagServiceBindings:
with pytest.raises(NotFound, match="Dataset not found"):
TagService.check_target_exists("knowledge", "nonexistent")
@patch("services.tag_service.current_user")
@patch("services.tag_service.db.session")
@patch("services.tag_service.current_user", autospec=True)
@patch("services.tag_service.db.session", autospec=True)
def test_check_target_exists_raises_not_found_for_missing_app(self, mock_db_session, mock_current_user, factory):
"""
Test that missing app raises NotFound.

View File

@ -87,7 +87,7 @@ class TestWebhookServiceUnit:
webhook_trigger = MagicMock()
webhook_trigger.tenant_id = "test_tenant"
with patch.object(WebhookService, "_process_file_uploads") as mock_process_files:
with patch.object(WebhookService, "_process_file_uploads", autospec=True) as mock_process_files:
mock_process_files.return_value = {"file": "mocked_file_obj"}
webhook_data = WebhookService.extract_webhook_data(webhook_trigger)
@ -123,8 +123,10 @@ class TestWebhookServiceUnit:
mock_file.to_dict.return_value = {"file": "data"}
with (
patch.object(WebhookService, "_detect_binary_mimetype", return_value="text/plain") as mock_detect,
patch.object(WebhookService, "_create_file_from_binary") as mock_create,
patch.object(
WebhookService, "_detect_binary_mimetype", return_value="text/plain", autospec=True
) as mock_detect,
patch.object(WebhookService, "_create_file_from_binary", autospec=True) as mock_create,
):
mock_create.return_value = mock_file
body, files = WebhookService._extract_octet_stream_body(webhook_trigger)
@ -168,7 +170,7 @@ class TestWebhookServiceUnit:
fake_magic.from_buffer.side_effect = real_magic.MagicException("magic error")
monkeypatch.setattr("services.trigger.webhook_service.magic", fake_magic)
with patch("services.trigger.webhook_service.logger") as mock_logger:
with patch("services.trigger.webhook_service.logger", autospec=True) as mock_logger:
result = WebhookService._detect_binary_mimetype(b"binary data")
assert result == "application/octet-stream"
@ -245,15 +247,12 @@ class TestWebhookServiceUnit:
assert response_data[0]["id"] == 1
assert response_data[1]["id"] == 2
@patch("services.trigger.webhook_service.ToolFileManager")
@patch("services.trigger.webhook_service.file_factory")
@patch("services.trigger.webhook_service.ToolFileManager", autospec=True)
@patch("services.trigger.webhook_service.file_factory", autospec=True)
def test_process_file_uploads_success(self, mock_file_factory, mock_tool_file_manager):
"""Test successful file upload processing."""
# Mock ToolFileManager
mock_tool_file_instance = MagicMock()
mock_tool_file_manager.return_value = mock_tool_file_instance
# Mock file creation
mock_tool_file_instance = mock_tool_file_manager.return_value # Mock file creation
mock_tool_file = MagicMock()
mock_tool_file.id = "test_file_id"
mock_tool_file_instance.create_file_by_raw.return_value = mock_tool_file
@ -285,15 +284,12 @@ class TestWebhookServiceUnit:
assert mock_tool_file_manager.call_count == 2
assert mock_file_factory.build_from_mapping.call_count == 2
@patch("services.trigger.webhook_service.ToolFileManager")
@patch("services.trigger.webhook_service.file_factory")
@patch("services.trigger.webhook_service.ToolFileManager", autospec=True)
@patch("services.trigger.webhook_service.file_factory", autospec=True)
def test_process_file_uploads_with_errors(self, mock_file_factory, mock_tool_file_manager):
"""Test file upload processing with errors."""
# Mock ToolFileManager
mock_tool_file_instance = MagicMock()
mock_tool_file_manager.return_value = mock_tool_file_instance
# Mock file creation
mock_tool_file_instance = mock_tool_file_manager.return_value # Mock file creation
mock_tool_file = MagicMock()
mock_tool_file.id = "test_file_id"
mock_tool_file_instance.create_file_by_raw.return_value = mock_tool_file
@ -544,8 +540,8 @@ class TestWebhookServiceUnit:
# Mock the WebhookService methods
with (
patch.object(WebhookService, "get_webhook_trigger_and_workflow") as mock_get_trigger,
patch.object(WebhookService, "extract_and_validate_webhook_data") as mock_extract,
patch.object(WebhookService, "get_webhook_trigger_and_workflow", autospec=True) as mock_get_trigger,
patch.object(WebhookService, "extract_and_validate_webhook_data", autospec=True) as mock_extract,
):
mock_trigger = MagicMock()
mock_workflow = MagicMock()

View File

@ -124,7 +124,7 @@ class TestWorkflowRunService:
"""Create WorkflowRunService instance with mocked dependencies."""
session_factory, _ = mock_session_factory
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
with patch("services.workflow_run_service.DifyAPIRepositoryFactory", autospec=True) as mock_factory:
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
service = WorkflowRunService(session_factory)
return service
@ -135,7 +135,7 @@ class TestWorkflowRunService:
mock_engine = create_autospec(Engine)
session_factory, _ = mock_session_factory
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
with patch("services.workflow_run_service.DifyAPIRepositoryFactory", autospec=True) as mock_factory:
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
service = WorkflowRunService(mock_engine)
return service
@ -146,7 +146,7 @@ class TestWorkflowRunService:
"""Test WorkflowRunService initialization with session_factory."""
session_factory, _ = mock_session_factory
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
with patch("services.workflow_run_service.DifyAPIRepositoryFactory", autospec=True) as mock_factory:
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
service = WorkflowRunService(session_factory)
@ -158,9 +158,11 @@ class TestWorkflowRunService:
mock_engine = create_autospec(Engine)
session_factory, _ = mock_session_factory
with patch("services.workflow_run_service.DifyAPIRepositoryFactory") as mock_factory:
with patch("services.workflow_run_service.DifyAPIRepositoryFactory", autospec=True) as mock_factory:
mock_factory.create_api_workflow_run_repository.return_value = mock_workflow_run_repository
with patch("services.workflow_run_service.sessionmaker", return_value=session_factory) as mock_sessionmaker:
with patch(
"services.workflow_run_service.sessionmaker", return_value=session_factory, autospec=True
) as mock_sessionmaker:
service = WorkflowRunService(mock_engine)
mock_sessionmaker.assert_called_once_with(bind=mock_engine, expire_on_commit=False)

View File

@ -141,7 +141,7 @@ class TestDraftVariableSaver:
def test_draft_saver_with_small_variables(self, draft_saver, mock_session):
with patch(
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable", autospec=True
) as _mock_try_offload:
_mock_try_offload.return_value = None
mock_segment = StringSegment(value="small value")
@ -153,7 +153,7 @@ class TestDraftVariableSaver:
def test_draft_saver_with_large_variables(self, draft_saver, mock_session):
with patch(
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable"
"services.workflow_draft_variable_service.DraftVariableSaver._try_offload_large_variable", autospec=True
) as _mock_try_offload:
mock_segment = StringSegment(value="small value")
mock_draft_var_file = WorkflowDraftVariableFile(
@ -170,7 +170,7 @@ class TestDraftVariableSaver:
# Should not have large variable metadata
assert draft_var.file_id == mock_draft_var_file.id
@patch("services.workflow_draft_variable_service._batch_upsert_draft_variable")
@patch("services.workflow_draft_variable_service._batch_upsert_draft_variable", autospec=True)
def test_save_method_integration(self, mock_batch_upsert, draft_saver):
"""Test complete save workflow."""
outputs = {"result": {"data": "test_output"}, "metadata": {"type": "llm_response"}}
@ -222,7 +222,7 @@ class TestWorkflowDraftVariableService:
name="test_var",
value=StringSegment(value="reset_value"),
)
with patch.object(service, "_reset_conv_var", return_value=expected_result) as mock_reset_conv:
with patch.object(service, "_reset_conv_var", return_value=expected_result, autospec=True) as mock_reset_conv:
result = service.reset_variable(workflow, variable)
mock_reset_conv.assert_called_once_with(workflow, variable)
@ -330,8 +330,8 @@ class TestWorkflowDraftVariableService:
# Mock workflow methods
mock_node_config = {"type": "test_node"}
with (
patch.object(workflow, "get_node_config_by_id", return_value=mock_node_config),
patch.object(workflow, "get_node_type_from_node_config", return_value=NodeType.LLM),
patch.object(workflow, "get_node_config_by_id", return_value=mock_node_config, autospec=True),
patch.object(workflow, "get_node_type_from_node_config", return_value=NodeType.LLM, autospec=True),
):
result = service._reset_node_var_or_sys_var(workflow, variable)

View File

@ -50,7 +50,7 @@ def pipeline_id():
@pytest.fixture
def mock_db_session():
"""Mock database session via session_factory.create_session()."""
with patch("tasks.clean_dataset_task.session_factory") as mock_sf:
with patch("tasks.clean_dataset_task.session_factory", autospec=True) as mock_sf:
mock_session = MagicMock()
# context manager for create_session()
cm = MagicMock()
@ -79,7 +79,7 @@ def mock_db_session():
@pytest.fixture
def mock_storage():
"""Mock storage client."""
with patch("tasks.clean_dataset_task.storage") as mock_storage:
with patch("tasks.clean_dataset_task.storage", autospec=True) as mock_storage:
mock_storage.delete.return_value = None
yield mock_storage
@ -87,7 +87,7 @@ def mock_storage():
@pytest.fixture
def mock_index_processor_factory():
"""Mock IndexProcessorFactory."""
with patch("tasks.clean_dataset_task.IndexProcessorFactory") as mock_factory:
with patch("tasks.clean_dataset_task.IndexProcessorFactory", autospec=True) as mock_factory:
mock_processor = MagicMock()
mock_processor.clean.return_value = None
mock_factory_instance = MagicMock()
@ -104,7 +104,7 @@ def mock_index_processor_factory():
@pytest.fixture
def mock_get_image_upload_file_ids():
"""Mock get_image_upload_file_ids function."""
with patch("tasks.clean_dataset_task.get_image_upload_file_ids") as mock_func:
with patch("tasks.clean_dataset_task.get_image_upload_file_ids", autospec=True) as mock_func:
mock_func.return_value = []
yield mock_func

View File

@ -75,7 +75,7 @@ def mock_document(document_id, dataset_id, notion_workspace_id, notion_page_id,
@pytest.fixture
def mock_db_session(mock_document, mock_dataset):
"""Mock session_factory.create_session to drive deterministic read-only task flow."""
with patch("tasks.document_indexing_sync_task.session_factory") as mock_session_factory:
with patch("tasks.document_indexing_sync_task.session_factory", autospec=True) as mock_session_factory:
session = MagicMock()
session.scalars.return_value.all.return_value = []
session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
@ -96,7 +96,7 @@ def mock_db_session(mock_document, mock_dataset):
@pytest.fixture
def mock_datasource_provider_service():
"""Mock datasource credential provider."""
with patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_service_class:
with patch("tasks.document_indexing_sync_task.DatasourceProviderService", autospec=True) as mock_service_class:
mock_service = MagicMock()
mock_service.get_datasource_credentials.return_value = {"integration_secret": "test_token"}
mock_service_class.return_value = mock_service
@ -106,7 +106,7 @@ def mock_datasource_provider_service():
@pytest.fixture
def mock_notion_extractor():
"""Mock notion extractor class and instance."""
with patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class:
with patch("tasks.document_indexing_sync_task.NotionExtractor", autospec=True) as mock_extractor_class:
mock_extractor = MagicMock()
mock_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
mock_extractor_class.return_value = mock_extractor

View File

@ -95,7 +95,7 @@ def mock_document_segments(document_ids):
@pytest.fixture
def mock_db_session():
"""Mock database session via session_factory.create_session()."""
with patch("tasks.duplicate_document_indexing_task.session_factory") as mock_sf:
with patch("tasks.duplicate_document_indexing_task.session_factory", autospec=True) as mock_sf:
session = MagicMock()
# Allow tests to observe session.close() via context manager teardown
session.close = MagicMock()
@ -118,7 +118,7 @@ def mock_db_session():
@pytest.fixture
def mock_indexing_runner():
"""Mock IndexingRunner."""
with patch("tasks.duplicate_document_indexing_task.IndexingRunner") as mock_runner_class:
with patch("tasks.duplicate_document_indexing_task.IndexingRunner", autospec=True) as mock_runner_class:
mock_runner = MagicMock(spec=IndexingRunner)
mock_runner_class.return_value = mock_runner
yield mock_runner
@ -127,7 +127,7 @@ def mock_indexing_runner():
@pytest.fixture
def mock_feature_service():
"""Mock FeatureService."""
with patch("tasks.duplicate_document_indexing_task.FeatureService") as mock_service:
with patch("tasks.duplicate_document_indexing_task.FeatureService", autospec=True) as mock_service:
mock_features = Mock()
mock_features.billing = Mock()
mock_features.billing.enabled = False
@ -141,7 +141,7 @@ def mock_feature_service():
@pytest.fixture
def mock_index_processor_factory():
"""Mock IndexProcessorFactory."""
with patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory") as mock_factory:
with patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory", autospec=True) as mock_factory:
mock_processor = MagicMock()
mock_processor.clean = Mock()
mock_factory.return_value.init_index_processor.return_value = mock_processor
@ -151,7 +151,7 @@ def mock_index_processor_factory():
@pytest.fixture
def mock_tenant_isolated_queue():
"""Mock TenantIsolatedTaskQueue."""
with patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue") as mock_queue_class:
with patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) as mock_queue_class:
mock_queue = MagicMock(spec=TenantIsolatedTaskQueue)
mock_queue.pull_tasks.return_value = []
mock_queue.delete_task_key = Mock()
@ -168,7 +168,7 @@ def mock_tenant_isolated_queue():
class TestDuplicateDocumentIndexingTask:
"""Tests for the deprecated duplicate_document_indexing_task function."""
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task", autospec=True)
def test_duplicate_document_indexing_task_calls_core_function(self, mock_core_func, dataset_id, document_ids):
"""Test that duplicate_document_indexing_task calls the core _duplicate_document_indexing_task function."""
# Act
@ -177,7 +177,7 @@ class TestDuplicateDocumentIndexingTask:
# Assert
mock_core_func.assert_called_once_with(dataset_id, document_ids)
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task", autospec=True)
def test_duplicate_document_indexing_task_with_empty_document_ids(self, mock_core_func, dataset_id):
"""Test duplicate_document_indexing_task with empty document_ids list."""
# Arrange
@ -445,7 +445,7 @@ class TestDuplicateDocumentIndexingTaskCore:
class TestDuplicateDocumentIndexingTaskWithTenantQueue:
"""Tests for _duplicate_document_indexing_task_with_tenant_queue function."""
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task", autospec=True)
def test_tenant_queue_wrapper_calls_core_function(
self,
mock_core_func,
@ -464,7 +464,7 @@ class TestDuplicateDocumentIndexingTaskWithTenantQueue:
# Assert
mock_core_func.assert_called_once_with(dataset_id, document_ids)
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task", autospec=True)
def test_tenant_queue_wrapper_deletes_key_when_no_tasks(
self,
mock_core_func,
@ -484,7 +484,7 @@ class TestDuplicateDocumentIndexingTaskWithTenantQueue:
# Assert
mock_tenant_isolated_queue.delete_task_key.assert_called_once()
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task", autospec=True)
def test_tenant_queue_wrapper_processes_next_tasks(
self,
mock_core_func,
@ -514,7 +514,7 @@ class TestDuplicateDocumentIndexingTaskWithTenantQueue:
document_ids=document_ids,
)
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task")
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task", autospec=True)
def test_tenant_queue_wrapper_handles_core_function_error(
self,
mock_core_func,
@ -544,7 +544,7 @@ class TestDuplicateDocumentIndexingTaskWithTenantQueue:
class TestNormalDuplicateDocumentIndexingTask:
"""Tests for normal_duplicate_document_indexing_task function."""
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue", autospec=True)
def test_normal_task_calls_tenant_queue_wrapper(
self,
mock_wrapper_func,
@ -561,7 +561,7 @@ class TestNormalDuplicateDocumentIndexingTask:
tenant_id, dataset_id, document_ids, normal_duplicate_document_indexing_task
)
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue", autospec=True)
def test_normal_task_with_empty_document_ids(
self,
mock_wrapper_func,
@ -589,7 +589,7 @@ class TestNormalDuplicateDocumentIndexingTask:
class TestPriorityDuplicateDocumentIndexingTask:
"""Tests for priority_duplicate_document_indexing_task function."""
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue", autospec=True)
def test_priority_task_calls_tenant_queue_wrapper(
self,
mock_wrapper_func,
@ -606,7 +606,7 @@ class TestPriorityDuplicateDocumentIndexingTask:
tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task
)
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue", autospec=True)
def test_priority_task_with_single_document(
self,
mock_wrapper_func,
@ -625,7 +625,7 @@ class TestPriorityDuplicateDocumentIndexingTask:
tenant_id, dataset_id, document_ids, priority_duplicate_document_indexing_task
)
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue")
@patch("tasks.duplicate_document_indexing_task._duplicate_document_indexing_task_with_tenant_queue", autospec=True)
def test_priority_task_with_large_batch(
self,
mock_wrapper_func,

View File

@ -321,7 +321,9 @@ def test_structured_output_parser():
)
else:
# Test successful cases
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
with patch(
"core.llm_generator.output_parser.structured_output.json_repair.loads", autospec=True
) as mock_json_repair:
# Configure json_repair mock for cases that need it
if case["name"] == "json_repair_scenario":
mock_json_repair.return_value = {"name": "test"}
@ -402,7 +404,9 @@ def test_parse_structured_output_edge_cases():
prompt_messages = [UserPromptMessage(content="Test reasoning")]
with patch("core.llm_generator.output_parser.structured_output.json_repair.loads") as mock_json_repair:
with patch(
"core.llm_generator.output_parser.structured_output.json_repair.loads", autospec=True
) as mock_json_repair:
# Mock json_repair to return a list with dict
mock_json_repair.return_value = [{"thought": "reasoning process"}, "other content"]