refactor: accept db.session explicitly in RecommendedAppService (#37417)

Co-authored-by: Shahil Kadia <shahil@users.noreply.github.com>
This commit is contained in:
Shahil kadia 2026-06-15 06:49:16 +05:30 committed by GitHub
parent a875d76290
commit c6b3e525d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 37 additions and 35 deletions

View File

@ -9,6 +9,7 @@ from constants.languages import languages
from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, with_current_user
from extensions.ext_database import db
from fields.base import ResponseModel
from libs.helper import build_icon_url
from libs.login import login_required
@ -98,7 +99,7 @@ class RecommendedAppListApi(Resource):
language_prefix = languages[0]
return RecommendedAppListResponse.model_validate(
RecommendedAppService.get_recommended_apps_and_categories(language_prefix),
RecommendedAppService.get_recommended_apps_and_categories(db.session, language_prefix),
from_attributes=True,
).model_dump(mode="json")
@ -109,4 +110,4 @@ class RecommendedAppApi(Resource):
@login_required
@account_initialization_required
def get(self, app_id: UUID):
return RecommendedAppService.get_recommend_app_detail(str(app_id))
return RecommendedAppService.get_recommend_app_detail(db.session, str(app_id))

View File

@ -223,7 +223,7 @@ class TrialAppWorkflowRunApi(TrialAppResource):
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
)
RecommendedAppService.add_trial_app_record(app_id, user_id)
RecommendedAppService.add_trial_app_record(db.session, app_id, user_id)
return helper.compact_generate_response(response)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@ -296,7 +296,7 @@ class TrialChatApi(TrialAppResource):
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
)
RecommendedAppService.add_trial_app_record(app_id, user_id)
RecommendedAppService.add_trial_app_record(db.session, app_id, user_id)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
@ -373,7 +373,7 @@ class TrialChatAudioApi(TrialAppResource):
user_id = current_user.id
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
RecommendedAppService.add_trial_app_record(app_id, user_id)
RecommendedAppService.add_trial_app_record(db.session, app_id, user_id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
@ -420,7 +420,7 @@ class TrialChatTextApi(TrialAppResource):
user_id = current_user.id
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
RecommendedAppService.add_trial_app_record(app_id, user_id)
RecommendedAppService.add_trial_app_record(db.session, app_id, user_id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
@ -473,7 +473,7 @@ class TrialCompletionApi(TrialAppResource):
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
)
RecommendedAppService.add_trial_app_record(app_id, user_id)
RecommendedAppService.add_trial_app_record(db.session, app_id, user_id)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@ -1,9 +1,9 @@
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import scoped_session
from configs import dify_config
from extensions.ext_database import db
from models.model import AccountTrialAppRecord, TrialApp
from services.feature_service import FeatureService
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
@ -11,7 +11,7 @@ from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFa
class RecommendedAppService:
@classmethod
def get_recommended_apps_and_categories(cls, language: str):
def get_recommended_apps_and_categories(cls, session: scoped_session, language: str):
"""
Get recommended apps and categories.
:param language: language
@ -31,7 +31,7 @@ class RecommendedAppService:
apps = result["recommended_apps"]
for app in apps:
app_id = app["app_id"]
trial_app_model = db.session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1))
trial_app_model = session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1))
if trial_app_model:
app["can_trial"] = True
else:
@ -39,7 +39,7 @@ class RecommendedAppService:
return result
@classmethod
def get_recommend_app_detail(cls, app_id: str) -> dict[str, Any] | None:
def get_recommend_app_detail(cls, session: scoped_session, app_id: str) -> dict[str, Any] | None:
"""
Get recommend app detail.
:param app_id: app id
@ -52,7 +52,7 @@ class RecommendedAppService:
return None
if FeatureService.get_system_features().enable_trial_app:
app_id = result["id"]
trial_app_model = db.session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1))
trial_app_model = session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1))
if trial_app_model:
result["can_trial"] = True
else:
@ -60,20 +60,20 @@ class RecommendedAppService:
return result
@classmethod
def add_trial_app_record(cls, app_id: str, account_id: str):
def add_trial_app_record(cls, session: scoped_session, app_id: str, account_id: str):
"""
Add trial app record.
:param app_id: app id
:return:
"""
account_trial_app_record = db.session.scalar(
account_trial_app_record = session.scalar(
select(AccountTrialAppRecord)
.where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id)
.limit(1)
)
if account_trial_app_record:
account_trial_app_record.count += 1
db.session.commit()
session.commit()
else:
db.session.add(AccountTrialAppRecord(app_id=app_id, count=1, account_id=account_id))
db.session.commit()
session.add(AccountTrialAppRecord(app_id=app_id, count=1, account_id=account_id))
session.commit()

View File

