chore: fix the test cases

This commit is contained in:
Novice 2025-10-14 21:36:51 +08:00
parent 4a9fe55976
commit 3c6035490d
No known key found for this signature in database
GPG Key ID: EE3F68E3105DAAAB
4 changed files with 123 additions and 86 deletions

View File

@ -276,31 +276,34 @@ def sse_client(
read_queue: ReadQueue | None = None
write_queue: WriteQueue | None = None
with ThreadPoolExecutor() as executor:
try:
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
with ssrf_proxy_sse_connect(
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source:
event_source.response.raise_for_status()
executor = ThreadPoolExecutor()
try:
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
with ssrf_proxy_sse_connect(
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
) as event_source:
event_source.response.raise_for_status()
read_queue, write_queue = transport.connect(executor, client, event_source)
read_queue, write_queue = transport.connect(executor, client, event_source)
yield read_queue, write_queue
yield read_queue, write_queue
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401:
raise MCPAuthError()
raise MCPConnectionError()
except Exception:
logger.exception("Error connecting to SSE endpoint")
raise
finally:
# Clean up queues
if read_queue:
read_queue.put(None)
if write_queue:
write_queue.put(None)
except httpx.HTTPStatusError as exc:
if exc.response.status_code == 401:
raise MCPAuthError()
raise MCPConnectionError()
except Exception:
logger.exception("Error connecting to SSE endpoint")
raise
finally:
# Clean up queues
if read_queue:
read_queue.put(None)
if write_queue:
write_queue.put(None)
# Shutdown executor without waiting to prevent hanging
executor.shutdown(wait=False)
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage):

View File

@ -434,45 +434,48 @@ def streamablehttp_client(
server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client
client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server
with ThreadPoolExecutor(max_workers=2) as executor:
try:
with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
) as client:
# Define callbacks that need access to thread pool
def start_get_stream():
"""Start a worker thread to handle server-initiated messages."""
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
executor = ThreadPoolExecutor(max_workers=2)
try:
with create_ssrf_proxy_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
) as client:
# Define callbacks that need access to thread pool
def start_get_stream():
"""Start a worker thread to handle server-initiated messages."""
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
# Start the post_writer worker thread
executor.submit(
transport.post_writer,
client,
client_to_server_queue, # Queue for messages FROM client TO server
server_to_client_queue, # Queue for messages FROM server TO client
start_get_stream,
)
# Start the post_writer worker thread
executor.submit(
transport.post_writer,
client,
client_to_server_queue, # Queue for messages FROM client TO server
server_to_client_queue, # Queue for messages FROM server TO client
start_get_stream,
)
try:
yield (
server_to_client_queue, # Queue for receiving messages FROM server
client_to_server_queue, # Queue for sending messages TO server
transport.get_session_id,
)
finally:
if transport.session_id and terminate_on_close:
transport.terminate_session(client)
# Signal threads to stop
client_to_server_queue.put(None)
finally:
# Clear any remaining items and add None sentinel to unblock any waiting threads
try:
while not client_to_server_queue.empty():
client_to_server_queue.get_nowait()
except queue.Empty:
pass
yield (
server_to_client_queue, # Queue for receiving messages FROM server
client_to_server_queue, # Queue for sending messages TO server
transport.get_session_id,
)
finally:
if transport.session_id and terminate_on_close:
transport.terminate_session(client)
client_to_server_queue.put(None)
server_to_client_queue.put(None)
# Signal threads to stop
client_to_server_queue.put(None)
finally:
# Clear any remaining items and add None sentinel to unblock any waiting threads
try:
while not client_to_server_queue.empty():
client_to_server_queue.get_nowait()
except queue.Empty:
pass
client_to_server_queue.put(None)
server_to_client_queue.put(None)
# Shutdown executor without waiting to prevent hanging
executor.shutdown(wait=False)

View File

