fix: fix use fastopenapi lead user is anonymouse (#32236)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wangxiaolei 2026-02-11 15:53:51 +08:00 committed by GitHub
parent 5b4c7b2a40
commit 2f87ecc0ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 335 additions and 133 deletions

View File

@ -1,6 +1,7 @@
import urllib.parse import urllib.parse
import httpx import httpx
from flask_restx import Resource
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
import services import services
@ -10,12 +11,12 @@ from controllers.common.errors import (
RemoteFileUploadError, RemoteFileUploadError,
UnsupportedFileTypeError, UnsupportedFileTypeError,
) )
from controllers.fastopenapi import console_router from controllers.console import console_ns
from core.file import helpers as file_helpers from core.file import helpers as file_helpers
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from extensions.ext_database import db from extensions.ext_database import db
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from libs.login import current_account_with_tenant from libs.login import current_account_with_tenant, login_required
from services.file_service import FileService from services.file_service import FileService
@ -23,12 +24,10 @@ class RemoteFileUploadPayload(BaseModel):
url: str = Field(..., description="URL to fetch") url: str = Field(..., description="URL to fetch")
@console_router.get( @console_ns.route("/remote-files/<path:url>")
"/remote-files/<path:url>", class GetRemoteFileInfo(Resource):
response_model=RemoteFileInfo, @login_required
tags=["console"], def get(self, url: str):
)
def get_remote_file_info(url: str) -> RemoteFileInfo:
decoded_url = urllib.parse.unquote(url) decoded_url = urllib.parse.unquote(url)
resp = ssrf_proxy.head(decoded_url) resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK: if resp.status_code != httpx.codes.OK:
@ -37,32 +36,34 @@ def get_remote_file_info(url: str) -> RemoteFileInfo:
return RemoteFileInfo( return RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"), file_type=resp.headers.get("Content-Type", "application/octet-stream"),
file_length=int(resp.headers.get("Content-Length", 0)), file_length=int(resp.headers.get("Content-Length", 0)),
) ).model_dump(mode="json")
@console_router.post( @console_ns.route("/remote-files/upload")
"/remote-files/upload", class RemoteFileUpload(Resource):
response_model=FileWithSignedUrl, @login_required
tags=["console"], def post(self):
status_code=201, payload = RemoteFileUploadPayload.model_validate(console_ns.payload)
)
def upload_remote_file(payload: RemoteFileUploadPayload) -> FileWithSignedUrl:
url = payload.url url = payload.url
# Try to fetch remote file metadata/content first
try: try:
resp = ssrf_proxy.head(url=url) resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK: if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK: if resp.status_code != httpx.codes.OK:
# Normalize into a user-friendly error message expected by tests
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as e: except httpx.RequestError as e:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}") raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
file_info = helpers.guess_file_info_from_response(resp) file_info = helpers.guess_file_info_from_response(resp)
# Enforce file size limit with 400 (Bad Request) per tests' expectation
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size): if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
raise FileTooLargeError raise FileTooLargeError()
# Load content if needed
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try: try:
@ -79,7 +80,9 @@ def upload_remote_file(payload: RemoteFileUploadPayload) -> FileWithSignedUrl:
except services.errors.file.UnsupportedFileTypeError: except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError() raise UnsupportedFileTypeError()
return FileWithSignedUrl( # Success: return created resource with 201 status
return (
FileWithSignedUrl(
id=upload_file.id, id=upload_file.id,
name=upload_file.name, name=upload_file.name,
size=upload_file.size, size=upload_file.size,
@ -88,4 +91,6 @@ def upload_remote_file(payload: RemoteFileUploadPayload) -> FileWithSignedUrl:
mime_type=upload_file.mime_type, mime_type=upload_file.mime_type,
created_by=upload_file.created_by, created_by=upload_file.created_by,
created_at=int(upload_file.created_at.timestamp()), created_at=int(upload_file.created_at.timestamp()),
).model_dump(mode="json"),
201,
) )

