mirror of
https://github.com/langgenius/dify.git
synced 2026-04-15 09:57:03 +08:00
refactor(api): Inject db dependency to FileService
This commit is contained in:
parent
58dfae60f0
commit
40faa9ce16
@ -22,6 +22,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from fields.file_fields import file_fields, upload_config_fields
|
||||
from libs.login import login_required
|
||||
from models import db
|
||||
from services.file_service import FileService
|
||||
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
@ -68,7 +69,7 @@ class FileApi(Resource):
|
||||
source = None
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
@ -89,7 +90,7 @@ class FilePreviewApi(Resource):
|
||||
@account_initialization_required
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
text = FileService.get_file_preview(file_id)
|
||||
text = FileService(db.engine).get_file_preview(file_id)
|
||||
return {"content": text}
|
||||
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ from controllers.common.errors import (
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import file_fields_with_signed_url, remote_file_info_fields
|
||||
from models import db
|
||||
from models.account import Account
|
||||
from services.file_service import FileService
|
||||
|
||||
@ -61,7 +62,7 @@ class RemoteFileUploadApi(Resource):
|
||||
|
||||
try:
|
||||
user = cast(Account, current_user)
|
||||
upload_file = FileService.upload_file(
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
|
||||
@ -211,7 +211,7 @@ class WebappLogoWorkspaceApi(Resource):
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
|
||||
@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound
|
||||
import services
|
||||
from controllers.common.errors import UnsupportedFileTypeError
|
||||
from controllers.files import files_ns
|
||||
from models import db
|
||||
from services.account_service import TenantService
|
||||
from services.file_service import FileService
|
||||
|
||||
@ -28,7 +29,7 @@ class ImagePreviewApi(Resource):
|
||||
return {"content": "Invalid request."}, 400
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService.get_image_preview(
|
||||
generator, mimetype = FileService(db.engine).get_image_preview(
|
||||
file_id=file_id,
|
||||
timestamp=timestamp,
|
||||
nonce=nonce,
|
||||
@ -57,7 +58,7 @@ class FilePreviewApi(Resource):
|
||||
return {"content": "Invalid request."}, 400
|
||||
|
||||
try:
|
||||
generator, upload_file = FileService.get_file_generator_by_file_id(
|
||||
generator, upload_file = FileService(db.engine).get_file_generator_by_file_id(
|
||||
file_id=file_id,
|
||||
timestamp=args["timestamp"],
|
||||
nonce=args["nonce"],
|
||||
@ -108,7 +109,7 @@ class WorkspaceWebappLogoApi(Resource):
|
||||
raise NotFound("webapp logo is not found")
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService.get_public_image_preview(
|
||||
generator, mimetype = FileService(db.engine).get_public_image_preview(
|
||||
webapp_logo_file_id,
|
||||
)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
|
||||
@ -13,7 +13,7 @@ from controllers.common.errors import (
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from fields.file_fields import build_file_model
|
||||
from models.model import App, EndUser
|
||||
from models import App, EndUser, db
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@ -52,7 +52,7 @@ class FileApi(Resource):
|
||||
raise FilenameNotExistsError
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
|
||||
@ -123,7 +123,7 @@ class DocumentAddByTextApi(DatasetApiResource):
|
||||
args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
|
||||
)
|
||||
|
||||
upload_file = FileService.upload_text(text=str(text), text_name=str(name))
|
||||
upload_file = FileService(db.engine).upload_text(text=str(text), text_name=str(name))
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
@ -198,7 +198,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
|
||||
name = args.get("name")
|
||||
if text is None or name is None:
|
||||
raise ValueError("Both text and name must be strings.")
|
||||
upload_file = FileService.upload_text(text=str(text), text_name=str(name))
|
||||
upload_file = FileService(db.engine).upload_text(text=str(text), text_name=str(name))
|
||||
data_source = {
|
||||
"type": "upload_file",
|
||||
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
|
||||
@ -298,7 +298,7 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
if not file.filename:
|
||||
raise FilenameNotExistsError
|
||||
|
||||
upload_file = FileService.upload_file(
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
@ -387,7 +387,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
raise FilenameNotExistsError
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
|
||||
@ -12,6 +12,7 @@ from controllers.common.errors import (
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from fields.file_fields import build_file_model
|
||||
from models import db
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@ -68,7 +69,7 @@ class FileApi(WebApiResource):
|
||||
source = None
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file.filename,
|
||||
content=file.read(),
|
||||
mimetype=file.mimetype,
|
||||
|
||||
@ -15,6 +15,7 @@ from controllers.web.wraps import WebApiResource
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from fields.file_fields import build_file_with_signed_url_model, build_remote_file_info_model
|
||||
from models import db
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@ -119,7 +120,7 @@ class RemoteFileUploadApi(WebApiResource):
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
|
||||
@ -4,6 +4,8 @@ import uuid
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
@ -15,7 +17,6 @@ from constants import (
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import extract_tenant_id
|
||||
@ -29,8 +30,18 @@ PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
class FileService:
|
||||
@staticmethod
|
||||
_session_maker: sessionmaker
|
||||
|
||||
def __init__(self, session_factory: sessionmaker | Engine | None = None):
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_maker = sessionmaker(bind=session_factory)
|
||||
elif isinstance(session_factory, sessionmaker):
|
||||
self._session_maker = session_factory
|
||||
else:
|
||||
raise AssertionError("must be a sessionmaker or an Engine.")
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
*,
|
||||
filename: str,
|
||||
content: bytes,
|
||||
@ -85,14 +96,14 @@ class FileService:
|
||||
hash=hashlib.sha3_256(content).hexdigest(),
|
||||
source_url=source_url,
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
|
||||
# The `UploadFile` ID is generated within its constructor, so flushing to retrieve the ID is unnecessary.
|
||||
# We can directly generate the `source_url` here before committing.
|
||||
if not upload_file.source_url:
|
||||
upload_file.source_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id)
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
session.add(upload_file)
|
||||
session.commit()
|
||||
|
||||
return upload_file
|
||||
|
||||
@ -109,8 +120,7 @@ class FileService:
|
||||
|
||||
return file_size <= file_size_limit
|
||||
|
||||
@staticmethod
|
||||
def upload_text(text: str, text_name: str) -> UploadFile:
|
||||
def upload_text(self, text: str, text_name: str) -> UploadFile:
|
||||
if len(text_name) > 200:
|
||||
text_name = text_name[:200]
|
||||
# user uuid as file name
|
||||
@ -137,14 +147,15 @@ class FileService:
|
||||
used_at=naive_utc_now(),
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
session.add(upload_file)
|
||||
session.commit()
|
||||
|
||||
return upload_file
|
||||
|
||||
@staticmethod
|
||||
def get_file_preview(file_id: str):
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
def get_file_preview(self, file_id: str):
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
@ -159,15 +170,14 @@ class FileService:
|
||||
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str):
|
||||
def get_image_preview(self, file_id: str, timestamp: str, nonce: str, sign: str):
|
||||
result = file_helpers.verify_image_signature(
|
||||
upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign
|
||||
)
|
||||
if not result:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
@ -181,13 +191,13 @@ class FileService:
|
||||
|
||||
return generator, upload_file.mime_type
|
||||
|
||||
@staticmethod
|
||||
def get_file_generator_by_file_id(file_id: str, timestamp: str, nonce: str, sign: str):
|
||||
def get_file_generator_by_file_id(self, file_id: str, timestamp: str, nonce: str, sign: str):
|
||||
result = file_helpers.verify_file_signature(upload_file_id=file_id, timestamp=timestamp, nonce=nonce, sign=sign)
|
||||
if not result:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
@ -196,9 +206,9 @@ class FileService:
|
||||
|
||||
return generator, upload_file
|
||||
|
||||
@staticmethod
|
||||
def get_public_image_preview(file_id: str):
|
||||
upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
def get_public_image_preview(self, file_id: str):
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user