mirror of https://github.com/langgenius/dify.git
chore: fix the test cases
This commit is contained in:
parent
4a9fe55976
commit
3c6035490d
|
|
@ -276,31 +276,34 @@ def sse_client(
|
|||
read_queue: ReadQueue | None = None
|
||||
write_queue: WriteQueue | None = None
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
|
||||
with ssrf_proxy_sse_connect(
|
||||
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
executor = ThreadPoolExecutor()
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
|
||||
with ssrf_proxy_sse_connect(
|
||||
url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
|
||||
read_queue, write_queue = transport.connect(executor, client, event_source)
|
||||
read_queue, write_queue = transport.connect(executor, client, event_source)
|
||||
|
||||
yield read_queue, write_queue
|
||||
yield read_queue, write_queue
|
||||
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code == 401:
|
||||
raise MCPAuthError()
|
||||
raise MCPConnectionError()
|
||||
except Exception:
|
||||
logger.exception("Error connecting to SSE endpoint")
|
||||
raise
|
||||
finally:
|
||||
# Clean up queues
|
||||
if read_queue:
|
||||
read_queue.put(None)
|
||||
if write_queue:
|
||||
write_queue.put(None)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
if exc.response.status_code == 401:
|
||||
raise MCPAuthError()
|
||||
raise MCPConnectionError()
|
||||
except Exception:
|
||||
logger.exception("Error connecting to SSE endpoint")
|
||||
raise
|
||||
finally:
|
||||
# Clean up queues
|
||||
if read_queue:
|
||||
read_queue.put(None)
|
||||
if write_queue:
|
||||
write_queue.put(None)
|
||||
|
||||
# Shutdown executor without waiting to prevent hanging
|
||||
executor.shutdown(wait=False)
|
||||
|
||||
|
||||
def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage):
|
||||
|
|
|
|||
|
|
@ -434,45 +434,48 @@ def streamablehttp_client(
|
|||
server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client
|
||||
client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server
|
||||
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(
|
||||
headers=transport.request_headers,
|
||||
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
|
||||
) as client:
|
||||
# Define callbacks that need access to thread pool
|
||||
def start_get_stream():
|
||||
"""Start a worker thread to handle server-initiated messages."""
|
||||
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
|
||||
executor = ThreadPoolExecutor(max_workers=2)
|
||||
try:
|
||||
with create_ssrf_proxy_mcp_http_client(
|
||||
headers=transport.request_headers,
|
||||
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
|
||||
) as client:
|
||||
# Define callbacks that need access to thread pool
|
||||
def start_get_stream():
|
||||
"""Start a worker thread to handle server-initiated messages."""
|
||||
executor.submit(transport.handle_get_stream, client, server_to_client_queue)
|
||||
|
||||
# Start the post_writer worker thread
|
||||
executor.submit(
|
||||
transport.post_writer,
|
||||
client,
|
||||
client_to_server_queue, # Queue for messages FROM client TO server
|
||||
server_to_client_queue, # Queue for messages FROM server TO client
|
||||
start_get_stream,
|
||||
)
|
||||
# Start the post_writer worker thread
|
||||
executor.submit(
|
||||
transport.post_writer,
|
||||
client,
|
||||
client_to_server_queue, # Queue for messages FROM client TO server
|
||||
server_to_client_queue, # Queue for messages FROM server TO client
|
||||
start_get_stream,
|
||||
)
|
||||
|
||||
try:
|
||||
yield (
|
||||
server_to_client_queue, # Queue for receiving messages FROM server
|
||||
client_to_server_queue, # Queue for sending messages TO server
|
||||
transport.get_session_id,
|
||||
)
|
||||
finally:
|
||||
if transport.session_id and terminate_on_close:
|
||||
transport.terminate_session(client)
|
||||
|
||||
# Signal threads to stop
|
||||
client_to_server_queue.put(None)
|
||||
finally:
|
||||
# Clear any remaining items and add None sentinel to unblock any waiting threads
|
||||
try:
|
||||
while not client_to_server_queue.empty():
|
||||
client_to_server_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
yield (
|
||||
server_to_client_queue, # Queue for receiving messages FROM server
|
||||
client_to_server_queue, # Queue for sending messages TO server
|
||||
transport.get_session_id,
|
||||
)
|
||||
finally:
|
||||
if transport.session_id and terminate_on_close:
|
||||
transport.terminate_session(client)
|
||||
|
||||
client_to_server_queue.put(None)
|
||||
server_to_client_queue.put(None)
|
||||
# Signal threads to stop
|
||||
client_to_server_queue.put(None)
|
||||
finally:
|
||||
# Clear any remaining items and add None sentinel to unblock any waiting threads
|
||||
try:
|
||||
while not client_to_server_queue.empty():
|
||||
client_to_server_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
client_to_server_queue.put(None)
|
||||
server_to_client_queue.put(None)
|
||||
|
||||
# Shutdown executor without waiting to prevent hanging
|
||||
executor.shutdown(wait=False)
|
||||
|
|
|
|||
|
|
@ -201,11 +201,14 @@ class BaseSession(
|
|||
self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds
|
||||
except TimeoutError:
|
||||
# If the receiver loop is still running after timeout, we'll force shutdown
|
||||
pass
|
||||
# Cancel the future to interrupt the receiver loop
|
||||
self._receiver_future.cancel()
|
||||
|
||||
# Shutdown the executor
|
||||
if self._executor:
|
||||
self._executor.shutdown(wait=True)
|
||||
# Use non-blocking shutdown to prevent hanging
|
||||
# The receiver thread should have already exited due to the None message in the queue
|
||||
self._executor.shutdown(wait=False)
|
||||
|
||||
def send_request(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -113,9 +113,9 @@ class TestMCPToolManageService:
|
|||
mcp_provider = MCPToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
name=fake.company(),
|
||||
server_identifier=fake.uuid4(),
|
||||
server_identifier=str(fake.uuid4()),
|
||||
server_url="encrypted_server_url",
|
||||
server_url_hash=fake.sha256(),
|
||||
server_url_hash=str(fake.sha256()),
|
||||
user_id=user_id,
|
||||
authed=False,
|
||||
tools="[]",
|
||||
|
|
@ -364,6 +364,7 @@ class TestMCPToolManageService:
|
|||
)
|
||||
|
||||
# Act: Execute the method under test
|
||||
from core.entities.mcp_provider import MCPConfiguration
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
|
|
@ -376,8 +377,10 @@ class TestMCPToolManageService:
|
|||
icon_type="emoji",
|
||||
icon_background="#FF6B6B",
|
||||
server_identifier="test_identifier_123",
|
||||
timeout=30.0,
|
||||
sse_read_timeout=300.0,
|
||||
configuration=MCPConfiguration(
|
||||
timeout=30.0,
|
||||
sse_read_timeout=300.0,
|
||||
),
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
|
|
@ -423,6 +426,7 @@ class TestMCPToolManageService:
|
|||
)
|
||||
|
||||
# Create first provider
|
||||
from core.entities.mcp_provider import MCPConfiguration
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
|
|
@ -435,8 +439,10 @@ class TestMCPToolManageService:
|
|||
icon_type="emoji",
|
||||
icon_background="#FF6B6B",
|
||||
server_identifier="test_identifier_1",
|
||||
timeout=30.0,
|
||||
sse_read_timeout=300.0,
|
||||
configuration=MCPConfiguration(
|
||||
timeout=30.0,
|
||||
sse_read_timeout=300.0,
|
||||
),
|
||||
)
|
||||
|
||||
# Act & Assert: Verify proper error handling for duplicate name
|
||||
|
|
@ -450,8 +456,10 @@ class TestMCPToolManageService:
|
|||
icon_type="emoji",
|
||||
icon_background="#4ECDC4",
|
||||
server_identifier="test_identifier_2",
|
||||
timeout=45.0,
|
||||
sse_read_timeout=400.0,
|
||||
configuration=MCPConfiguration(
|
||||
timeout=45.0,
|
||||
sse_read_timeout=400.0,
|
||||
),
|
||||
)
|
||||
|
||||
def test_create_mcp_provider_duplicate_server_url(
|
||||
|
|
@ -472,6 +480,7 @@ class TestMCPToolManageService:
|
|||
)
|
||||
|
||||
# Create first provider
|
||||
from core.entities.mcp_provider import MCPConfiguration
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
|
|
@ -484,8 +493,10 @@ class TestMCPToolManageService:
|
|||
icon_type="emoji",
|
||||
icon_background="#FF6B6B",
|
||||
server_identifier="test_identifier_1",
|
||||
timeout=30.0,
|
||||
sse_read_timeout=300.0,
|
||||
configuration=MCPConfiguration(
|
||||
timeout=30.0,
|
||||
sse_read_timeout=300.0,
|
||||
),
|
||||
)
|
||||
|
||||
# Act & Assert: Verify proper error handling for duplicate server URL
|
||||
|
|
@ -499,8 +510,10 @@ class TestMCPToolManageService:
|
|||
icon_type="emoji",
|
||||
icon_background="#4ECDC4",
|
||||
server_identifier="test_identifier_2",
|
||||
timeout=45.0,
|
||||
sse_read_timeout=400.0,
|
||||
configuration=MCPConfiguration(
|
||||
timeout=45.0,
|
||||
sse_read_timeout=400.0,
|
||||
),
|
||||
)
|
||||
|
||||
def test_create_mcp_provider_duplicate_server_identifier(
|
||||
|
|
@ -521,6 +534,7 @@ class TestMCPToolManageService:
|
|||
)
|
||||
|
||||
# Create first provider
|
||||
from core.entities.mcp_provider import MCPConfiguration
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
|
|
@ -533,8 +547,10 @@ class TestMCPToolManageService:
|
|||
icon_type="emoji",
|
||||
icon_background="#FF6B6B",
|
||||
server_identifier="test_identifier_123",
|
||||
timeout=30.0,
|
||||
sse_read_timeout=300.0,
|
||||
configuration=MCPConfiguration(
|
||||
timeout=30.0,
|
||||
sse_read_timeout=300.0,
|
||||
),
|
||||
)
|
||||
|
||||
# Act & Assert: Verify proper error handling for duplicate server identifier
|
||||
|
|
@ -548,8 +564,10 @@ class TestMCPToolManageService:
|
|||
icon_type="emoji",
|
||||
icon_background="#4ECDC4",
|
||||
server_identifier="test_identifier_123", # Duplicate identifier
|
||||
timeout=45.0,
|
||||
sse_read_timeout=400.0,
|
||||
configuration=MCPConfiguration(
|
||||
timeout=45.0,
|
||||
sse_read_timeout=400.0,
|
||||
),
|
||||
)
|
||||
|
||||
def test_retrieve_mcp_tools_success(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
|
|
@ -1057,6 +1075,8 @@ class TestMCPToolManageService:
|
|||
db.session.commit()
|
||||
|
||||
# Act: Execute the method under test
|
||||
from core.entities.mcp_provider import MCPConfiguration
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
service.update_provider(
|
||||
tenant_id=tenant.id,
|
||||
|
|
@ -1067,8 +1087,10 @@ class TestMCPToolManageService:
|
|||
icon_type="emoji",
|
||||
icon_background="#4ECDC4",
|
||||
server_identifier="updated_identifier_123",
|
||||
timeout=45.0,
|
||||
sse_read_timeout=400.0,
|
||||
configuration=MCPConfiguration(
|
||||
timeout=45.0,
|
||||
sse_read_timeout=400.0,
|
||||
),
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
|
|
@ -1082,7 +1104,7 @@ class TestMCPToolManageService:
|
|||
# Verify icon was updated
|
||||
import json
|
||||
|
||||
icon_data = json.loads(mcp_provider.icon)
|
||||
icon_data = json.loads(mcp_provider.icon or "{}")
|
||||
assert icon_data["content"] == "🚀"
|
||||
assert icon_data["background"] == "#4ECDC4"
|
||||
|
||||
|
|
@ -1122,6 +1144,7 @@ class TestMCPToolManageService:
|
|||
}
|
||||
|
||||
# Act: Execute the method under test
|
||||
from core.entities.mcp_provider import MCPConfiguration
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
|
|
@ -1134,8 +1157,10 @@ class TestMCPToolManageService:
|
|||
icon_type="emoji",
|
||||
icon_background="#4ECDC4",
|
||||
server_identifier="updated_identifier_123",
|
||||
timeout=45.0,
|
||||
sse_read_timeout=400.0,
|
||||
configuration=MCPConfiguration(
|
||||
timeout=45.0,
|
||||
sse_read_timeout=400.0,
|
||||
),
|
||||
)
|
||||
|
||||
# Assert: Verify the expected outcomes
|
||||
|
|
@ -1183,6 +1208,7 @@ class TestMCPToolManageService:
|
|||
db.session.commit()
|
||||
|
||||
# Act & Assert: Verify proper error handling for duplicate name
|
||||
from core.entities.mcp_provider import MCPConfiguration
|
||||
from extensions.ext_database import db
|
||||
|
||||
service = MCPToolManageService(db.session())
|
||||
|
|
@ -1196,8 +1222,10 @@ class TestMCPToolManageService:
|
|||
icon_type="emoji",
|
||||
icon_background="#4ECDC4",
|
||||
server_identifier="unique_identifier",
|
||||
timeout=45.0,
|
||||
sse_read_timeout=400.0,
|
||||
configuration=MCPConfiguration(
|
||||
timeout=45.0,
|
||||
sse_read_timeout=400.0,
|
||||
),
|
||||
)
|
||||
|
||||
def test_update_mcp_provider_credentials_success(
|
||||
|
|
@ -1258,7 +1286,7 @@ class TestMCPToolManageService:
|
|||
# Verify credentials were encrypted and merged
|
||||
import json
|
||||
|
||||
credentials = json.loads(mcp_provider.encrypted_credentials)
|
||||
credentials = json.loads(mcp_provider.encrypted_credentials or "{}")
|
||||
assert "existing_key" in credentials
|
||||
assert "new_key" in credentials
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue