more httpx (#25651)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Asuka Minato 2025-09-23 00:07:09 +09:00 committed by GitHub
parent 0c4193bd91
commit 8940decd1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 120 additions and 117 deletions

View File

@ -5,7 +5,7 @@ import logging
import os import os
import time import time
import requests import httpx
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,10 +30,10 @@ class NacosHttpClient:
params = {} params = {}
try: try:
self._inject_auth_info(headers, params) self._inject_auth_info(headers, params)
response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params) response = httpx.request(method, url="http://" + self.server + url, headers=headers, params=params)
response.raise_for_status() response.raise_for_status()
return response.text return response.text
except requests.RequestException as e: except httpx.RequestError as e:
return f"Request to Nacos failed: {e}" return f"Request to Nacos failed: {e}"
def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None: def _inject_auth_info(self, headers: dict[str, str], params: dict[str, str], module: str = "config") -> None:
@ -78,7 +78,7 @@ class NacosHttpClient:
params = {"username": self.username, "password": self.password} params = {"username": self.username, "password": self.password}
url = "http://" + self.server + "/nacos/v1/auth/login" url = "http://" + self.server + "/nacos/v1/auth/login"
try: try:
resp = requests.request("POST", url, headers=None, params=params) resp = httpx.request("POST", url, headers=None, params=params)
resp.raise_for_status() resp.raise_for_status()
response_data = resp.json() response_data = resp.json()
self.token = response_data.get("accessToken") self.token = response_data.get("accessToken")

View File

@ -1,6 +1,6 @@
import logging import logging
import requests import httpx
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_login import current_user from flask_login import current_user
from flask_restx import Resource, fields from flask_restx import Resource, fields
@ -119,7 +119,7 @@ class OAuthDataSourceBinding(Resource):
return {"error": "Invalid code"}, 400 return {"error": "Invalid code"}, 400
try: try:
oauth_provider.get_access_token(code) oauth_provider.get_access_token(code)
except requests.HTTPError as e: except httpx.HTTPStatusError as e:
logger.exception( logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
) )
@ -152,7 +152,7 @@ class OAuthDataSourceSync(Resource):
return {"error": "Invalid provider"}, 400 return {"error": "Invalid provider"}, 400
try: try:
oauth_provider.sync_data_source(binding_id) oauth_provider.sync_data_source(binding_id)
except requests.HTTPError as e: except httpx.HTTPStatusError as e:
logger.exception( logger.exception(
"An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text "An error occurred during the OAuthCallback process with %s: %s", provider, e.response.text
) )

View File

@ -1,6 +1,6 @@
import logging import logging
import requests import httpx
from flask import current_app, redirect, request from flask import current_app, redirect, request
from flask_restx import Resource from flask_restx import Resource
from sqlalchemy import select from sqlalchemy import select
@ -101,8 +101,10 @@ class OAuthCallback(Resource):
try: try:
token = oauth_provider.get_access_token(code) token = oauth_provider.get_access_token(code)
user_info = oauth_provider.get_user_info(token) user_info = oauth_provider.get_user_info(token)
except requests.RequestException as e: except httpx.RequestError as e:
error_text = e.response.text if e.response else str(e) error_text = str(e)
if isinstance(e, httpx.HTTPStatusError):
error_text = e.response.text
logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text) logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text)
return {"error": "OAuth process failed"}, 400 return {"error": "OAuth process failed"}, 400

View File

@ -1,7 +1,7 @@
import json import json
import logging import logging
import requests import httpx
from flask_restx import Resource, fields, reqparse from flask_restx import Resource, fields, reqparse
from packaging import version from packaging import version
@ -57,7 +57,11 @@ class VersionApi(Resource):
return result return result
try: try:
response = requests.get(check_update_url, {"current_version": args["current_version"]}, timeout=(3, 10)) response = httpx.get(
check_update_url,
params={"current_version": args["current_version"]},
timeout=httpx.Timeout(connect=3, read=10),
)
except Exception as error: except Exception as error:
logger.warning("Check update version error: %s.", str(error)) logger.warning("Check update version error: %s.", str(error))
result["version"] = args["current_version"] result["version"] = args["current_version"]

View File

@ -8,7 +8,7 @@ from collections import deque
from collections.abc import Sequence from collections.abc import Sequence
from datetime import datetime from datetime import datetime
import requests import httpx
from opentelemetry import trace as trace_api from opentelemetry import trace as trace_api
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.resources import Resource
@ -65,13 +65,13 @@ class TraceClient:
def api_check(self): def api_check(self):
try: try:
response = requests.head(self.endpoint, timeout=5) response = httpx.head(self.endpoint, timeout=5)
if response.status_code == 405: if response.status_code == 405:
return True return True
else: else:
logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code) logger.debug("AliyunTrace API check failed: Unexpected status code: %s", response.status_code)
return False return False
except requests.RequestException as e: except httpx.RequestError as e:
logger.debug("AliyunTrace API check failed: %s", str(e)) logger.debug("AliyunTrace API check failed: %s", str(e))
raise ValueError(f"AliyunTrace API check failed: {str(e)}") raise ValueError(f"AliyunTrace API check failed: {str(e)}")

View File

@ -1,7 +1,7 @@
import urllib.parse import urllib.parse
from dataclasses import dataclass from dataclasses import dataclass
import requests import httpx
@dataclass @dataclass
@ -58,7 +58,7 @@ class GitHubOAuth(OAuth):
"redirect_uri": self.redirect_uri, "redirect_uri": self.redirect_uri,
} }
headers = {"Accept": "application/json"} headers = {"Accept": "application/json"}
response = requests.post(self._TOKEN_URL, data=data, headers=headers) response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json() response_json = response.json()
access_token = response_json.get("access_token") access_token = response_json.get("access_token")
@ -70,11 +70,11 @@ class GitHubOAuth(OAuth):
def get_raw_user_info(self, token: str): def get_raw_user_info(self, token: str):
headers = {"Authorization": f"token {token}"} headers = {"Authorization": f"token {token}"}
response = requests.get(self._USER_INFO_URL, headers=headers) response = httpx.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status() response.raise_for_status()
user_info = response.json() user_info = response.json()
email_response = requests.get(self._EMAIL_INFO_URL, headers=headers) email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
email_info = email_response.json() email_info = email_response.json()
primary_email: dict = next((email for email in email_info if email["primary"] == True), {}) primary_email: dict = next((email for email in email_info if email["primary"] == True), {})
@ -112,7 +112,7 @@ class GoogleOAuth(OAuth):
"redirect_uri": self.redirect_uri, "redirect_uri": self.redirect_uri,
} }
headers = {"Accept": "application/json"} headers = {"Accept": "application/json"}
response = requests.post(self._TOKEN_URL, data=data, headers=headers) response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json() response_json = response.json()
access_token = response_json.get("access_token") access_token = response_json.get("access_token")
@ -124,7 +124,7 @@ class GoogleOAuth(OAuth):
def get_raw_user_info(self, token: str): def get_raw_user_info(self, token: str):
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
response = requests.get(self._USER_INFO_URL, headers=headers) response = httpx.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()

View File

@ -1,7 +1,7 @@
import urllib.parse import urllib.parse
from typing import Any from typing import Any
import requests import httpx
from flask_login import current_user from flask_login import current_user
from sqlalchemy import select from sqlalchemy import select
@ -43,7 +43,7 @@ class NotionOAuth(OAuthDataSource):
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri} data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
headers = {"Accept": "application/json"} headers = {"Accept": "application/json"}
auth = (self.client_id, self.client_secret) auth = (self.client_id, self.client_secret)
response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers) response = httpx.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
response_json = response.json() response_json = response.json()
access_token = response_json.get("access_token") access_token = response_json.get("access_token")
@ -239,7 +239,7 @@ class NotionOAuth(OAuthDataSource):
"Notion-Version": "2022-06-28", "Notion-Version": "2022-06-28",
} }
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response_json = response.json() response_json = response.json()
results.extend(response_json.get("results", [])) results.extend(response_json.get("results", []))
@ -254,7 +254,7 @@ class NotionOAuth(OAuthDataSource):
"Authorization": f"Bearer {access_token}", "Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28", "Notion-Version": "2022-06-28",
} }
response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers) response = httpx.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
response_json = response.json() response_json = response.json()
if response.status_code != 200: if response.status_code != 200:
message = response_json.get("message", "unknown error") message = response_json.get("message", "unknown error")
@ -270,7 +270,7 @@ class NotionOAuth(OAuthDataSource):
"Authorization": f"Bearer {access_token}", "Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28", "Notion-Version": "2022-06-28",
} }
response = requests.get(url=self._NOTION_BOT_USER, headers=headers) response = httpx.get(url=self._NOTION_BOT_USER, headers=headers)
response_json = response.json() response_json = response.json()
if "object" in response_json and response_json["object"] == "user": if "object" in response_json and response_json["object"] == "user":
user_type = response_json["type"] user_type = response_json["type"]
@ -294,7 +294,7 @@ class NotionOAuth(OAuthDataSource):
"Authorization": f"Bearer {access_token}", "Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28", "Notion-Version": "2022-06-28",
} }
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response_json = response.json() response_json = response.json()
results.extend(response_json.get("results", [])) results.extend(response_json.get("results", []))

