mirror of
https://github.com/langgenius/dify.git
synced 2026-06-16 14:01:10 +08:00
refactor: accept db.session explicitly in RecommendedAppService (#37417)
Co-authored-by: Shahil Kadia <shahil@users.noreply.github.com>
This commit is contained in:
parent
a875d76290
commit
c6b3e525d1
@ -9,6 +9,7 @@ from constants.languages import languages
|
||||
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, with_current_user
|
||||
from extensions.ext_database import db
|
||||
from fields.base import ResponseModel
|
||||
from libs.helper import build_icon_url
|
||||
from libs.login import login_required
|
||||
@ -98,7 +99,7 @@ class RecommendedAppListApi(Resource):
|
||||
language_prefix = languages[0]
|
||||
|
||||
return RecommendedAppListResponse.model_validate(
|
||||
RecommendedAppService.get_recommended_apps_and_categories(language_prefix),
|
||||
RecommendedAppService.get_recommended_apps_and_categories(db.session, language_prefix),
|
||||
from_attributes=True,
|
||||
).model_dump(mode="json")
|
||||
|
||||
@ -109,4 +110,4 @@ class RecommendedAppApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, app_id: UUID):
|
||||
return RecommendedAppService.get_recommend_app_detail(str(app_id))
|
||||
return RecommendedAppService.get_recommend_app_detail(db.session, str(app_id))
|
||||
|
||||
@ -223,7 +223,7 @@ class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
RecommendedAppService.add_trial_app_record(db.session, app_id, user_id)
|
||||
return helper.compact_generate_response(response)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@ -296,7 +296,7 @@ class TrialChatApi(TrialAppResource):
|
||||
response = AppGenerateService.generate(
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
|
||||
)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
RecommendedAppService.add_trial_app_record(db.session, app_id, user_id)
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
@ -373,7 +373,7 @@ class TrialChatAudioApi(TrialAppResource):
|
||||
user_id = current_user.id
|
||||
|
||||
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
RecommendedAppService.add_trial_app_record(db.session, app_id, user_id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
@ -420,7 +420,7 @@ class TrialChatTextApi(TrialAppResource):
|
||||
user_id = current_user.id
|
||||
|
||||
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
RecommendedAppService.add_trial_app_record(db.session, app_id, user_id)
|
||||
return response
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
@ -473,7 +473,7 @@ class TrialCompletionApi(TrialAppResource):
|
||||
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
|
||||
)
|
||||
|
||||
RecommendedAppService.add_trial_app_record(app_id, user_id)
|
||||
RecommendedAppService.add_trial_app_record(db.session, app_id, user_id)
|
||||
return helper.compact_generate_response(response)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import scoped_session
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from models.model import AccountTrialAppRecord, TrialApp
|
||||
from services.feature_service import FeatureService
|
||||
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
|
||||
@ -11,7 +11,7 @@ from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFa
|
||||
|
||||
class RecommendedAppService:
|
||||
@classmethod
|
||||
def get_recommended_apps_and_categories(cls, language: str):
|
||||
def get_recommended_apps_and_categories(cls, session: scoped_session, language: str):
|
||||
"""
|
||||
Get recommended apps and categories.
|
||||
:param language: language
|
||||
@ -31,7 +31,7 @@ class RecommendedAppService:
|
||||
apps = result["recommended_apps"]
|
||||
for app in apps:
|
||||
app_id = app["app_id"]
|
||||
trial_app_model = db.session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1))
|
||||
trial_app_model = session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1))
|
||||
if trial_app_model:
|
||||
app["can_trial"] = True
|
||||
else:
|
||||
@ -39,7 +39,7 @@ class RecommendedAppService:
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_recommend_app_detail(cls, app_id: str) -> dict[str, Any] | None:
|
||||
def get_recommend_app_detail(cls, session: scoped_session, app_id: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Get recommend app detail.
|
||||
:param app_id: app id
|
||||
@ -52,7 +52,7 @@ class RecommendedAppService:
|
||||
return None
|
||||
if FeatureService.get_system_features().enable_trial_app:
|
||||
app_id = result["id"]
|
||||
trial_app_model = db.session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1))
|
||||
trial_app_model = session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1))
|
||||
if trial_app_model:
|
||||
result["can_trial"] = True
|
||||
else:
|
||||
@ -60,20 +60,20 @@ class RecommendedAppService:
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def add_trial_app_record(cls, app_id: str, account_id: str):
|
||||
def add_trial_app_record(cls, session: scoped_session, app_id: str, account_id: str):
|
||||
"""
|
||||
Add trial app record.
|
||||
:param app_id: app id
|
||||
:return:
|
||||
"""
|
||||
account_trial_app_record = db.session.scalar(
|
||||
account_trial_app_record = session.scalar(
|
||||
select(AccountTrialAppRecord)
|
||||
.where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id)
|
||||
.limit(1)
|
||||
)
|
||||
if account_trial_app_record:
|
||||
account_trial_app_record.count += 1
|
||||
db.session.commit()
|
||||
session.commit()
|
||||
else:
|
||||
db.session.add(AccountTrialAppRecord(app_id=app_id, count=1, account_id=account_id))
|
||||
db.session.commit()
|
||||
session.add(AccountTrialAppRecord(app_id=app_id, count=1, account_id=account_id))
|
||||
session.commit()
|
||||
|
||||
@ -9,6 +9,7 @@ import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import AccountTrialAppRecord, TrialApp
|
||||
from services import recommended_app_service as service_module
|
||||
from services.recommended_app_service import RecommendedAppService
|
||||
@ -117,7 +118,7 @@ class TestRecommendedAppServiceGetApps:
|
||||
mock_factory = MagicMock(return_value=mock_instance)
|
||||
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
|
||||
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories(db.session, "en-US")
|
||||
|
||||
assert result == expected
|
||||
assert len(result["recommended_apps"]) == 2
|
||||
@ -142,7 +143,7 @@ class TestRecommendedAppServiceGetApps:
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
|
||||
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
|
||||
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("zh-CN")
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories(db.session, "zh-CN")
|
||||
|
||||
assert result == builtin_response
|
||||
assert result["recommended_apps"][0]["id"] == "builtin-1"
|
||||
@ -163,7 +164,7 @@ class TestRecommendedAppServiceGetApps:
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
|
||||
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
|
||||
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories(db.session, "en-US")
|
||||
|
||||
assert result == builtin_response
|
||||
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once()
|
||||
@ -181,7 +182,7 @@ class TestRecommendedAppServiceGetApps:
|
||||
mock_instance.get_recommended_apps_and_categories.return_value = lang_response
|
||||
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
|
||||
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories(language)
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories(db.session, language)
|
||||
|
||||
assert result["recommended_apps"][0]["id"] == f"app-{language}"
|
||||
mock_instance.get_recommended_apps_and_categories.assert_called_with(language)
|
||||
@ -196,7 +197,7 @@ class TestRecommendedAppServiceGetApps:
|
||||
mock_instance.get_recommended_apps_and_categories.return_value = response
|
||||
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
|
||||
|
||||
RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
RecommendedAppService.get_recommended_apps_and_categories(db.session, "en-US")
|
||||
|
||||
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
|
||||
|
||||
@ -236,7 +237,7 @@ class TestRecommendedAppServiceGetDetail:
|
||||
mock_instance.get_recommend_app_detail.return_value = expected
|
||||
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
|
||||
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
result = RecommendedAppService.get_recommend_app_detail(db.session, app_id)
|
||||
|
||||
assert result == expected
|
||||
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
|
||||
@ -255,7 +256,7 @@ class TestRecommendedAppServiceGetDetail:
|
||||
mock_instance.get_recommend_app_detail.return_value = detail
|
||||
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
|
||||
|
||||
result = RecommendedAppService.get_recommend_app_detail("test-app")
|
||||
result = RecommendedAppService.get_recommend_app_detail(db.session, "test-app")
|
||||
|
||||
assert result is not None
|
||||
mock_instance.get_recommend_app_detail.assert_called_with("test-app")
|
||||
@ -275,7 +276,7 @@ class TestRecommendedAppServiceTrialFeatures:
|
||||
MagicMock(return_value=SimpleNamespace(enable_trial_app=False)),
|
||||
)
|
||||
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories(db.session, "en-US")
|
||||
|
||||
assert result == expected
|
||||
retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
|
||||
@ -306,7 +307,7 @@ class TestRecommendedAppServiceTrialFeatures:
|
||||
MagicMock(return_value=SimpleNamespace(enable_trial_app=True)),
|
||||
)
|
||||
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories("ja-JP")
|
||||
result = RecommendedAppService.get_recommended_apps_and_categories(db.session, "ja-JP")
|
||||
|
||||
builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
|
||||
assert result["recommended_apps"][0]["can_trial"] is True
|
||||
@ -342,7 +343,7 @@ class TestRecommendedAppServiceTrialFeatures:
|
||||
MagicMock(return_value=SimpleNamespace(enable_trial_app=True)),
|
||||
)
|
||||
|
||||
result = RecommendedAppService.get_recommend_app_detail(app_id)
|
||||
result = RecommendedAppService.get_recommend_app_detail(db.session, app_id)
|
||||
assert result is not None
|
||||
detail_result = cast(RecommendedAppPayload, result)
|
||||
|
||||
@ -363,7 +364,7 @@ class TestRecommendedAppServiceTrialFeatures:
|
||||
mock_instance.get_recommend_app_detail.return_value = None
|
||||
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
|
||||
|
||||
result = RecommendedAppService.get_recommend_app_detail("nonexistent")
|
||||
result = RecommendedAppService.get_recommend_app_detail(db.session, "nonexistent")
|
||||
|
||||
assert result is None
|
||||
mock_instance.get_recommend_app_detail.assert_called_once_with("nonexistent")
|
||||
@ -376,7 +377,7 @@ class TestRecommendedAppServiceTrialFeatures:
|
||||
db_session_with_containers.add(AccountTrialAppRecord(app_id=app_id, account_id=account_id, count=3))
|
||||
db_session_with_containers.commit()
|
||||
|
||||
RecommendedAppService.add_trial_app_record(app_id, account_id)
|
||||
RecommendedAppService.add_trial_app_record(db.session, app_id, account_id)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
record = db_session_with_containers.scalar(
|
||||
@ -391,7 +392,7 @@ class TestRecommendedAppServiceTrialFeatures:
|
||||
app_id = str(uuid.uuid4())
|
||||
account_id = str(uuid.uuid4())
|
||||
|
||||
RecommendedAppService.add_trial_app_record(app_id, account_id)
|
||||
RecommendedAppService.add_trial_app_record(db.session, app_id, account_id)
|
||||
|
||||
db_session_with_containers.expire_all()
|
||||
record = db_session_with_containers.scalar(
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import ANY, patch
|
||||
|
||||
from flask import Flask
|
||||
|
||||
@ -37,7 +37,7 @@ class TestRecommendedAppListApi:
|
||||
):
|
||||
result = method(api, make_account("fr-FR"))
|
||||
|
||||
service_mock.assert_called_once_with("en-US")
|
||||
service_mock.assert_called_once_with(ANY, "en-US")
|
||||
assert result == result_data
|
||||
|
||||
def test_get_fallback_to_user_language(self, app: Flask):
|
||||
@ -56,7 +56,7 @@ class TestRecommendedAppListApi:
|
||||
):
|
||||
result = method(api, make_account("fr-FR"))
|
||||
|
||||
service_mock.assert_called_once_with("fr-FR")
|
||||
service_mock.assert_called_once_with(ANY, "fr-FR")
|
||||
assert result == result_data
|
||||
|
||||
def test_get_fallback_to_default_language(self, app: Flask):
|
||||
@ -75,7 +75,7 @@ class TestRecommendedAppListApi:
|
||||
):
|
||||
result = method(api, make_account(None))
|
||||
|
||||
service_mock.assert_called_once_with(module.languages[0])
|
||||
service_mock.assert_called_once_with(ANY, module.languages[0])
|
||||
assert result == result_data
|
||||
|
||||
|
||||
@ -96,7 +96,7 @@ class TestRecommendedAppApi:
|
||||
):
|
||||
result = method(api, "11111111-1111-1111-1111-111111111111")
|
||||
|
||||
service_mock.assert_called_once_with("11111111-1111-1111-1111-111111111111")
|
||||
service_mock.assert_called_once_with(ANY, "11111111-1111-1111-1111-111111111111")
|
||||
assert result == result_data
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user