From 40faa9ce16287cc11485e34a119eb50d66348419 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Thu, 14 Aug 2025 13:52:44 +0800 Subject: [PATCH] refactor(api): Inject db dependency to FileService --- api/controllers/console/files.py | 5 +- api/controllers/console/remote_files.py | 3 +- .../console/workspace/workspace.py | 2 +- api/controllers/files/image_preview.py | 7 ++- api/controllers/service_api/app/file.py | 4 +- .../service_api/dataset/document.py | 8 +-- api/controllers/web/files.py | 3 +- api/controllers/web/remote_files.py | 3 +- api/services/file_service.py | 60 +++++++++++-------- 9 files changed, 55 insertions(+), 40 deletions(-) diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py index 101a49a32e..74543e3a10 100644 --- a/api/controllers/console/files.py +++ b/api/controllers/console/files.py @@ -22,6 +22,7 @@ from controllers.console.wraps import ( ) from fields.file_fields import file_fields, upload_config_fields from libs.login import login_required +from models import db from services.file_service import FileService PREVIEW_WORDS_LIMIT = 3000 @@ -68,7 +69,7 @@ class FileApi(Resource): source = None try: - upload_file = FileService.upload_file( + upload_file = FileService(db.engine).upload_file( filename=file.filename, content=file.read(), mimetype=file.mimetype, @@ -89,7 +90,7 @@ class FilePreviewApi(Resource): @account_initialization_required def get(self, file_id): file_id = str(file_id) - text = FileService.get_file_preview(file_id) + text = FileService(db.engine).get_file_preview(file_id) return {"content": text} diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 73014cfc97..e0a31bccbb 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -15,6 +15,7 @@ from controllers.common.errors import ( from core.file import helpers as file_helpers from core.helper import ssrf_proxy from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields +from models import db from models.account import Account from services.file_service import FileService @@ -61,7 +62,7 @@ class RemoteFileUploadApi(Resource): try: user = cast(Account, current_user) - upload_file = FileService.upload_file( + upload_file = FileService(db.engine).upload_file( filename=file_info.filename, content=content, mimetype=file_info.mimetype, diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index e7a3aca66c..3c8299b2a1 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -211,7 +211,7 @@ class WebappLogoWorkspaceApi(Resource): raise UnsupportedFileTypeError() try: - upload_file = FileService.upload_file( + upload_file = FileService(db.engine).upload_file( filename=file.filename, content=file.read(), mimetype=file.mimetype, diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index 48baac6556..643441a310 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound import services from controllers.common.errors import UnsupportedFileTypeError from controllers.files import files_ns +from models import db from services.account_service import TenantService from services.file_service import FileService @@ -28,7 +29,7 @@ class ImagePreviewApi(Resource): return {"content": "Invalid request."}, 400 try: - generator, mimetype = FileService.get_image_preview( + generator, mimetype = FileService(db.engine).get_image_preview( file_id=file_id, timestamp=timestamp, nonce=nonce, @@ -57,7 +58,7 @@ class FilePreviewApi(Resource): return {"content": "Invalid request."}, 400 try: - generator, upload_file = FileService.get_file_generator_by_file_id( + generator, upload_file = FileService(db.engine).get_file_generator_by_file_id( file_id=file_id, timestamp=args["timestamp"], nonce=args["nonce"], @@ -108,7 +109,7 @@ class WorkspaceWebappLogoApi(Resource): raise NotFound("webapp logo is not found") try: - generator, mimetype = FileService.get_public_image_preview( + generator, mimetype = FileService(db.engine).get_public_image_preview( webapp_logo_file_id, ) except services.errors.file.UnsupportedFileTypeError: diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index 05f27545b3..b8da10bec3 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -13,7 +13,7 @@ from controllers.common.errors import ( from controllers.service_api import service_api_ns from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from fields.file_fields import build_file_model -from models.model import App, EndUser +from models import App, EndUser, db from services.file_service import FileService @@ -52,7 +52,7 @@ class FileApi(Resource): raise FilenameNotExistsError try: - upload_file = FileService.upload_file( + upload_file = FileService(db.engine).upload_file( filename=file.filename, content=file.read(), mimetype=file.mimetype, diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 43232229c8..7494f71c23 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -123,7 +123,7 @@ class DocumentAddByTextApi(DatasetApiResource): args.get("retrieval_model").get("reranking_model").get("reranking_model_name"), ) - upload_file = FileService.upload_text(text=str(text), text_name=str(name)) + upload_file = FileService(db.engine).upload_text(text=str(text), text_name=str(name)) data_source = { "type": "upload_file", "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, @@ -198,7 +198,7 @@ class DocumentUpdateByTextApi(DatasetApiResource): name = args.get("name") if text is None or name is None: raise ValueError("Both text and name must be strings.") - upload_file = FileService.upload_text(text=str(text), text_name=str(name)) + upload_file = FileService(db.engine).upload_text(text=str(text), text_name=str(name)) data_source = { "type": "upload_file", "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, @@ -298,7 +298,7 @@ class DocumentAddByFileApi(DatasetApiResource): if not file.filename: raise FilenameNotExistsError - upload_file = FileService.upload_file( + upload_file = FileService(db.engine).upload_file( filename=file.filename, content=file.read(), mimetype=file.mimetype, @@ -387,7 +387,7 @@ class DocumentUpdateByFileApi(DatasetApiResource): raise FilenameNotExistsError try: - upload_file = FileService.upload_file( + upload_file = FileService(db.engine).upload_file( filename=file.filename, content=file.read(), mimetype=file.mimetype, diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py index 7508874fae..49b86a355d 100644 --- a/api/controllers/web/files.py +++ b/api/controllers/web/files.py @@ -12,6 +12,7 @@ from controllers.common.errors import ( from controllers.web import web_ns from controllers.web.wraps import WebApiResource from fields.file_fields import build_file_model +from models import db from services.file_service import FileService @@ -68,7 +69,7 @@ class FileApi(WebApiResource): source = None try: - upload_file = FileService.upload_file( + upload_file = FileService(db.engine).upload_file( filename=file.filename, content=file.read(), mimetype=file.mimetype, diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index ab20c7667c..2d84e52c5e 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -15,6 +15,7 @@ from controllers.web.wraps import WebApiResource from core.file import helpers as file_helpers from core.helper import ssrf_proxy from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model +from models import db from services.file_service import FileService @@ -119,7 +120,7 @@ class RemoteFileUploadApi(WebApiResource): content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content try: - upload_file = FileService.upload_file( + upload_file = FileService(db.engine).upload_file( filename=file_info.filename, content=content, mimetype=file_info.mimetype, diff --git a/api/services/file_service.py b/api/services/file_service.py index 4c0a0f451c..dbab4cf827 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -4,6 +4,8 @@ import uuid from typing import Any, Literal, Union from flask_login import current_user +from sqlalchemy import Engine +from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import NotFound from configs import dify_config @@ -15,7 +17,6 @@ from constants import ( ) from core.file import helpers as file_helpers from core.rag.extractor.extract_processor import ExtractProcessor -from extensions.ext_database import db from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id @@ -29,8 +30,18 @@ PREVIEW_WORDS_LIMIT = 3000 class FileService: - @staticmethod + _session_maker: sessionmaker + + def __init__(self, session_factory: sessionmaker | Engine | None = None): + if isinstance(session_factory, Engine): + self._session_maker = sessionmaker(bind=session_factory) + elif isinstance(session_factory, sessionmaker): + self._session_maker = session_factory + else: + raise AssertionError("must be a sessionmaker or an Engine.") + def upload_file( + self, *, filename: str, content: bytes, @@ -85,14 +96,14 @@ class FileService: hash=hashlib.sha3_256(content).hexdigest(), source_url=source_url, ) - - db.session.add(upload_file) - db.session.commit() - + # The `UploadFile` ID is generated within its constructor, so flushing to retrieve the ID is unnecessary. + # We can directly generate the `source_url` here before committing. if not upload_file.source_url: upload_file.source_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id) - db.session.add(upload_file) - db.session.commit() + + with self._session_maker(expire_on_commit=False) as session: + session.add(upload_file) + session.commit() return upload_file @@ -109,8 +120,7 @@ class FileService: return file_size <= file_size_limit - @staticmethod - def upload_text(text: str, text_name: str) -> UploadFile: + def upload_text(self, text: str, text_name: str) -> UploadFile: if len(text_name) > 200: text_name = text_name[:200] # user uuid as file name @@ -137,14 +147,15 @@ class FileService: used_at=naive_utc_now(), ) - db.session.add(upload_file) - db.session.commit() + with self._session_maker(expire_on_commit=False) as session: + session.add(upload_file) + session.commit() return upload_file - @staticmethod - def get_file_preview(file_id: str): - upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + def get_file_preview(self, file_id: str): + with self._session_maker(expire_on_commit=False) as session: + upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found") @@ -159,15 +170,14 @@ class FileService: return text - @staticmethod - def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str): + def get_image_preview(self, file_id: str, timestamp: str, nonce: str, sign: str): result = file_helpers.verify_image_signature( upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign ) if not result: raise NotFound("File not found or signature is invalid") - - upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + with self._session_maker(expire_on_commit=False) as session: + upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found or signature is invalid") @@ -181,13 +191,13 @@ class FileService: return generator, upload_file.mime_type - @staticmethod - def get_file_generator_by_file_id(file_id: str, timestamp: str, nonce: str, sign: str): + def get_file_generator_by_file_id(self, file_id: str, timestamp: str, nonce: str, sign: str): result = file_helpers.verify_file_signature(upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign) if not result: raise NotFound("File not found or signature is invalid") - upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + with self._session_maker(expire_on_commit=False) as session: + upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found or signature is invalid") @@ -196,9 +206,9 @@ class FileService: return generator, upload_file - @staticmethod - def get_public_image_preview(file_id: str): - upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + def get_public_image_preview(self, file_id: str): + with self._session_maker(expire_on_commit=False) as session: + upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first() if not upload_file: raise NotFound("File not found or signature is invalid")