View File

@ -1,6 +1,6 @@
import json import json
import requests import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase from services.auth.api_key_auth_base import ApiKeyAuthBase
@ -36,7 +36,7 @@ class FirecrawlAuth(ApiKeyAuthBase):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers): def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data) return httpx.post(url, headers=headers, json=data)
def _handle_error(self, response): def _handle_error(self, response):
if response.status_code in {402, 409, 500}: if response.status_code in {402, 409, 500}:

View File

@ -1,6 +1,6 @@
import json import json
import requests import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase from services.auth.api_key_auth_base import ApiKeyAuthBase
@ -31,7 +31,7 @@ class JinaAuth(ApiKeyAuthBase):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers): def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data) return httpx.post(url, headers=headers, json=data)
def _handle_error(self, response): def _handle_error(self, response):
if response.status_code in {402, 409, 500}: if response.status_code in {402, 409, 500}:

View File

@ -1,6 +1,6 @@
import json import json
import requests import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase from services.auth.api_key_auth_base import ApiKeyAuthBase
@ -31,7 +31,7 @@ class JinaAuth(ApiKeyAuthBase):
return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
def _post_request(self, url, data, headers): def _post_request(self, url, data, headers):
return requests.post(url, headers=headers, json=data) return httpx.post(url, headers=headers, json=data)
def _handle_error(self, response): def _handle_error(self, response):
if response.status_code in {402, 409, 500}: if response.status_code in {402, 409, 500}:

View File

@ -1,7 +1,7 @@
import json import json
from urllib.parse import urljoin from urllib.parse import urljoin
import requests import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase from services.auth.api_key_auth_base import ApiKeyAuthBase
@ -31,7 +31,7 @@ class WatercrawlAuth(ApiKeyAuthBase):
return {"Content-Type": "application/json", "X-API-KEY": self.api_key} return {"Content-Type": "application/json", "X-API-KEY": self.api_key}
def _get_request(self, url, headers): def _get_request(self, url, headers):
return requests.get(url, headers=headers) return httpx.get(url, headers=headers)
def _handle_error(self, response): def _handle_error(self, response):
if response.status_code in {402, 409, 500}: if response.status_code in {402, 409, 500}:

View File

@ -1,6 +1,6 @@
import os import os
import requests import httpx
class OperationService: class OperationService:
@ -12,7 +12,7 @@ class OperationService:
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key} headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}" url = f"{cls.base_url}{endpoint}"
response = requests.request(method, url, json=json, params=params, headers=headers) response = httpx.request(method, url, json=json, params=params, headers=headers)
return response.json() return response.json()

View File

@ -3,7 +3,7 @@ import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
import requests import httpx
from flask_login import current_user from flask_login import current_user
from core.helper import encrypter from core.helper import encrypter
@ -216,7 +216,7 @@ class WebsiteService:
@classmethod @classmethod
def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]: def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]:
if not request.options.crawl_sub_pages: if not request.options.crawl_sub_pages:
response = requests.get( response = httpx.get(
f"https://r.jina.ai/{request.url}", f"https://r.jina.ai/{request.url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
) )
@ -224,7 +224,7 @@ class WebsiteService:
raise ValueError("Failed to crawl:") raise ValueError("Failed to crawl:")
return {"status": "active", "data": response.json().get("data")} return {"status": "active", "data": response.json().get("data")}
else: else:
response = requests.post( response = httpx.post(
"https://adaptivecrawl-kir3wx7b3a-uc.a.run.app", "https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
json={ json={
"url": request.url, "url": request.url,
@ -287,7 +287,7 @@ class WebsiteService:
@classmethod @classmethod
def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]: def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]:
response = requests.post( response = httpx.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id}, json={"taskId": job_id},
@ -303,7 +303,7 @@ class WebsiteService:
} }
if crawl_status_data["status"] == "completed": if crawl_status_data["status"] == "completed":
response = requests.post( response = httpx.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())}, json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
@ -362,7 +362,7 @@ class WebsiteService:
@classmethod @classmethod
def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None: def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None:
if not job_id: if not job_id:
response = requests.get( response = httpx.get(
f"https://r.jina.ai/{url}", f"https://r.jina.ai/{url}",
headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
) )
@ -371,7 +371,7 @@ class WebsiteService:
return dict(response.json().get("data", {})) return dict(response.json().get("data", {}))
else: else:
# Get crawl status first # Get crawl status first
status_response = requests.post( status_response = httpx.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id}, json={"taskId": job_id},
@ -381,7 +381,7 @@ class WebsiteService:
raise ValueError("Crawl job is not completed") raise ValueError("Crawl job is not completed")
# Get processed data # Get processed data
data_response = requests.post( data_response = httpx.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())}, json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},