@ -9,6 +9,7 @@ import pytest
from sqlalchemy import select
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.model import AccountTrialAppRecord, TrialApp
from services import recommended_app_service as service_module
from services.recommended_app_service import RecommendedAppService
@ -117,7 +118,7 @@ class TestRecommendedAppServiceGetApps:
mock_factory = MagicMock(return_value=mock_instance)
mock_factory_class.get_recommend_app_factory.return_value = mock_factory
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
result = RecommendedAppService.get_recommended_apps_and_categories(db.session, "en-US")
assert result == expected
assert len(result["recommended_apps"]) == 2
@ -142,7 +143,7 @@ class TestRecommendedAppServiceGetApps:
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
result = RecommendedAppService.get_recommended_apps_and_categories("zh-CN")
result = RecommendedAppService.get_recommended_apps_and_categories(db.session, "zh-CN")
assert result == builtin_response
assert result["recommended_apps"][0]["id"] == "builtin-1"
@ -163,7 +164,7 @@ class TestRecommendedAppServiceGetApps:
mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response
mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
result = RecommendedAppService.get_recommended_apps_and_categories(db.session, "en-US")
assert result == builtin_response
mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once()
@ -181,7 +182,7 @@ class TestRecommendedAppServiceGetApps:
mock_instance.get_recommended_apps_and_categories.return_value = lang_response
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
result = RecommendedAppService.get_recommended_apps_and_categories(language)
result = RecommendedAppService.get_recommended_apps_and_categories(db.session, language)
assert result["recommended_apps"][0]["id"] == f"app-{language}"
mock_instance.get_recommended_apps_and_categories.assert_called_with(language)
@ -196,7 +197,7 @@ class TestRecommendedAppServiceGetApps:
mock_instance.get_recommended_apps_and_categories.return_value = response
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
RecommendedAppService.get_recommended_apps_and_categories("en-US")
RecommendedAppService.get_recommended_apps_and_categories(db.session, "en-US")
mock_factory_class.get_recommend_app_factory.assert_called_with(mode)
@ -236,7 +237,7 @@ class TestRecommendedAppServiceGetDetail:
mock_instance.get_recommend_app_detail.return_value = expected
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
result = RecommendedAppService.get_recommend_app_detail(app_id)
result = RecommendedAppService.get_recommend_app_detail(db.session, app_id)
assert result == expected
mock_instance.get_recommend_app_detail.assert_called_once_with(app_id)
@ -255,7 +256,7 @@ class TestRecommendedAppServiceGetDetail:
mock_instance.get_recommend_app_detail.return_value = detail
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
result = RecommendedAppService.get_recommend_app_detail("test-app")
result = RecommendedAppService.get_recommend_app_detail(db.session, "test-app")
assert result is not None
mock_instance.get_recommend_app_detail.assert_called_with("test-app")
@ -275,7 +276,7 @@ class TestRecommendedAppServiceTrialFeatures:
MagicMock(return_value=SimpleNamespace(enable_trial_app=False)),
)
result = RecommendedAppService.get_recommended_apps_and_categories("en-US")
result = RecommendedAppService.get_recommended_apps_and_categories(db.session, "en-US")
assert result == expected
retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US")
@ -306,7 +307,7 @@ class TestRecommendedAppServiceTrialFeatures:
MagicMock(return_value=SimpleNamespace(enable_trial_app=True)),
)
result = RecommendedAppService.get_recommended_apps_and_categories("ja-JP")
result = RecommendedAppService.get_recommended_apps_and_categories(db.session, "ja-JP")
builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US")
assert result["recommended_apps"][0]["can_trial"] is True
@ -342,7 +343,7 @@ class TestRecommendedAppServiceTrialFeatures:
MagicMock(return_value=SimpleNamespace(enable_trial_app=True)),
)
result = RecommendedAppService.get_recommend_app_detail(app_id)
result = RecommendedAppService.get_recommend_app_detail(db.session, app_id)
assert result is not None
detail_result = cast(RecommendedAppPayload, result)
@ -363,7 +364,7 @@ class TestRecommendedAppServiceTrialFeatures:
mock_instance.get_recommend_app_detail.return_value = None
mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance)
result = RecommendedAppService.get_recommend_app_detail("nonexistent")
result = RecommendedAppService.get_recommend_app_detail(db.session, "nonexistent")
assert result is None
mock_instance.get_recommend_app_detail.assert_called_once_with("nonexistent")
@ -376,7 +377,7 @@ class TestRecommendedAppServiceTrialFeatures:
db_session_with_containers.add(AccountTrialAppRecord(app_id=app_id, account_id=account_id, count=3))
db_session_with_containers.commit()
RecommendedAppService.add_trial_app_record(app_id, account_id)
RecommendedAppService.add_trial_app_record(db.session, app_id, account_id)
db_session_with_containers.expire_all()
record = db_session_with_containers.scalar(
@ -391,7 +392,7 @@ class TestRecommendedAppServiceTrialFeatures:
app_id = str(uuid.uuid4())
account_id = str(uuid.uuid4())
RecommendedAppService.add_trial_app_record(app_id, account_id)
RecommendedAppService.add_trial_app_record(db.session, app_id, account_id)
db_session_with_containers.expire_all()
record = db_session_with_containers.scalar(

View File

@ -1,4 +1,4 @@
from unittest.mock import patch
from unittest.mock import ANY, patch
from flask import Flask
@ -37,7 +37,7 @@ class TestRecommendedAppListApi:
):
result = method(api, make_account("fr-FR"))
service_mock.assert_called_once_with("en-US")
service_mock.assert_called_once_with(ANY, "en-US")
assert result == result_data
def test_get_fallback_to_user_language(self, app: Flask):
@ -56,7 +56,7 @@ class TestRecommendedAppListApi:
):
result = method(api, make_account("fr-FR"))
service_mock.assert_called_once_with("fr-FR")
service_mock.assert_called_once_with(ANY, "fr-FR")
assert result == result_data
def test_get_fallback_to_default_language(self, app: Flask):
@ -75,7 +75,7 @@ class TestRecommendedAppListApi:
):
result = method(api, make_account(None))
service_mock.assert_called_once_with(module.languages[0])
service_mock.assert_called_once_with(ANY, module.languages[0])
assert result == result_data
@ -96,7 +96,7 @@ class TestRecommendedAppApi:
):
result = method(api, "11111111-1111-1111-1111-111111111111")
service_mock.assert_called_once_with("11111111-1111-1111-1111-111111111111")
service_mock.assert_called_once_with(ANY, "11111111-1111-1111-1111-111111111111")
assert result == result_data