View File

@ -1,92 +1,286 @@
import builtins """Tests for remote file upload API endpoints using Flask-RESTX."""
import contextlib
from datetime import datetime from datetime import datetime
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import patch from unittest.mock import Mock, patch
import httpx import httpx
import pytest import pytest
from flask import Flask from flask import Flask, g
from flask.views import MethodView
from extensions import ext_fastopenapi
if not hasattr(builtins, "MethodView"):
builtins.MethodView = MethodView # type: ignore[attr-defined]
@pytest.fixture @pytest.fixture
def app() -> Flask: def app() -> Flask:
"""Create Flask app for testing."""
app = Flask(__name__) app = Flask(__name__)
app.config["TESTING"] = True app.config["TESTING"] = True
app.config["SECRET_KEY"] = "test-secret-key"
return app return app
def test_console_remote_files_fastopenapi_get_info(app: Flask): @pytest.fixture
ext_fastopenapi.init_app(app) def client(app):
"""Create test client with console blueprint registered."""
from controllers.console import bp
app.register_blueprint(bp)
return app.test_client()
@pytest.fixture
def mock_account():
"""Create a mock account for testing."""
from models import Account
account = Mock(spec=Account)
account.id = "test-account-id"
account.current_tenant_id = "test-tenant-id"
return account
@pytest.fixture
def auth_ctx(app, mock_account):
"""Context manager to set auth/tenant context in flask.g for a request."""
@contextlib.contextmanager
def _ctx():
with app.test_request_context():
g._login_user = mock_account
g._current_tenant = mock_account.current_tenant_id
yield
return _ctx
class TestGetRemoteFileInfo:
"""Test GET /console/api/remote-files/<path:url> endpoint."""
def test_get_remote_file_info_success(self, app, client, mock_account):
"""Test successful retrieval of remote file info."""
response = httpx.Response( response = httpx.Response(
200, 200,
request=httpx.Request("HEAD", "http://example.com/file.txt"), request=httpx.Request("HEAD", "http://example.com/file.txt"),
headers={"Content-Type": "text/plain", "Content-Length": "10"}, headers={"Content-Type": "text/plain", "Content-Length": "1024"},
) )
with patch("controllers.console.remote_files.ssrf_proxy.head", return_value=response): with (
client = app.test_client() patch(
"controllers.console.remote_files.current_account_with_tenant",
return_value=(mock_account, "test-tenant-id"),
),
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=response),
patch("libs.login.check_csrf_token", return_value=None),
):
with app.test_request_context():
g._login_user = mock_account
g._current_tenant = mock_account.current_tenant_id
encoded_url = "http%3A%2F%2Fexample.com%2Ffile.txt" encoded_url = "http%3A%2F%2Fexample.com%2Ffile.txt"
resp = client.get(f"/console/api/remote-files/{encoded_url}") resp = client.get(f"/console/api/remote-files/{encoded_url}")
assert resp.status_code == 200 assert resp.status_code == 200
assert resp.get_json() == {"file_type": "text/plain", "file_length": 10} data = resp.get_json()
assert data["file_type"] == "text/plain"
assert data["file_length"] == 1024
def test_console_remote_files_fastopenapi_upload(app: Flask):
ext_fastopenapi.init_app(app)
def test_get_remote_file_info_fallback_to_get_on_head_failure(self, app, client, mock_account):
"""Test fallback to GET when HEAD returns non-200 status."""
head_response = httpx.Response( head_response = httpx.Response(
404,
request=httpx.Request("HEAD", "http://example.com/file.pdf"),
)
get_response = httpx.Response(
200, 200,
request=httpx.Request("GET", "http://example.com/file.txt"), request=httpx.Request("GET", "http://example.com/file.pdf"),
content=b"hello", headers={"Content-Type": "application/pdf", "Content-Length": "2048"},
)
file_info = SimpleNamespace(
extension="txt",
size=5,
filename="file.txt",
mimetype="text/plain",
)
uploaded = SimpleNamespace(
id="file-id",
name="file.txt",
size=5,
extension="txt",
mime_type="text/plain",
created_by="user-id",
created_at=datetime(2024, 1, 1),
) )
with ( with (
patch("controllers.console.remote_files.db", new=SimpleNamespace(engine=object())), patch(
"controllers.console.remote_files.current_account_with_tenant",
return_value=(mock_account, "test-tenant-id"),
),
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_response), patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_response),
patch("controllers.console.remote_files.helpers.guess_file_info_from_response", return_value=file_info), patch("controllers.console.remote_files.ssrf_proxy.get", return_value=get_response),
patch("controllers.console.remote_files.FileService.is_file_size_within_limit", return_value=True), patch("libs.login.check_csrf_token", return_value=None),
patch("controllers.console.remote_files.FileService.__init__", return_value=None),
patch("controllers.console.remote_files.current_account_with_tenant", return_value=(object(), "tenant-id")),
patch("controllers.console.remote_files.FileService.upload_file", return_value=uploaded),
patch("controllers.console.remote_files.file_helpers.get_signed_file_url", return_value="signed-url"),
): ):
client = app.test_client() with app.test_request_context():
g._login_user = mock_account
g._current_tenant = mock_account.current_tenant_id
encoded_url = "http%3A%2F%2Fexample.com%2Ffile.pdf"
resp = client.get(f"/console/api/remote-files/{encoded_url}")
assert resp.status_code == 200
data = resp.get_json()
assert data["file_type"] == "application/pdf"
assert data["file_length"] == 2048
class TestRemoteFileUpload:
"""Test POST /console/api/remote-files/upload endpoint."""
@pytest.mark.parametrize(
("head_status", "use_get"),
[
(200, False), # HEAD succeeds
(405, True), # HEAD fails -> fallback GET
],
)
def test_upload_remote_file_success_paths(self, client, mock_account, auth_ctx, head_status, use_get):
url = "http://example.com/file.pdf"
head_resp = httpx.Response(
head_status,
request=httpx.Request("HEAD", url),
headers={"Content-Type": "application/pdf", "Content-Length": "1024"},
)
get_resp = httpx.Response(
200,
request=httpx.Request("GET", url),
headers={"Content-Type": "application/pdf", "Content-Length": "1024"},
content=b"file content",
)
file_info = SimpleNamespace(
extension="pdf",
size=1024,
filename="file.pdf",
mimetype="application/pdf",
)
uploaded_file = SimpleNamespace(
id="uploaded-file-id",
name="file.pdf",
size=1024,
extension="pdf",
mime_type="application/pdf",
created_by="test-account-id",
created_at=datetime(2024, 1, 1, 12, 0, 0),
)
with (
patch(
"controllers.console.remote_files.current_account_with_tenant",
return_value=(mock_account, "test-tenant-id"),
),
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_resp) as p_head,
patch("controllers.console.remote_files.ssrf_proxy.get", return_value=get_resp) as p_get,
patch(
"controllers.console.remote_files.helpers.guess_file_info_from_response",
return_value=file_info,
),
patch(
"controllers.console.remote_files.FileService.is_file_size_within_limit",
return_value=True,
),
patch("controllers.console.remote_files.db", spec=["engine"]),
patch("controllers.console.remote_files.FileService") as mock_file_service,
patch(
"controllers.console.remote_files.file_helpers.get_signed_file_url",
return_value="http://example.com/signed-url",
),
patch("libs.login.check_csrf_token", return_value=None),
):
mock_file_service.return_value.upload_file.return_value = uploaded_file
with auth_ctx():
resp = client.post( resp = client.post(
"/console/api/remote-files/upload", "/console/api/remote-files/upload",
json={"url": "http://example.com/file.txt"}, json={"url": url},
) )
assert resp.status_code == 201 assert resp.status_code == 201
assert resp.get_json() == { p_head.assert_called_once()
"id": "file-id", # GET is used either for fallback (HEAD fails) or to fetch content after HEAD succeeds
"name": "file.txt", p_get.assert_called_once()
"size": 5, mock_file_service.return_value.upload_file.assert_called_once()
"extension": "txt",
"url": "signed-url", data = resp.get_json()
"mime_type": "text/plain", assert data["id"] == "uploaded-file-id"
"created_by": "user-id", assert data["name"] == "file.pdf"
"created_at": int(uploaded.created_at.timestamp()), assert data["size"] == 1024
} assert data["extension"] == "pdf"
assert data["url"] == "http://example.com/signed-url"
assert data["mime_type"] == "application/pdf"
assert data["created_by"] == "test-account-id"
@pytest.mark.parametrize(
("size_ok", "raises", "expected_status", "expected_msg"),
[
# When size check fails in controller, API returns 413 with message "File size exceeded..."
(False, None, 413, "file size exceeded"),
# When service raises unsupported type, controller maps to 415 with message "File type not allowed."
(True, "unsupported", 415, "file type not allowed"),
],
)
def test_upload_remote_file_errors(
self, client, mock_account, auth_ctx, size_ok, raises, expected_status, expected_msg
):
url = "http://example.com/x.pdf"
head_resp = httpx.Response(
200,
request=httpx.Request("HEAD", url),
headers={"Content-Type": "application/pdf", "Content-Length": "9"},
)
file_info = SimpleNamespace(extension="pdf", size=9, filename="x.pdf", mimetype="application/pdf")
with (
patch(
"controllers.console.remote_files.current_account_with_tenant",
return_value=(mock_account, "test-tenant-id"),
),
patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_resp),
patch(
"controllers.console.remote_files.helpers.guess_file_info_from_response",
return_value=file_info,
),
patch(
"controllers.console.remote_files.FileService.is_file_size_within_limit",
return_value=size_ok,
),
patch("controllers.console.remote_files.db", spec=["engine"]),
patch("libs.login.check_csrf_token", return_value=None),
):
if raises == "unsupported":
from services.errors.file import UnsupportedFileTypeError
with patch("controllers.console.remote_files.FileService") as mock_file_service:
mock_file_service.return_value.upload_file.side_effect = UnsupportedFileTypeError("bad")
with auth_ctx():
resp = client.post(
"/console/api/remote-files/upload",
json={"url": url},
)
else:
with auth_ctx():
resp = client.post(
"/console/api/remote-files/upload",
json={"url": url},
)
assert resp.status_code == expected_status
data = resp.get_json()
msg = (data.get("error") or {}).get("message") or data.get("message", "")
assert expected_msg in msg.lower()
def test_upload_remote_file_fetch_failure(self, client, mock_account, auth_ctx):
"""Test upload when fetching of remote file fails."""
with (
patch(
"controllers.console.remote_files.current_account_with_tenant",
return_value=(mock_account, "test-tenant-id"),
),
patch(
"controllers.console.remote_files.ssrf_proxy.head",
side_effect=httpx.RequestError("Connection failed"),
),
patch("libs.login.check_csrf_token", return_value=None),
):
with auth_ctx():
resp = client.post(
"/console/api/remote-files/upload",
json={"url": "http://unreachable.com/file.pdf"},
)
assert resp.status_code == 400
data = resp.get_json()
msg = (data.get("error") or {}).get("message") or data.get("message", "")
assert "failed to fetch" in msg.lower()

View File

@ -496,6 +496,9 @@ class TestSchemaResolverClass:
avg_time_no_cache = sum(results1) / len(results1) avg_time_no_cache = sum(results1) / len(results1)
# Second run (with cache) - run multiple times # Second run (with cache) - run multiple times
# Warm up cache first
resolve_dify_schema_refs(schema)
results2 = [] results2 = []
for _ in range(3): for _ in range(3):
start = time.perf_counter() start = time.perf_counter()