refactor(api): Inject db dependency to FileService

This commit is contained in:
QuantumGhost 2025-08-14 13:52:44 +08:00
parent 58dfae60f0
commit 40faa9ce16
9 changed files with 55 additions and 40 deletions

View File

@ -22,6 +22,7 @@ from controllers.console.wraps import (
)
from fields.file_fields import file_fields, upload_config_fields
from libs.login import login_required
from models import db
from services.file_service import FileService
PREVIEW_WORDS_LIMIT = 3000
@ -68,7 +69,7 @@ class FileApi(Resource):
source = None
try:
upload_file = FileService.upload_file(
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,
@ -89,7 +90,7 @@ class FilePreviewApi(Resource):
@account_initialization_required
def get(self, file_id):
file_id = str(file_id)
text = FileService.get_file_preview(file_id)
text = FileService(db.engine).get_file_preview(file_id)
return {"content": text}

View File

@ -15,6 +15,7 @@ from controllers.common.errors import (
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
from models import db
from models.account import Account
from services.file_service import FileService
@ -61,7 +62,7 @@ class RemoteFileUploadApi(Resource):
try:
user = cast(Account, current_user)
upload_file = FileService.upload_file(
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,

View File

@ -211,7 +211,7 @@ class WebappLogoWorkspaceApi(Resource):
raise UnsupportedFileTypeError()
try:
upload_file = FileService.upload_file(
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,

View File

@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
import services
from controllers.common.errors import UnsupportedFileTypeError
from controllers.files import files_ns
from models import db
from services.account_service import TenantService
from services.file_service import FileService
@ -28,7 +29,7 @@ class ImagePreviewApi(Resource):
return {"content": "Invalid request."}, 400
try:
generator, mimetype = FileService.get_image_preview(
generator, mimetype = FileService(db.engine).get_image_preview(
file_id=file_id,
timestamp=timestamp,
nonce=nonce,
@ -57,7 +58,7 @@ class FilePreviewApi(Resource):
return {"content": "Invalid request."}, 400
try:
generator, upload_file = FileService.get_file_generator_by_file_id(
generator, upload_file = FileService(db.engine).get_file_generator_by_file_id(
file_id=file_id,
timestamp=args["timestamp"],
nonce=args["nonce"],
@ -108,7 +109,7 @@ class WorkspaceWebappLogoApi(Resource):
raise NotFound("webapp logo is not found")
try:
generator, mimetype = FileService.get_public_image_preview(
generator, mimetype = FileService(db.engine).get_public_image_preview(
webapp_logo_file_id,
)
except services.errors.file.UnsupportedFileTypeError:

View File

@ -13,7 +13,7 @@ from controllers.common.errors import (
from controllers.service_api import service_api_ns
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from fields.file_fields import build_file_model
from models.model import App, EndUser
from models import App, EndUser, db
from services.file_service import FileService
@ -52,7 +52,7 @@ class FileApi(Resource):
raise FilenameNotExistsError
try:
upload_file = FileService.upload_file(
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,

View File

@ -123,7 +123,7 @@ class DocumentAddByTextApi(DatasetApiResource):
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
)
upload_file = FileService.upload_text(text=str(text), text_name=str(name))
upload_file = FileService(db.engine).upload_text(text=str(text), text_name=str(name))
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
@ -198,7 +198,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
name = args.get("name")
if text is None or name is None:
raise ValueError("Both text and name must be strings.")
upload_file = FileService.upload_text(text=str(text), text_name=str(name))
upload_file = FileService(db.engine).upload_text(text=str(text), text_name=str(name))
data_source = {
"type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
@ -298,7 +298,7 @@ class DocumentAddByFileApi(DatasetApiResource):
if not file.filename:
raise FilenameNotExistsError
upload_file = FileService.upload_file(
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,
@ -387,7 +387,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
raise FilenameNotExistsError
try:
upload_file = FileService.upload_file(
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,

View File

@ -12,6 +12,7 @@ from controllers.common.errors import (
from controllers.web import web_ns
from controllers.web.wraps import WebApiResource
from fields.file_fields import build_file_model
from models import db
from services.file_service import FileService
@ -68,7 +69,7 @@ class FileApi(WebApiResource):
source = None
try:
upload_file = FileService.upload_file(
upload_file = FileService(db.engine).upload_file(
filename=file.filename,
content=file.read(),
mimetype=file.mimetype,

View File

@ -15,6 +15,7 @@ from controllers.web.wraps import WebApiResource
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
from models import db
from services.file_service import FileService
@ -119,7 +120,7 @@ class RemoteFileUploadApi(WebApiResource):
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
upload_file = FileService.upload_file(
upload_file = FileService(db.engine).upload_file(
filename=file_info.filename,
content=content,
mimetype=file_info.mimetype,

View File

@ -4,6 +4,8 @@ import uuid
from typing import Any, Literal, Union
from flask_login import current_user
from sqlalchemy import Engine
from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound
from configs import dify_config
@ -15,7 +17,6 @@ from constants import (
)
from core.file import helpers as file_helpers
from core.rag.extractor.extract_processor import ExtractProcessor
from extensions.ext_database import db
from extensions.ext_storage import storage
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id
@ -29,8 +30,18 @@ PREVIEW_WORDS_LIMIT = 3000
class FileService:
@staticmethod
_session_maker: sessionmaker
def __init__(self, session_factory: sessionmaker | Engine | None = None):
if isinstance(session_factory, Engine):
self._session_maker = sessionmaker(bind=session_factory)
elif isinstance(session_factory, sessionmaker):
self._session_maker = session_factory
else:
raise AssertionError("must be a sessionmaker or an Engine.")
def upload_file(
self,
*,
filename: str,
content: bytes,
@ -85,14 +96,14 @@ class FileService:
hash=hashlib.sha3_256(content).hexdigest(),
source_url=source_url,
)
db.session.add(upload_file)
db.session.commit()
# The `UploadFile` ID is generated within its constructor, so flushing to retrieve the ID is unnecessary.
# We can directly generate the `source_url` here before committing.
if not upload_file.source_url:
upload_file.source_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
db.session.add(upload_file)
db.session.commit()
with self._session_maker(expire_on_commit=False) as session:
session.add(upload_file)
session.commit()
return upload_file
@ -109,8 +120,7 @@ class FileService:
return file_size <= file_size_limit
@staticmethod
def upload_text(text: str, text_name: str) -> UploadFile:
def upload_text(self, text: str, text_name: str) -> UploadFile:
if len(text_name) > 200:
text_name = text_name[:200]
# user uuid as file name
@ -137,14 +147,15 @@ class FileService:
used_at=naive_utc_now(),
)
db.session.add(upload_file)
db.session.commit()
with self._session_maker(expire_on_commit=False) as session:
session.add(upload_file)
session.commit()
return upload_file
@staticmethod
def get_file_preview(file_id: str):
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
def get_file_preview(self, file_id: str):
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found")
@ -159,15 +170,14 @@ class FileService:
return text
@staticmethod
def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str):
def get_image_preview(self, file_id: str, timestamp: str, nonce: str, sign: str):
result = file_helpers.verify_image_signature(
upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign
)
if not result:
raise NotFound("File not found or signature is invalid")
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found or signature is invalid")
@ -181,13 +191,13 @@ class FileService:
return generator, upload_file.mime_type
@staticmethod
def get_file_generator_by_file_id(file_id: str, timestamp: str, nonce: str, sign: str):
def get_file_generator_by_file_id(self, file_id: str, timestamp: str, nonce: str, sign: str):
result = file_helpers.verify_file_signature(upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign)
if not result:
raise NotFound("File not found or signature is invalid")
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found or signature is invalid")
@ -196,9 +206,9 @@ class FileService:
return generator, upload_file
@staticmethod
def get_public_image_preview(file_id: str):
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
def get_public_image_preview(self, file_id: str):
with self._session_maker(expire_on_commit=False) as session:
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found or signature is invalid")