View File

@ -1,8 +1,8 @@
import os import os
from typing import Literal from typing import Literal
import httpx
import pytest import pytest
import requests
from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse
from core.tools.entities.common_entities import I18nObject from core.tools.entities.common_entities import I18nObject
@ -27,13 +27,11 @@ class MockedHttp:
@classmethod @classmethod
def requests_request( def requests_request(
cls, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs cls, method: Literal["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"], url: str, **kwargs
) -> requests.Response: ) -> httpx.Response:
""" """
Mocked requests.request Mocked httpx.request
""" """
request = requests.PreparedRequest() request = httpx.Request(method, url)
request.method = method
request.url = url
if url.endswith("/tools"): if url.endswith("/tools"):
content = PluginDaemonBasicResponse[list[ToolProviderEntity]]( content = PluginDaemonBasicResponse[list[ToolProviderEntity]](
code=0, message="success", data=cls.list_tools() code=0, message="success", data=cls.list_tools()
@ -41,8 +39,7 @@ class MockedHttp:
else: else:
raise ValueError("") raise ValueError("")
response = requests.Response() response = httpx.Response(status_code=200)
response.status_code = 200
response.request = request response.request = request
response._content = content.encode("utf-8") response._content = content.encode("utf-8")
return response return response
@ -54,7 +51,7 @@ MOCK_SWITCH = os.getenv("MOCK_SWITCH", "false").lower() == "true"
@pytest.fixture @pytest.fixture
def setup_http_mock(request, monkeypatch: pytest.MonkeyPatch): def setup_http_mock(request, monkeypatch: pytest.MonkeyPatch):
if MOCK_SWITCH: if MOCK_SWITCH:
monkeypatch.setattr(requests, "request", MockedHttp.requests_request) monkeypatch.setattr(httpx, "request", MockedHttp.requests_request)
def unpatch(): def unpatch():
monkeypatch.undo() monkeypatch.undo()

View File

@ -6,7 +6,7 @@ Test Clickzetta integration in Docker environment
import os import os
import time import time
import requests import httpx
from clickzetta import connect from clickzetta import connect
@ -66,7 +66,7 @@ def test_dify_api():
max_retries = 30 max_retries = 30
for i in range(max_retries): for i in range(max_retries):
try: try:
response = requests.get(f"{base_url}/console/api/health") response = httpx.get(f"{base_url}/console/api/health")
if response.status_code == 200: if response.status_code == 200:
print("✓ Dify API is ready") print("✓ Dify API is ready")
break break

View File

@ -201,9 +201,9 @@ class TestOAuthCallback:
mock_db.session.rollback = MagicMock() mock_db.session.rollback = MagicMock()
# Import the real requests module to create a proper exception # Import the real requests module to create a proper exception
import requests import httpx
request_exception = requests.exceptions.RequestException("OAuth error") request_exception = httpx.RequestError("OAuth error")
request_exception.response = MagicMock() request_exception.response = MagicMock()
request_exception.response.text = str(exception) request_exception.response.text = str(exception)

View File

@ -1,8 +1,8 @@
import urllib.parse import urllib.parse
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import httpx
import pytest import pytest
import requests
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
@ -68,7 +68,7 @@ class TestGitHubOAuth(BaseOAuthTest):
({}, None, True), ({}, None, True),
], ],
) )
@patch("requests.post") @patch("httpx.post")
def test_should_retrieve_access_token( def test_should_retrieve_access_token(
self, mock_post, oauth, mock_response, response_data, expected_token, should_raise self, mock_post, oauth, mock_response, response_data, expected_token, should_raise
): ):
@ -105,7 +105,7 @@ class TestGitHubOAuth(BaseOAuthTest):
), ),
], ],
) )
@patch("requests.get") @patch("httpx.get")
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email): def test_should_retrieve_user_info_correctly(self, mock_get, oauth, user_data, email_data, expected_email):
user_response = MagicMock() user_response = MagicMock()
user_response.json.return_value = user_data user_response.json.return_value = user_data
@ -121,11 +121,11 @@ class TestGitHubOAuth(BaseOAuthTest):
assert user_info.name == user_data["name"] assert user_info.name == user_data["name"]
assert user_info.email == expected_email assert user_info.email == expected_email
@patch("requests.get") @patch("httpx.get")
def test_should_handle_network_errors(self, mock_get, oauth): def test_should_handle_network_errors(self, mock_get, oauth):
mock_get.side_effect = requests.exceptions.RequestException("Network error") mock_get.side_effect = httpx.RequestError("Network error")
with pytest.raises(requests.exceptions.RequestException): with pytest.raises(httpx.RequestError):
oauth.get_raw_user_info("test_token") oauth.get_raw_user_info("test_token")
@ -167,7 +167,7 @@ class TestGoogleOAuth(BaseOAuthTest):
({}, None, True), ({}, None, True),
], ],
) )
@patch("requests.post") @patch("httpx.post")
def test_should_retrieve_access_token( def test_should_retrieve_access_token(
self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise self, mock_post, oauth, oauth_config, mock_response, response_data, expected_token, should_raise
): ):
@ -201,7 +201,7 @@ class TestGoogleOAuth(BaseOAuthTest):
({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string ({"sub": "123", "email": "test@example.com", "name": "Test User"}, ""), # Always returns empty string
], ],
) )
@patch("requests.get") @patch("httpx.get")
def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name): def test_should_retrieve_user_info_correctly(self, mock_get, oauth, mock_response, user_data, expected_name):
mock_response.json.return_value = user_data mock_response.json.return_value = user_data
mock_get.return_value = mock_response mock_get.return_value = mock_response
@ -217,12 +217,12 @@ class TestGoogleOAuth(BaseOAuthTest):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"exception_type", "exception_type",
[ [
requests.exceptions.HTTPError, httpx.HTTPError,
requests.exceptions.ConnectionError, httpx.ConnectError,
requests.exceptions.Timeout, httpx.TimeoutException,
], ],
) )
@patch("requests.get") @patch("httpx.get")
def test_should_handle_http_errors(self, mock_get, oauth, exception_type): def test_should_handle_http_errors(self, mock_get, oauth, exception_type):
mock_response = MagicMock() mock_response = MagicMock()
mock_response.raise_for_status.side_effect = exception_type("Error") mock_response.raise_for_status.side_effect = exception_type("Error")