@ -201,11 +201,14 @@ class BaseSession(
self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds
except TimeoutError:
# If the receiver loop is still running after timeout, we'll force shutdown
pass
# Cancel the future to interrupt the receiver loop
self._receiver_future.cancel()
# Shutdown the executor
if self._executor:
self._executor.shutdown(wait=True)
# Use non-blocking shutdown to prevent hanging
# The receiver thread should have already exited due to the None message in the queue
self._executor.shutdown(wait=False)
def send_request(
self,

View File

@ -113,9 +113,9 @@ class TestMCPToolManageService:
mcp_provider = MCPToolProvider(
tenant_id=tenant_id,
name=fake.company(),
server_identifier=fake.uuid4(),
server_identifier=str(fake.uuid4()),
server_url="encrypted_server_url",
server_url_hash=fake.sha256(),
server_url_hash=str(fake.sha256()),
user_id=user_id,
authed=False,
tools="[]",
@ -364,6 +364,7 @@ class TestMCPToolManageService:
)
# Act: Execute the method under test
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
@ -376,8 +377,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#FF6B6B",
server_identifier="test_identifier_123",
timeout=30.0,
sse_read_timeout=300.0,
configuration=MCPConfiguration(
timeout=30.0,
sse_read_timeout=300.0,
),
)
# Assert: Verify the expected outcomes
@ -423,6 +426,7 @@ class TestMCPToolManageService:
)
# Create first provider
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
@ -435,8 +439,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#FF6B6B",
server_identifier="test_identifier_1",
timeout=30.0,
sse_read_timeout=300.0,
configuration=MCPConfiguration(
timeout=30.0,
sse_read_timeout=300.0,
),
)
# Act & Assert: Verify proper error handling for duplicate name
@ -450,8 +456,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="test_identifier_2",
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
def test_create_mcp_provider_duplicate_server_url(
@ -472,6 +480,7 @@ class TestMCPToolManageService:
)
# Create first provider
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
@ -484,8 +493,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#FF6B6B",
server_identifier="test_identifier_1",
timeout=30.0,
sse_read_timeout=300.0,
configuration=MCPConfiguration(
timeout=30.0,
sse_read_timeout=300.0,
),
)
# Act & Assert: Verify proper error handling for duplicate server URL
@ -499,8 +510,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="test_identifier_2",
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
def test_create_mcp_provider_duplicate_server_identifier(
@ -521,6 +534,7 @@ class TestMCPToolManageService:
)
# Create first provider
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
@ -533,8 +547,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#FF6B6B",
server_identifier="test_identifier_123",
timeout=30.0,
sse_read_timeout=300.0,
configuration=MCPConfiguration(
timeout=30.0,
sse_read_timeout=300.0,
),
)
# Act & Assert: Verify proper error handling for duplicate server identifier
@ -548,8 +564,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="test_identifier_123", # Duplicate identifier
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
def test_retrieve_mcp_tools_success(self, db_session_with_containers, mock_external_service_dependencies):
@ -1057,6 +1075,8 @@ class TestMCPToolManageService:
db.session.commit()
# Act: Execute the method under test
from core.entities.mcp_provider import MCPConfiguration
service = MCPToolManageService(db.session())
service.update_provider(
tenant_id=tenant.id,
@ -1067,8 +1087,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="updated_identifier_123",
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
# Assert: Verify the expected outcomes
@ -1082,7 +1104,7 @@ class TestMCPToolManageService:
# Verify icon was updated
import json
icon_data = json.loads(mcp_provider.icon)
icon_data = json.loads(mcp_provider.icon or "{}")
assert icon_data["content"] == "🚀"
assert icon_data["background"] == "#4ECDC4"
@ -1122,6 +1144,7 @@ class TestMCPToolManageService:
}
# Act: Execute the method under test
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
@ -1134,8 +1157,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="updated_identifier_123",
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
# Assert: Verify the expected outcomes
@ -1183,6 +1208,7 @@ class TestMCPToolManageService:
db.session.commit()
# Act & Assert: Verify proper error handling for duplicate name
from core.entities.mcp_provider import MCPConfiguration
from extensions.ext_database import db
service = MCPToolManageService(db.session())
@ -1196,8 +1222,10 @@ class TestMCPToolManageService:
icon_type="emoji",
icon_background="#4ECDC4",
server_identifier="unique_identifier",
timeout=45.0,
sse_read_timeout=400.0,
configuration=MCPConfiguration(
timeout=45.0,
sse_read_timeout=400.0,
),
)
def test_update_mcp_provider_credentials_success(
@ -1258,7 +1286,7 @@ class TestMCPToolManageService:
# Verify credentials were encrypted and merged
import json
credentials = json.loads(mcp_provider.encrypted_credentials)
credentials = json.loads(mcp_provider.encrypted_credentials or "{}")
assert "existing_key" in credentials
assert "new_key" in credentials