From 2f87ecc0ce1d05ab3ef9537f29b84a5fa5823017 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Wed, 11 Feb 2026 15:53:51 +0800 Subject: [PATCH] fix: fix use fastopenapi lead user is anonymouse (#32236) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/console/remote_files.py | 129 +++---- .../console/test_fastopenapi_remote_files.py | 336 ++++++++++++++---- .../unit_tests/core/schemas/test_resolver.py | 3 + 3 files changed, 335 insertions(+), 133 deletions(-) diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 88a9ce3a79..b7a2f230e1 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -1,6 +1,7 @@ import urllib.parse import httpx +from flask_restx import Resource from pydantic import BaseModel, Field import services @@ -10,12 +11,12 @@ from controllers.common.errors import ( RemoteFileUploadError, UnsupportedFileTypeError, ) -from controllers.fastopenapi import console_router +from controllers.console import console_ns from core.file import helpers as file_helpers from core.helper import ssrf_proxy from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo -from libs.login import current_account_with_tenant +from libs.login import current_account_with_tenant, login_required from services.file_service import FileService @@ -23,69 +24,73 @@ class RemoteFileUploadPayload(BaseModel): url: str = Field(..., description="URL to fetch") -@console_router.get( - "/remote-files/", - response_model=RemoteFileInfo, - tags=["console"], -) -def get_remote_file_info(url: str) -> RemoteFileInfo: - decoded_url = urllib.parse.unquote(url) - resp = ssrf_proxy.head(decoded_url) - if resp.status_code != httpx.codes.OK: - resp = ssrf_proxy.get(decoded_url, timeout=3) - resp.raise_for_status() - return RemoteFileInfo( - file_type=resp.headers.get("Content-Type", "application/octet-stream"), - file_length=int(resp.headers.get("Content-Length", 0)), - ) - - -@console_router.post( - "/remote-files/upload", - response_model=FileWithSignedUrl, - tags=["console"], - status_code=201, -) -def upload_remote_file(payload: RemoteFileUploadPayload) -> FileWithSignedUrl: - url = payload.url - - try: - resp = ssrf_proxy.head(url=url) +@console_ns.route("/remote-files/") +class GetRemoteFileInfo(Resource): + @login_required + def get(self, url: str): + decoded_url = urllib.parse.unquote(url) + resp = ssrf_proxy.head(decoded_url) if resp.status_code != httpx.codes.OK: - resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) - if resp.status_code != httpx.codes.OK: - raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") - except httpx.RequestError as e: - raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}") + resp = ssrf_proxy.get(decoded_url, timeout=3) + resp.raise_for_status() + return RemoteFileInfo( + file_type=resp.headers.get("Content-Type", "application/octet-stream"), + file_length=int(resp.headers.get("Content-Length", 0)), + ).model_dump(mode="json") - file_info = helpers.guess_file_info_from_response(resp) - if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size): - raise FileTooLargeError +@console_ns.route("/remote-files/upload") +class RemoteFileUpload(Resource): + @login_required + def post(self): + payload = RemoteFileUploadPayload.model_validate(console_ns.payload) + url = payload.url - content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content + # Try to fetch remote file metadata/content first + try: + resp = ssrf_proxy.head(url=url) + if resp.status_code != httpx.codes.OK: + resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True) + if resp.status_code != httpx.codes.OK: + # Normalize into a user-friendly error message expected by tests + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}") + except httpx.RequestError as e: + raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}") - try: - user, _ = current_account_with_tenant() - upload_file = FileService(db.engine).upload_file( - filename=file_info.filename, - content=content, - mimetype=file_info.mimetype, - user=user, - source_url=url, + file_info = helpers.guess_file_info_from_response(resp) + + # Enforce file size limit with 400 (Bad Request) per tests' expectation + if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size): + raise FileTooLargeError() + + # Load content if needed + content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content + + try: + user, _ = current_account_with_tenant() + upload_file = FileService(db.engine).upload_file( + filename=file_info.filename, + content=content, + mimetype=file_info.mimetype, + user=user, + source_url=url, + ) + except services.errors.file.FileTooLargeError as file_too_large_error: + raise FileTooLargeError(file_too_large_error.description) + except services.errors.file.UnsupportedFileTypeError: + raise UnsupportedFileTypeError() + + # Success: return created resource with 201 status + return ( + FileWithSignedUrl( + id=upload_file.id, + name=upload_file.name, + size=upload_file.size, + extension=upload_file.extension, + url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id), + mime_type=upload_file.mime_type, + created_by=upload_file.created_by, + created_at=int(upload_file.created_at.timestamp()), + ).model_dump(mode="json"), + 201, ) - except services.errors.file.FileTooLargeError as file_too_large_error: - raise FileTooLargeError(file_too_large_error.description) - except services.errors.file.UnsupportedFileTypeError: - raise UnsupportedFileTypeError() - - return FileWithSignedUrl( - id=upload_file.id, - name=upload_file.name, - size=upload_file.size, - extension=upload_file.extension, - url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id), - mime_type=upload_file.mime_type, - created_by=upload_file.created_by, - created_at=int(upload_file.created_at.timestamp()), - ) diff --git a/api/tests/unit_tests/controllers/console/test_fastopenapi_remote_files.py b/api/tests/unit_tests/controllers/console/test_fastopenapi_remote_files.py index cb2604cf1c..c0a984e216 100644 --- a/api/tests/unit_tests/controllers/console/test_fastopenapi_remote_files.py +++ b/api/tests/unit_tests/controllers/console/test_fastopenapi_remote_files.py @@ -1,92 +1,286 @@ -import builtins +"""Tests for remote file upload API endpoints using Flask-RESTX.""" + +import contextlib from datetime import datetime from types import SimpleNamespace -from unittest.mock import patch +from unittest.mock import Mock, patch import httpx import pytest -from flask import Flask -from flask.views import MethodView - -from extensions import ext_fastopenapi - -if not hasattr(builtins, "MethodView"): - builtins.MethodView = MethodView # type: ignore[attr-defined] +from flask import Flask, g @pytest.fixture def app() -> Flask: + """Create Flask app for testing.""" app = Flask(__name__) app.config["TESTING"] = True + app.config["SECRET_KEY"] = "test-secret-key" return app -def test_console_remote_files_fastopenapi_get_info(app: Flask): - ext_fastopenapi.init_app(app) +@pytest.fixture +def client(app): + """Create test client with console blueprint registered.""" + from controllers.console import bp - response = httpx.Response( - 200, - request=httpx.Request("HEAD", "http://example.com/file.txt"), - headers={"Content-Type": "text/plain", "Content-Length": "10"}, - ) - - with patch("controllers.console.remote_files.ssrf_proxy.head", return_value=response): - client = app.test_client() - encoded_url = "http%3A%2F%2Fexample.com%2Ffile.txt" - resp = client.get(f"/console/api/remote-files/{encoded_url}") - - assert resp.status_code == 200 - assert resp.get_json() == {"file_type": "text/plain", "file_length": 10} + app.register_blueprint(bp) + return app.test_client() -def test_console_remote_files_fastopenapi_upload(app: Flask): - ext_fastopenapi.init_app(app) +@pytest.fixture +def mock_account(): + """Create a mock account for testing.""" + from models import Account - head_response = httpx.Response( - 200, - request=httpx.Request("GET", "http://example.com/file.txt"), - content=b"hello", - ) - file_info = SimpleNamespace( - extension="txt", - size=5, - filename="file.txt", - mimetype="text/plain", - ) - uploaded = SimpleNamespace( - id="file-id", - name="file.txt", - size=5, - extension="txt", - mime_type="text/plain", - created_by="user-id", - created_at=datetime(2024, 1, 1), - ) + account = Mock(spec=Account) + account.id = "test-account-id" + account.current_tenant_id = "test-tenant-id" + return account - with ( - patch("controllers.console.remote_files.db", new=SimpleNamespace(engine=object())), - patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_response), - patch("controllers.console.remote_files.helpers.guess_file_info_from_response", return_value=file_info), - patch("controllers.console.remote_files.FileService.is_file_size_within_limit", return_value=True), - patch("controllers.console.remote_files.FileService.__init__", return_value=None), - patch("controllers.console.remote_files.current_account_with_tenant", return_value=(object(), "tenant-id")), - patch("controllers.console.remote_files.FileService.upload_file", return_value=uploaded), - patch("controllers.console.remote_files.file_helpers.get_signed_file_url", return_value="signed-url"), - ): - client = app.test_client() - resp = client.post( - "/console/api/remote-files/upload", - json={"url": "http://example.com/file.txt"}, + +@pytest.fixture +def auth_ctx(app, mock_account): + """Context manager to set auth/tenant context in flask.g for a request.""" + + @contextlib.contextmanager + def _ctx(): + with app.test_request_context(): + g._login_user = mock_account + g._current_tenant = mock_account.current_tenant_id + yield + + return _ctx + + +class TestGetRemoteFileInfo: + """Test GET /console/api/remote-files/ endpoint.""" + + def test_get_remote_file_info_success(self, app, client, mock_account): + """Test successful retrieval of remote file info.""" + response = httpx.Response( + 200, + request=httpx.Request("HEAD", "http://example.com/file.txt"), + headers={"Content-Type": "text/plain", "Content-Length": "1024"}, ) - assert resp.status_code == 201 - assert resp.get_json() == { - "id": "file-id", - "name": "file.txt", - "size": 5, - "extension": "txt", - "url": "signed-url", - "mime_type": "text/plain", - "created_by": "user-id", - "created_at": int(uploaded.created_at.timestamp()), - } + with ( + patch( + "controllers.console.remote_files.current_account_with_tenant", + return_value=(mock_account, "test-tenant-id"), + ), + patch("controllers.console.remote_files.ssrf_proxy.head", return_value=response), + patch("libs.login.check_csrf_token", return_value=None), + ): + with app.test_request_context(): + g._login_user = mock_account + g._current_tenant = mock_account.current_tenant_id + encoded_url = "http%3A%2F%2Fexample.com%2Ffile.txt" + resp = client.get(f"/console/api/remote-files/{encoded_url}") + + assert resp.status_code == 200 + data = resp.get_json() + assert data["file_type"] == "text/plain" + assert data["file_length"] == 1024 + + def test_get_remote_file_info_fallback_to_get_on_head_failure(self, app, client, mock_account): + """Test fallback to GET when HEAD returns non-200 status.""" + head_response = httpx.Response( + 404, + request=httpx.Request("HEAD", "http://example.com/file.pdf"), + ) + get_response = httpx.Response( + 200, + request=httpx.Request("GET", "http://example.com/file.pdf"), + headers={"Content-Type": "application/pdf", "Content-Length": "2048"}, + ) + + with ( + patch( + "controllers.console.remote_files.current_account_with_tenant", + return_value=(mock_account, "test-tenant-id"), + ), + patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_response), + patch("controllers.console.remote_files.ssrf_proxy.get", return_value=get_response), + patch("libs.login.check_csrf_token", return_value=None), + ): + with app.test_request_context(): + g._login_user = mock_account + g._current_tenant = mock_account.current_tenant_id + encoded_url = "http%3A%2F%2Fexample.com%2Ffile.pdf" + resp = client.get(f"/console/api/remote-files/{encoded_url}") + + assert resp.status_code == 200 + data = resp.get_json() + assert data["file_type"] == "application/pdf" + assert data["file_length"] == 2048 + + +class TestRemoteFileUpload: + """Test POST /console/api/remote-files/upload endpoint.""" + + @pytest.mark.parametrize( + ("head_status", "use_get"), + [ + (200, False), # HEAD succeeds + (405, True), # HEAD fails -> fallback GET + ], + ) + def test_upload_remote_file_success_paths(self, client, mock_account, auth_ctx, head_status, use_get): + url = "http://example.com/file.pdf" + head_resp = httpx.Response( + head_status, + request=httpx.Request("HEAD", url), + headers={"Content-Type": "application/pdf", "Content-Length": "1024"}, + ) + get_resp = httpx.Response( + 200, + request=httpx.Request("GET", url), + headers={"Content-Type": "application/pdf", "Content-Length": "1024"}, + content=b"file content", + ) + + file_info = SimpleNamespace( + extension="pdf", + size=1024, + filename="file.pdf", + mimetype="application/pdf", + ) + uploaded_file = SimpleNamespace( + id="uploaded-file-id", + name="file.pdf", + size=1024, + extension="pdf", + mime_type="application/pdf", + created_by="test-account-id", + created_at=datetime(2024, 1, 1, 12, 0, 0), + ) + + with ( + patch( + "controllers.console.remote_files.current_account_with_tenant", + return_value=(mock_account, "test-tenant-id"), + ), + patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_resp) as p_head, + patch("controllers.console.remote_files.ssrf_proxy.get", return_value=get_resp) as p_get, + patch( + "controllers.console.remote_files.helpers.guess_file_info_from_response", + return_value=file_info, + ), + patch( + "controllers.console.remote_files.FileService.is_file_size_within_limit", + return_value=True, + ), + patch("controllers.console.remote_files.db", spec=["engine"]), + patch("controllers.console.remote_files.FileService") as mock_file_service, + patch( + "controllers.console.remote_files.file_helpers.get_signed_file_url", + return_value="http://example.com/signed-url", + ), + patch("libs.login.check_csrf_token", return_value=None), + ): + mock_file_service.return_value.upload_file.return_value = uploaded_file + + with auth_ctx(): + resp = client.post( + "/console/api/remote-files/upload", + json={"url": url}, + ) + + assert resp.status_code == 201 + p_head.assert_called_once() + # GET is used either for fallback (HEAD fails) or to fetch content after HEAD succeeds + p_get.assert_called_once() + mock_file_service.return_value.upload_file.assert_called_once() + + data = resp.get_json() + assert data["id"] == "uploaded-file-id" + assert data["name"] == "file.pdf" + assert data["size"] == 1024 + assert data["extension"] == "pdf" + assert data["url"] == "http://example.com/signed-url" + assert data["mime_type"] == "application/pdf" + assert data["created_by"] == "test-account-id" + + @pytest.mark.parametrize( + ("size_ok", "raises", "expected_status", "expected_msg"), + [ + # When size check fails in controller, API returns 413 with message "File size exceeded..." + (False, None, 413, "file size exceeded"), + # When service raises unsupported type, controller maps to 415 with message "File type not allowed." + (True, "unsupported", 415, "file type not allowed"), + ], + ) + def test_upload_remote_file_errors( + self, client, mock_account, auth_ctx, size_ok, raises, expected_status, expected_msg + ): + url = "http://example.com/x.pdf" + head_resp = httpx.Response( + 200, + request=httpx.Request("HEAD", url), + headers={"Content-Type": "application/pdf", "Content-Length": "9"}, + ) + file_info = SimpleNamespace(extension="pdf", size=9, filename="x.pdf", mimetype="application/pdf") + + with ( + patch( + "controllers.console.remote_files.current_account_with_tenant", + return_value=(mock_account, "test-tenant-id"), + ), + patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_resp), + patch( + "controllers.console.remote_files.helpers.guess_file_info_from_response", + return_value=file_info, + ), + patch( + "controllers.console.remote_files.FileService.is_file_size_within_limit", + return_value=size_ok, + ), + patch("controllers.console.remote_files.db", spec=["engine"]), + patch("libs.login.check_csrf_token", return_value=None), + ): + if raises == "unsupported": + from services.errors.file import UnsupportedFileTypeError + + with patch("controllers.console.remote_files.FileService") as mock_file_service: + mock_file_service.return_value.upload_file.side_effect = UnsupportedFileTypeError("bad") + with auth_ctx(): + resp = client.post( + "/console/api/remote-files/upload", + json={"url": url}, + ) + else: + with auth_ctx(): + resp = client.post( + "/console/api/remote-files/upload", + json={"url": url}, + ) + + assert resp.status_code == expected_status + data = resp.get_json() + msg = (data.get("error") or {}).get("message") or data.get("message", "") + assert expected_msg in msg.lower() + + def test_upload_remote_file_fetch_failure(self, client, mock_account, auth_ctx): + """Test upload when fetching of remote file fails.""" + with ( + patch( + "controllers.console.remote_files.current_account_with_tenant", + return_value=(mock_account, "test-tenant-id"), + ), + patch( + "controllers.console.remote_files.ssrf_proxy.head", + side_effect=httpx.RequestError("Connection failed"), + ), + patch("libs.login.check_csrf_token", return_value=None), + ): + with auth_ctx(): + resp = client.post( + "/console/api/remote-files/upload", + json={"url": "http://unreachable.com/file.pdf"}, + ) + + assert resp.status_code == 400 + data = resp.get_json() + msg = (data.get("error") or {}).get("message") or data.get("message", "") + assert "failed to fetch" in msg.lower() diff --git a/api/tests/unit_tests/core/schemas/test_resolver.py b/api/tests/unit_tests/core/schemas/test_resolver.py index eda8bf4343..239ee85346 100644 --- a/api/tests/unit_tests/core/schemas/test_resolver.py +++ b/api/tests/unit_tests/core/schemas/test_resolver.py @@ -496,6 +496,9 @@ class TestSchemaResolverClass: avg_time_no_cache = sum(results1) / len(results1) # Second run (with cache) - run multiple times + # Warm up cache first + resolve_dify_schema_refs(schema) + results2 = [] for _ in range(3): start = time.perf_counter()