View File

@ -6,8 +6,8 @@ import json
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import httpx
import pytest import pytest
import requests
from services.auth.api_key_auth_factory import ApiKeyAuthFactory from services.auth.api_key_auth_factory import ApiKeyAuthFactory
from services.auth.api_key_auth_service import ApiKeyAuthService from services.auth.api_key_auth_service import ApiKeyAuthService
@ -26,7 +26,7 @@ class TestAuthIntegration:
self.watercrawl_credentials = {"auth_type": "x-api-key", "config": {"api_key": "wc_test_key_789"}} self.watercrawl_credentials = {"auth_type": "x-api-key", "config": {"api_key": "wc_test_key_789"}}
@patch("services.auth.api_key_auth_service.db.session") @patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token") @patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
def test_end_to_end_auth_flow(self, mock_encrypt, mock_http, mock_session): def test_end_to_end_auth_flow(self, mock_encrypt, mock_http, mock_session):
"""Test complete authentication flow: request → validation → encryption → storage""" """Test complete authentication flow: request → validation → encryption → storage"""
@ -47,7 +47,7 @@ class TestAuthIntegration:
mock_session.add.assert_called_once() mock_session.add.assert_called_once()
mock_session.commit.assert_called_once() mock_session.commit.assert_called_once()
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_cross_component_integration(self, mock_http): def test_cross_component_integration(self, mock_http):
"""Test factory → provider → HTTP call integration""" """Test factory → provider → HTTP call integration"""
mock_http.return_value = self._create_success_response() mock_http.return_value = self._create_success_response()
@ -97,7 +97,7 @@ class TestAuthIntegration:
assert "another_secret" not in factory_str assert "another_secret" not in factory_str
@patch("services.auth.api_key_auth_service.db.session") @patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
@patch("services.auth.api_key_auth_service.encrypter.encrypt_token") @patch("services.auth.api_key_auth_service.encrypter.encrypt_token")
def test_concurrent_creation_safety(self, mock_encrypt, mock_http, mock_session): def test_concurrent_creation_safety(self, mock_encrypt, mock_http, mock_session):
"""Test concurrent authentication creation safety""" """Test concurrent authentication creation safety"""
@ -142,31 +142,31 @@ class TestAuthIntegration:
with pytest.raises((ValueError, KeyError, TypeError, AttributeError)): with pytest.raises((ValueError, KeyError, TypeError, AttributeError)):
ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input) ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input)
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_http_error_handling(self, mock_http): def test_http_error_handling(self, mock_http):
"""Test proper HTTP error handling""" """Test proper HTTP error handling"""
mock_response = Mock() mock_response = Mock()
mock_response.status_code = 401 mock_response.status_code = 401
mock_response.text = '{"error": "Unauthorized"}' mock_response.text = '{"error": "Unauthorized"}'
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("Unauthorized") mock_response.raise_for_status.side_effect = httpx.HTTPError("Unauthorized")
mock_http.return_value = mock_response mock_http.return_value = mock_response
# PT012: Split into single statement for pytest.raises # PT012: Split into single statement for pytest.raises
factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials) factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials)
with pytest.raises((requests.exceptions.HTTPError, Exception)): with pytest.raises((httpx.HTTPError, Exception)):
factory.validate_credentials() factory.validate_credentials()
@patch("services.auth.api_key_auth_service.db.session") @patch("services.auth.api_key_auth_service.db.session")
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_network_failure_recovery(self, mock_http, mock_session): def test_network_failure_recovery(self, mock_http, mock_session):
"""Test system recovery from network failures""" """Test system recovery from network failures"""
mock_http.side_effect = requests.exceptions.RequestException("Network timeout") mock_http.side_effect = httpx.RequestError("Network timeout")
mock_session.add = Mock() mock_session.add = Mock()
mock_session.commit = Mock() mock_session.commit = Mock()
args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials} args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials}
with pytest.raises(requests.exceptions.RequestException): with pytest.raises(httpx.RequestError):
ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args) ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args)
mock_session.commit.assert_not_called() mock_session.commit.assert_not_called()

