mirror of
https://github.com/langgenius/dify.git
synced 2026-06-19 08:31:07 +08:00
chore(api): migrate file factory builders and account commands to use Session(db.engine) (#35236)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
881a9a1a08
commit
e70e4fa41d
@ -2,6 +2,7 @@ import base64
|
|||||||
import secrets
|
import secrets
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from constants.languages import languages
|
from constants.languages import languages
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
@ -43,10 +44,11 @@ def reset_password(email, new_password, password_confirm):
|
|||||||
# encrypt password with salt
|
# encrypt password with salt
|
||||||
password_hashed = hash_password(new_password, salt)
|
password_hashed = hash_password(new_password, salt)
|
||||||
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
base64_password_hashed = base64.b64encode(password_hashed).decode()
|
||||||
account = db.session.merge(account)
|
with Session(db.engine) as session:
|
||||||
account.password = base64_password_hashed
|
account = session.merge(account)
|
||||||
account.password_salt = base64_salt
|
account.password = base64_password_hashed
|
||||||
db.session.commit()
|
account.password_salt = base64_salt
|
||||||
|
session.commit()
|
||||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||||
click.echo(click.style("Password reset successfully.", fg="green"))
|
click.echo(click.style("Password reset successfully.", fg="green"))
|
||||||
|
|
||||||
@ -77,9 +79,10 @@ def reset_email(email, new_email, email_confirm):
|
|||||||
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
|
||||||
return
|
return
|
||||||
|
|
||||||
account = db.session.merge(account)
|
with Session(db.engine) as session:
|
||||||
account.email = normalized_new_email
|
account = session.merge(account)
|
||||||
db.session.commit()
|
account.email = normalized_new_email
|
||||||
|
session.commit()
|
||||||
click.echo(click.style("Email updated successfully.", fg="green"))
|
click.echo(click.style("Email updated successfully.", fg="green"))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -10,8 +10,8 @@ from typing import Any
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from core.app.file_access import FileAccessControllerProtocol
|
from core.app.file_access import FileAccessControllerProtocol
|
||||||
|
from core.db.session_factory import session_factory
|
||||||
from core.workflow.file_reference import build_file_reference
|
from core.workflow.file_reference import build_file_reference
|
||||||
from extensions.ext_database import db
|
|
||||||
from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type
|
from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type
|
||||||
from models import ToolFile, UploadFile
|
from models import ToolFile, UploadFile
|
||||||
|
|
||||||
@ -135,29 +135,30 @@ def _build_from_local_file(
|
|||||||
UploadFile.id == upload_file_id,
|
UploadFile.id == upload_file_id,
|
||||||
UploadFile.tenant_id == tenant_id,
|
UploadFile.tenant_id == tenant_id,
|
||||||
)
|
)
|
||||||
row = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
|
with session_factory.create_session() as session:
|
||||||
if row is None:
|
row = session.scalar(access_controller.apply_upload_file_filters(stmt))
|
||||||
raise ValueError("Invalid upload file")
|
if row is None:
|
||||||
|
raise ValueError("Invalid upload file")
|
||||||
|
|
||||||
detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
|
detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type)
|
||||||
file_type = _resolve_file_type(
|
file_type = _resolve_file_type(
|
||||||
detected_file_type=detected_file_type,
|
detected_file_type=detected_file_type,
|
||||||
specified_type=mapping.get("type", "custom"),
|
specified_type=mapping.get("type", "custom"),
|
||||||
strict_type_validation=strict_type_validation,
|
strict_type_validation=strict_type_validation,
|
||||||
)
|
)
|
||||||
|
|
||||||
return File(
|
return File(
|
||||||
id=mapping.get("id"),
|
id=mapping.get("id"),
|
||||||
filename=row.name,
|
filename=row.name,
|
||||||
extension="." + row.extension,
|
extension="." + row.extension,
|
||||||
mime_type=row.mime_type,
|
mime_type=row.mime_type,
|
||||||
type=file_type,
|
type=file_type,
|
||||||
transfer_method=transfer_method,
|
transfer_method=transfer_method,
|
||||||
remote_url=row.source_url,
|
remote_url=row.source_url,
|
||||||
reference=build_file_reference(record_id=str(row.id)),
|
reference=build_file_reference(record_id=str(row.id)),
|
||||||
size=row.size,
|
size=row.size,
|
||||||
storage_key=row.key,
|
storage_key=row.key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _build_from_remote_url(
|
def _build_from_remote_url(
|
||||||
@ -179,32 +180,33 @@ def _build_from_remote_url(
|
|||||||
UploadFile.id == upload_file_id,
|
UploadFile.id == upload_file_id,
|
||||||
UploadFile.tenant_id == tenant_id,
|
UploadFile.tenant_id == tenant_id,
|
||||||
)
|
)
|
||||||
upload_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
|
with session_factory.create_session() as session:
|
||||||
if upload_file is None:
|
upload_file = session.scalar(access_controller.apply_upload_file_filters(stmt))
|
||||||
raise ValueError("Invalid upload file")
|
if upload_file is None:
|
||||||
|
raise ValueError("Invalid upload file")
|
||||||
|
|
||||||
detected_file_type = standardize_file_type(
|
detected_file_type = standardize_file_type(
|
||||||
extension="." + upload_file.extension,
|
extension="." + upload_file.extension,
|
||||||
mime_type=upload_file.mime_type,
|
mime_type=upload_file.mime_type,
|
||||||
)
|
)
|
||||||
file_type = _resolve_file_type(
|
file_type = _resolve_file_type(
|
||||||
detected_file_type=detected_file_type,
|
detected_file_type=detected_file_type,
|
||||||
specified_type=mapping.get("type"),
|
specified_type=mapping.get("type"),
|
||||||
strict_type_validation=strict_type_validation,
|
strict_type_validation=strict_type_validation,
|
||||||
)
|
)
|
||||||
|
|
||||||
return File(
|
return File(
|
||||||
id=mapping.get("id"),
|
id=mapping.get("id"),
|
||||||
filename=upload_file.name,
|
filename=upload_file.name,
|
||||||
extension="." + upload_file.extension,
|
extension="." + upload_file.extension,
|
||||||
mime_type=upload_file.mime_type,
|
mime_type=upload_file.mime_type,
|
||||||
type=file_type,
|
type=file_type,
|
||||||
transfer_method=transfer_method,
|
transfer_method=transfer_method,
|
||||||
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
|
remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)),
|
||||||
reference=build_file_reference(record_id=str(upload_file.id)),
|
reference=build_file_reference(record_id=str(upload_file.id)),
|
||||||
size=upload_file.size,
|
size=upload_file.size,
|
||||||
storage_key=upload_file.key,
|
storage_key=upload_file.key,
|
||||||
)
|
)
|
||||||
|
|
||||||
url = mapping.get("url") or mapping.get("remote_url")
|
url = mapping.get("url") or mapping.get("remote_url")
|
||||||
if not url:
|
if not url:
|
||||||
@ -247,30 +249,31 @@ def _build_from_tool_file(
|
|||||||
ToolFile.id == tool_file_id,
|
ToolFile.id == tool_file_id,
|
||||||
ToolFile.tenant_id == tenant_id,
|
ToolFile.tenant_id == tenant_id,
|
||||||
)
|
)
|
||||||
tool_file = db.session.scalar(access_controller.apply_tool_file_filters(stmt))
|
with session_factory.create_session() as session:
|
||||||
if tool_file is None:
|
tool_file = session.scalar(access_controller.apply_tool_file_filters(stmt))
|
||||||
raise ValueError(f"ToolFile {tool_file_id} not found")
|
if tool_file is None:
|
||||||
|
raise ValueError(f"ToolFile {tool_file_id} not found")
|
||||||
|
|
||||||
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
|
||||||
detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
|
detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
|
||||||
file_type = _resolve_file_type(
|
file_type = _resolve_file_type(
|
||||||
detected_file_type=detected_file_type,
|
detected_file_type=detected_file_type,
|
||||||
specified_type=mapping.get("type"),
|
specified_type=mapping.get("type"),
|
||||||
strict_type_validation=strict_type_validation,
|
strict_type_validation=strict_type_validation,
|
||||||
)
|
)
|
||||||
|
|
||||||
return File(
|
return File(
|
||||||
id=mapping.get("id"),
|
id=mapping.get("id"),
|
||||||
filename=tool_file.name,
|
filename=tool_file.name,
|
||||||
type=file_type,
|
type=file_type,
|
||||||
transfer_method=transfer_method,
|
transfer_method=transfer_method,
|
||||||
remote_url=tool_file.original_url,
|
remote_url=tool_file.original_url,
|
||||||
reference=build_file_reference(record_id=str(tool_file.id)),
|
reference=build_file_reference(record_id=str(tool_file.id)),
|
||||||
extension=extension,
|
extension=extension,
|
||||||
mime_type=tool_file.mimetype,
|
mime_type=tool_file.mimetype,
|
||||||
size=tool_file.size,
|
size=tool_file.size,
|
||||||
storage_key=tool_file.file_key,
|
storage_key=tool_file.file_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _build_from_datasource_file(
|
def _build_from_datasource_file(
|
||||||
@ -289,31 +292,32 @@ def _build_from_datasource_file(
|
|||||||
UploadFile.id == datasource_file_id,
|
UploadFile.id == datasource_file_id,
|
||||||
UploadFile.tenant_id == tenant_id,
|
UploadFile.tenant_id == tenant_id,
|
||||||
)
|
)
|
||||||
datasource_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt))
|
with session_factory.create_session() as session:
|
||||||
if datasource_file is None:
|
datasource_file = session.scalar(access_controller.apply_upload_file_filters(stmt))
|
||||||
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
|
if datasource_file is None:
|
||||||
|
raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found")
|
||||||
|
|
||||||
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
|
extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin"
|
||||||
detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
|
detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type)
|
||||||
file_type = _resolve_file_type(
|
file_type = _resolve_file_type(
|
||||||
detected_file_type=detected_file_type,
|
detected_file_type=detected_file_type,
|
||||||
specified_type=mapping.get("type"),
|
specified_type=mapping.get("type"),
|
||||||
strict_type_validation=strict_type_validation,
|
strict_type_validation=strict_type_validation,
|
||||||
)
|
)
|
||||||
|
|
||||||
return File(
|
return File(
|
||||||
id=mapping.get("datasource_file_id"),
|
id=mapping.get("datasource_file_id"),
|
||||||
filename=datasource_file.name,
|
filename=datasource_file.name,
|
||||||
type=file_type,
|
type=file_type,
|
||||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||||
remote_url=datasource_file.source_url,
|
remote_url=datasource_file.source_url,
|
||||||
reference=build_file_reference(record_id=str(datasource_file.id)),
|
reference=build_file_reference(record_id=str(datasource_file.id)),
|
||||||
extension=extension,
|
extension=extension,
|
||||||
mime_type=datasource_file.mime_type,
|
mime_type=datasource_file.mime_type,
|
||||||
size=datasource_file.size,
|
size=datasource_file.size,
|
||||||
storage_key=datasource_file.key,
|
storage_key=datasource_file.key,
|
||||||
url=datasource_file.source_url,
|
url=datasource_file.source_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool:
|
def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool:
|
||||||
|
|||||||
@ -11,6 +11,21 @@ from factories.file_factory.builders import build_from_mapping as _build_from_ma
|
|||||||
from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig
|
from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig
|
||||||
from models import ToolFile, UploadFile
|
from models import ToolFile, UploadFile
|
||||||
|
|
||||||
|
|
||||||
|
def _make_session_ctx_mock(scalar_return=None):
|
||||||
|
"""Return a mock usable as the ``session_factory.create_session()`` context manager.
|
||||||
|
|
||||||
|
Patch ``factories.file_factory.builders.session_factory`` and set
|
||||||
|
``mock_sf.create_session.return_value = <this mock>`` to intercept DB calls
|
||||||
|
without requiring a live Flask app or database engine.
|
||||||
|
"""
|
||||||
|
session = MagicMock()
|
||||||
|
session.__enter__.return_value = session
|
||||||
|
session.__exit__.return_value = False
|
||||||
|
session.scalar.return_value = scalar_return
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
# Test Data
|
# Test Data
|
||||||
TEST_TENANT_ID = "test_tenant_id"
|
TEST_TENANT_ID = "test_tenant_id"
|
||||||
TEST_UPLOAD_FILE_ID = str(uuid.uuid4())
|
TEST_UPLOAD_FILE_ID = str(uuid.uuid4())
|
||||||
@ -49,8 +64,11 @@ def mock_upload_file():
|
|||||||
mock.source_url = TEST_REMOTE_URL
|
mock.source_url = TEST_REMOTE_URL
|
||||||
mock.size = 1024
|
mock.size = 1024
|
||||||
mock.key = "test_key"
|
mock.key = "test_key"
|
||||||
with patch("factories.file_factory.builders.db.session.scalar", return_value=mock, autospec=True) as m:
|
session = _make_session_ctx_mock(scalar_return=mock)
|
||||||
yield m
|
with patch("factories.file_factory.builders.session_factory") as mock_sf:
|
||||||
|
mock_sf.create_session.return_value = session
|
||||||
|
# yield session.scalar so callers can inspect call_args and mutate return_value
|
||||||
|
yield session.scalar
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -63,7 +81,9 @@ def mock_tool_file():
|
|||||||
mock.mimetype = "application/pdf"
|
mock.mimetype = "application/pdf"
|
||||||
mock.original_url = "http://example.com/tool.pdf"
|
mock.original_url = "http://example.com/tool.pdf"
|
||||||
mock.size = 2048
|
mock.size = 2048
|
||||||
with patch("factories.file_factory.builders.db.session.scalar", return_value=mock, autospec=True):
|
session = _make_session_ctx_mock(scalar_return=mock)
|
||||||
|
with patch("factories.file_factory.builders.session_factory") as mock_sf:
|
||||||
|
mock_sf.create_session.return_value = session
|
||||||
yield mock
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
@ -231,7 +251,9 @@ def test_build_from_remote_url_without_strict_validation(mock_http_head):
|
|||||||
|
|
||||||
def test_tool_file_not_found():
|
def test_tool_file_not_found():
|
||||||
"""Test ToolFile not found in database."""
|
"""Test ToolFile not found in database."""
|
||||||
with patch("factories.file_factory.builders.db.session.scalar", return_value=None, autospec=True):
|
session = _make_session_ctx_mock(scalar_return=None)
|
||||||
|
with patch("factories.file_factory.builders.session_factory") as mock_sf:
|
||||||
|
mock_sf.create_session.return_value = session
|
||||||
mapping = tool_file_mapping()
|
mapping = tool_file_mapping()
|
||||||
with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"):
|
with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"):
|
||||||
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||||
@ -239,7 +261,9 @@ def test_tool_file_not_found():
|
|||||||
|
|
||||||
def test_local_file_not_found():
|
def test_local_file_not_found():
|
||||||
"""Test UploadFile not found in database."""
|
"""Test UploadFile not found in database."""
|
||||||
with patch("factories.file_factory.builders.db.session.scalar", return_value=None, autospec=True):
|
session = _make_session_ctx_mock(scalar_return=None)
|
||||||
|
with patch("factories.file_factory.builders.session_factory") as mock_sf:
|
||||||
|
mock_sf.create_session.return_value = session
|
||||||
mapping = local_file_mapping()
|
mapping = local_file_mapping()
|
||||||
with pytest.raises(ValueError, match="Invalid upload file"):
|
with pytest.raises(ValueError, match="Invalid upload file"):
|
||||||
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||||
@ -311,7 +335,9 @@ def test_tenant_mismatch():
|
|||||||
mock_file.key = "test_key"
|
mock_file.key = "test_key"
|
||||||
|
|
||||||
# Mock the database query to return None (no file found for this tenant)
|
# Mock the database query to return None (no file found for this tenant)
|
||||||
with patch("factories.file_factory.builders.db.session.scalar", return_value=None, autospec=True):
|
session = _make_session_ctx_mock(scalar_return=None)
|
||||||
|
with patch("factories.file_factory.builders.session_factory") as mock_sf:
|
||||||
|
mock_sf.create_session.return_value = session
|
||||||
mapping = local_file_mapping()
|
mapping = local_file_mapping()
|
||||||
with pytest.raises(ValueError, match="Invalid upload file"):
|
with pytest.raises(ValueError, match="Invalid upload file"):
|
||||||
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||||
@ -350,11 +376,13 @@ def test_build_from_mapping_scopes_tool_file_to_end_user():
|
|||||||
invoke_from=InvokeFrom.WEB_APP,
|
invoke_from=InvokeFrom.WEB_APP,
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch("factories.file_factory.builders.db.session.scalar", return_value=tool_file, autospec=True) as scalar:
|
session = _make_session_ctx_mock(scalar_return=tool_file)
|
||||||
|
with patch("factories.file_factory.builders.session_factory") as mock_sf:
|
||||||
|
mock_sf.create_session.return_value = session
|
||||||
with bind_file_access_scope(scope):
|
with bind_file_access_scope(scope):
|
||||||
build_from_mapping(mapping=tool_file_mapping(), tenant_id=TEST_TENANT_ID)
|
build_from_mapping(mapping=tool_file_mapping(), tenant_id=TEST_TENANT_ID)
|
||||||
|
|
||||||
stmt = scalar.call_args.args[0]
|
stmt = session.scalar.call_args.args[0]
|
||||||
whereclause = str(stmt.whereclause)
|
whereclause = str(stmt.whereclause)
|
||||||
assert "tool_files.user_id" in whereclause
|
assert "tool_files.user_id" in whereclause
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user