View File

@ -1,7 +1,7 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import httpx
import pytest import pytest
import requests
from services.auth.firecrawl.firecrawl import FirecrawlAuth from services.auth.firecrawl.firecrawl import FirecrawlAuth
@ -64,7 +64,7 @@ class TestFirecrawlAuth:
FirecrawlAuth(credentials) FirecrawlAuth(credentials)
assert str(exc_info.value) == expected_error assert str(exc_info.value) == expected_error
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance): def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance):
"""Test successful credential validation""" """Test successful credential validation"""
mock_response = MagicMock() mock_response = MagicMock()
@ -95,7 +95,7 @@ class TestFirecrawlAuth:
(500, "Internal server error"), (500, "Internal server error"),
], ],
) )
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance): def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance):
"""Test handling of various HTTP error codes""" """Test handling of various HTTP error codes"""
mock_response = MagicMock() mock_response = MagicMock()
@ -115,7 +115,7 @@ class TestFirecrawlAuth:
(401, "Not JSON", True, "Expecting value"), # JSON decode error (401, "Not JSON", True, "Expecting value"), # JSON decode error
], ],
) )
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_unexpected_errors( def test_should_handle_unexpected_errors(
self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance
): ):
@ -134,13 +134,13 @@ class TestFirecrawlAuth:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("exception_type", "exception_message"), ("exception_type", "exception_message"),
[ [
(requests.ConnectionError, "Network error"), (httpx.ConnectError, "Network error"),
(requests.Timeout, "Request timeout"), (httpx.TimeoutException, "Request timeout"),
(requests.ReadTimeout, "Read timeout"), (httpx.ReadTimeout, "Read timeout"),
(requests.ConnectTimeout, "Connection timeout"), (httpx.ConnectTimeout, "Connection timeout"),
], ],
) )
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance): def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance):
"""Test handling of various network-related errors including timeouts""" """Test handling of various network-related errors including timeouts"""
mock_post.side_effect = exception_type(exception_message) mock_post.side_effect = exception_type(exception_message)
@ -162,7 +162,7 @@ class TestFirecrawlAuth:
FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}}) FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}})
assert "super_secret_key_12345" not in str(exc_info.value) assert "super_secret_key_12345" not in str(exc_info.value)
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_use_custom_base_url_in_validation(self, mock_post): def test_should_use_custom_base_url_in_validation(self, mock_post):
"""Test that custom base URL is used in validation""" """Test that custom base URL is used in validation"""
mock_response = MagicMock() mock_response = MagicMock()
@ -179,12 +179,12 @@ class TestFirecrawlAuth:
assert result is True assert result is True
assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl" assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl"
@patch("services.auth.firecrawl.firecrawl.requests.post") @patch("services.auth.firecrawl.firecrawl.httpx.post")
def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance): def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance):
"""Test that timeout errors are handled gracefully with appropriate error message""" """Test that timeout errors are handled gracefully with appropriate error message"""
mock_post.side_effect = requests.Timeout("The request timed out after 30 seconds") mock_post.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
with pytest.raises(requests.Timeout) as exc_info: with pytest.raises(httpx.TimeoutException) as exc_info:
auth_instance.validate_credentials() auth_instance.validate_credentials()
# Verify the timeout exception is raised with original message # Verify the timeout exception is raised with original message

View File

@ -1,7 +1,7 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import httpx
import pytest import pytest
import requests
from services.auth.jina.jina import JinaAuth from services.auth.jina.jina import JinaAuth
@ -35,7 +35,7 @@ class TestJinaAuth:
JinaAuth(credentials) JinaAuth(credentials)
assert str(exc_info.value) == "No API key provided" assert str(exc_info.value) == "No API key provided"
@patch("services.auth.jina.jina.requests.post") @patch("services.auth.jina.jina.httpx.post")
def test_should_validate_valid_credentials_successfully(self, mock_post): def test_should_validate_valid_credentials_successfully(self, mock_post):
"""Test successful credential validation""" """Test successful credential validation"""
mock_response = MagicMock() mock_response = MagicMock()
@ -53,7 +53,7 @@ class TestJinaAuth:
json={"url": "https://example.com"}, json={"url": "https://example.com"},
) )
@patch("services.auth.jina.jina.requests.post") @patch("services.auth.jina.jina.httpx.post")
def test_should_handle_http_402_error(self, mock_post): def test_should_handle_http_402_error(self, mock_post):
"""Test handling of 402 Payment Required error""" """Test handling of 402 Payment Required error"""
mock_response = MagicMock() mock_response = MagicMock()
@ -68,7 +68,7 @@ class TestJinaAuth:
auth.validate_credentials() auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required" assert str(exc_info.value) == "Failed to authorize. Status code: 402. Error: Payment required"
@patch("services.auth.jina.jina.requests.post") @patch("services.auth.jina.jina.httpx.post")
def test_should_handle_http_409_error(self, mock_post): def test_should_handle_http_409_error(self, mock_post):
"""Test handling of 409 Conflict error""" """Test handling of 409 Conflict error"""
mock_response = MagicMock() mock_response = MagicMock()
@ -83,7 +83,7 @@ class TestJinaAuth:
auth.validate_credentials() auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error" assert str(exc_info.value) == "Failed to authorize. Status code: 409. Error: Conflict error"
@patch("services.auth.jina.jina.requests.post") @patch("services.auth.jina.jina.httpx.post")
def test_should_handle_http_500_error(self, mock_post): def test_should_handle_http_500_error(self, mock_post):
"""Test handling of 500 Internal Server Error""" """Test handling of 500 Internal Server Error"""
mock_response = MagicMock() mock_response = MagicMock()
@ -98,7 +98,7 @@ class TestJinaAuth:
auth.validate_credentials() auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error" assert str(exc_info.value) == "Failed to authorize. Status code: 500. Error: Internal server error"
@patch("services.auth.jina.jina.requests.post") @patch("services.auth.jina.jina.httpx.post")
def test_should_handle_unexpected_error_with_text_response(self, mock_post): def test_should_handle_unexpected_error_with_text_response(self, mock_post):
"""Test handling of unexpected errors with text response""" """Test handling of unexpected errors with text response"""
mock_response = MagicMock() mock_response = MagicMock()
@ -114,7 +114,7 @@ class TestJinaAuth:
auth.validate_credentials() auth.validate_credentials()
assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden" assert str(exc_info.value) == "Failed to authorize. Status code: 403. Error: Forbidden"
@patch("services.auth.jina.jina.requests.post") @patch("services.auth.jina.jina.httpx.post")
def test_should_handle_unexpected_error_without_text(self, mock_post): def test_should_handle_unexpected_error_without_text(self, mock_post):
"""Test handling of unexpected errors without text response""" """Test handling of unexpected errors without text response"""
mock_response = MagicMock() mock_response = MagicMock()
@ -130,15 +130,15 @@ class TestJinaAuth:
auth.validate_credentials() auth.validate_credentials()
assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404" assert str(exc_info.value) == "Unexpected error occurred while trying to authorize. Status code: 404"
@patch("services.auth.jina.jina.requests.post") @patch("services.auth.jina.jina.httpx.post")
def test_should_handle_network_errors(self, mock_post): def test_should_handle_network_errors(self, mock_post):
"""Test handling of network connection errors""" """Test handling of network connection errors"""
mock_post.side_effect = requests.ConnectionError("Network error") mock_post.side_effect = httpx.ConnectError("Network error")
credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}} credentials = {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}}
auth = JinaAuth(credentials) auth = JinaAuth(credentials)
with pytest.raises(requests.ConnectionError): with pytest.raises(httpx.ConnectError):
auth.validate_credentials() auth.validate_credentials()
def test_should_not_expose_api_key_in_error_messages(self): def test_should_not_expose_api_key_in_error_messages(self):

View File

@ -1,7 +1,7 @@
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import httpx
import pytest import pytest
import requests
from services.auth.watercrawl.watercrawl import WatercrawlAuth from services.auth.watercrawl.watercrawl import WatercrawlAuth
@ -64,7 +64,7 @@ class TestWatercrawlAuth:
WatercrawlAuth(credentials) WatercrawlAuth(credentials)
assert str(exc_info.value) == expected_error assert str(exc_info.value) == expected_error
@patch("services.auth.watercrawl.watercrawl.requests.get") @patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance): def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance):
"""Test successful credential validation""" """Test successful credential validation"""
mock_response = MagicMock() mock_response = MagicMock()
@ -87,7 +87,7 @@ class TestWatercrawlAuth:
(500, "Internal server error"), (500, "Internal server error"),
], ],
) )
@patch("services.auth.watercrawl.watercrawl.requests.get") @patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance): def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance):
"""Test handling of various HTTP error codes""" """Test handling of various HTTP error codes"""
mock_response = MagicMock() mock_response = MagicMock()
@ -107,7 +107,7 @@ class TestWatercrawlAuth:
(401, "Not JSON", True, "Expecting value"), # JSON decode error (401, "Not JSON", True, "Expecting value"), # JSON decode error
], ],
) )
@patch("services.auth.watercrawl.watercrawl.requests.get") @patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_handle_unexpected_errors( def test_should_handle_unexpected_errors(
self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance
): ):
@ -126,13 +126,13 @@ class TestWatercrawlAuth:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("exception_type", "exception_message"), ("exception_type", "exception_message"),
[ [
(requests.ConnectionError, "Network error"), (httpx.ConnectError, "Network error"),
(requests.Timeout, "Request timeout"), (httpx.TimeoutException, "Request timeout"),
(requests.ReadTimeout, "Read timeout"), (httpx.ReadTimeout, "Read timeout"),
(requests.ConnectTimeout, "Connection timeout"), (httpx.ConnectTimeout, "Connection timeout"),
], ],
) )
@patch("services.auth.watercrawl.watercrawl.requests.get") @patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance): def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance):
"""Test handling of various network-related errors including timeouts""" """Test handling of various network-related errors including timeouts"""
mock_get.side_effect = exception_type(exception_message) mock_get.side_effect = exception_type(exception_message)
@ -154,7 +154,7 @@ class TestWatercrawlAuth:
WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}}) WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}})
assert "super_secret_key_12345" not in str(exc_info.value) assert "super_secret_key_12345" not in str(exc_info.value)
@patch("services.auth.watercrawl.watercrawl.requests.get") @patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_use_custom_base_url_in_validation(self, mock_get): def test_should_use_custom_base_url_in_validation(self, mock_get):
"""Test that custom base URL is used in validation""" """Test that custom base URL is used in validation"""
mock_response = MagicMock() mock_response = MagicMock()
@ -179,7 +179,7 @@ class TestWatercrawlAuth:
("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"), ("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"),
], ],
) )
@patch("services.auth.watercrawl.watercrawl.requests.get") @patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url): def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url):
"""Test that urljoin is used correctly for URL construction with various base URLs""" """Test that urljoin is used correctly for URL construction with various base URLs"""
mock_response = MagicMock() mock_response = MagicMock()
@ -193,12 +193,12 @@ class TestWatercrawlAuth:
# Verify the correct URL was called # Verify the correct URL was called
assert mock_get.call_args[0][0] == expected_url assert mock_get.call_args[0][0] == expected_url
@patch("services.auth.watercrawl.watercrawl.requests.get") @patch("services.auth.watercrawl.watercrawl.httpx.get")
def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance): def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance):
"""Test that timeout errors are handled gracefully with appropriate error message""" """Test that timeout errors are handled gracefully with appropriate error message"""
mock_get.side_effect = requests.Timeout("The request timed out after 30 seconds") mock_get.side_effect = httpx.TimeoutException("The request timed out after 30 seconds")
with pytest.raises(requests.Timeout) as exc_info: with pytest.raises(httpx.TimeoutException) as exc_info:
auth_instance.validate_credentials() auth_instance.validate_credentials()
# Verify the timeout exception is raised with original message # Verify the timeout exception is raised with original message