From c6b3e525d11a1b793f333547b6a97b25226f786e Mon Sep 17 00:00:00 2001 From: Shahil kadia Date: Mon, 15 Jun 2026 06:49:16 +0530 Subject: [PATCH] refactor: accept db.session explicitly in RecommendedAppService (#37417) Co-authored-by: Shahil Kadia --- .../console/explore/recommended_app.py | 5 ++-- api/controllers/console/explore/trial.py | 10 +++---- api/services/recommended_app_service.py | 20 +++++++------- .../services/test_recommended_app_service.py | 27 ++++++++++--------- .../console/explore/test_recommended_app.py | 10 +++---- 5 files changed, 37 insertions(+), 35 deletions(-) diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index c559dc7375..1b53226440 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -9,6 +9,7 @@ from constants.languages import languages from controllers.common.schema import query_params_from_model, register_response_schema_models, register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, with_current_user +from extensions.ext_database import db from fields.base import ResponseModel from libs.helper import build_icon_url from libs.login import login_required @@ -98,7 +99,7 @@ class RecommendedAppListApi(Resource): language_prefix = languages[0] return RecommendedAppListResponse.model_validate( - RecommendedAppService.get_recommended_apps_and_categories(language_prefix), + RecommendedAppService.get_recommended_apps_and_categories(db.session, language_prefix), from_attributes=True, ).model_dump(mode="json") @@ -109,4 +110,4 @@ class RecommendedAppApi(Resource): @login_required @account_initialization_required def get(self, app_id: UUID): - return RecommendedAppService.get_recommend_app_detail(str(app_id)) + return RecommendedAppService.get_recommend_app_detail(db.session, str(app_id)) diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index 094415f46d..ad98dd303f 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -223,7 +223,7 @@ class TrialAppWorkflowRunApi(TrialAppResource): response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True ) - RecommendedAppService.add_trial_app_record(app_id, user_id) + RecommendedAppService.add_trial_app_record(db.session, app_id, user_id) return helper.compact_generate_response(response) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -296,7 +296,7 @@ class TrialChatApi(TrialAppResource): response = AppGenerateService.generate( app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True ) - RecommendedAppService.add_trial_app_record(app_id, user_id) + RecommendedAppService.add_trial_app_record(db.session, app_id, user_id) return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") @@ -373,7 +373,7 @@ class TrialChatAudioApi(TrialAppResource): user_id = current_user.id response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None) - RecommendedAppService.add_trial_app_record(app_id, user_id) + RecommendedAppService.add_trial_app_record(db.session, app_id, user_id) return response except services.errors.app_model_config.AppModelConfigBrokenError: logger.exception("App model config broken.") @@ -420,7 +420,7 @@ class TrialChatTextApi(TrialAppResource): user_id = current_user.id response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id) - RecommendedAppService.add_trial_app_record(app_id, user_id) + RecommendedAppService.add_trial_app_record(db.session, app_id, user_id) return response except services.errors.app_model_config.AppModelConfigBrokenError: logger.exception("App model config broken.") @@ -473,7 +473,7 @@ class TrialCompletionApi(TrialAppResource): app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming ) - RecommendedAppService.add_trial_app_record(app_id, user_id) + RecommendedAppService.add_trial_app_record(db.session, app_id, user_id) return helper.compact_generate_response(response) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") diff --git a/api/services/recommended_app_service.py b/api/services/recommended_app_service.py index 4e189e6e7c..00eb5ee2d1 100644 --- a/api/services/recommended_app_service.py +++ b/api/services/recommended_app_service.py @@ -1,9 +1,9 @@ from typing import Any from sqlalchemy import select +from sqlalchemy.orm import scoped_session from configs import dify_config -from extensions.ext_database import db from models.model import AccountTrialAppRecord, TrialApp from services.feature_service import FeatureService from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory @@ -11,7 +11,7 @@ from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFa class RecommendedAppService: @classmethod - def get_recommended_apps_and_categories(cls, language: str): + def get_recommended_apps_and_categories(cls, session: scoped_session, language: str): """ Get recommended apps and categories. :param language: language @@ -31,7 +31,7 @@ class RecommendedAppService: apps = result["recommended_apps"] for app in apps: app_id = app["app_id"] - trial_app_model = db.session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1)) + trial_app_model = session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1)) if trial_app_model: app["can_trial"] = True else: @@ -39,7 +39,7 @@ class RecommendedAppService: return result @classmethod - def get_recommend_app_detail(cls, app_id: str) -> dict[str, Any] | None: + def get_recommend_app_detail(cls, session: scoped_session, app_id: str) -> dict[str, Any] | None: """ Get recommend app detail. :param app_id: app id @@ -52,7 +52,7 @@ class RecommendedAppService: return None if FeatureService.get_system_features().enable_trial_app: app_id = result["id"] - trial_app_model = db.session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1)) + trial_app_model = session.scalar(select(TrialApp).where(TrialApp.app_id == app_id).limit(1)) if trial_app_model: result["can_trial"] = True else: @@ -60,20 +60,20 @@ class RecommendedAppService: return result @classmethod - def add_trial_app_record(cls, app_id: str, account_id: str): + def add_trial_app_record(cls, session: scoped_session, app_id: str, account_id: str): """ Add trial app record. :param app_id: app id :return: """ - account_trial_app_record = db.session.scalar( + account_trial_app_record = session.scalar( select(AccountTrialAppRecord) .where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id) .limit(1) ) if account_trial_app_record: account_trial_app_record.count += 1 - db.session.commit() + session.commit() else: - db.session.add(AccountTrialAppRecord(app_id=app_id, count=1, account_id=account_id)) - db.session.commit() + session.add(AccountTrialAppRecord(app_id=app_id, count=1, account_id=account_id)) + session.commit() diff --git a/api/tests/test_containers_integration_tests/services/test_recommended_app_service.py b/api/tests/test_containers_integration_tests/services/test_recommended_app_service.py index 3c7ea311da..c4c4f0ac1f 100644 --- a/api/tests/test_containers_integration_tests/services/test_recommended_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_recommended_app_service.py @@ -9,6 +9,7 @@ import pytest from sqlalchemy import select from sqlalchemy.orm import Session +from extensions.ext_database import db from models.model import AccountTrialAppRecord, TrialApp from services import recommended_app_service as service_module from services.recommended_app_service import RecommendedAppService @@ -117,7 +118,7 @@ class TestRecommendedAppServiceGetApps: mock_factory = MagicMock(return_value=mock_instance) mock_factory_class.get_recommend_app_factory.return_value = mock_factory - result = RecommendedAppService.get_recommended_apps_and_categories("en-US") + result = RecommendedAppService.get_recommended_apps_and_categories(db.session, "en-US") assert result == expected assert len(result["recommended_apps"]) == 2 @@ -142,7 +143,7 @@ class TestRecommendedAppServiceGetApps: mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance - result = RecommendedAppService.get_recommended_apps_and_categories("zh-CN") + result = RecommendedAppService.get_recommended_apps_and_categories(db.session, "zh-CN") assert result == builtin_response assert result["recommended_apps"][0]["id"] == "builtin-1" @@ -163,7 +164,7 @@ class TestRecommendedAppServiceGetApps: mock_builtin_instance.fetch_recommended_apps_from_builtin.return_value = builtin_response mock_factory_class.get_buildin_recommend_app_retrieval.return_value = mock_builtin_instance - result = RecommendedAppService.get_recommended_apps_and_categories("en-US") + result = RecommendedAppService.get_recommended_apps_and_categories(db.session, "en-US") assert result == builtin_response mock_builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once() @@ -181,7 +182,7 @@ class TestRecommendedAppServiceGetApps: mock_instance.get_recommended_apps_and_categories.return_value = lang_response mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) - result = RecommendedAppService.get_recommended_apps_and_categories(language) + result = RecommendedAppService.get_recommended_apps_and_categories(db.session, language) assert result["recommended_apps"][0]["id"] == f"app-{language}" mock_instance.get_recommended_apps_and_categories.assert_called_with(language) @@ -196,7 +197,7 @@ class TestRecommendedAppServiceGetApps: mock_instance.get_recommended_apps_and_categories.return_value = response mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) - RecommendedAppService.get_recommended_apps_and_categories("en-US") + RecommendedAppService.get_recommended_apps_and_categories(db.session, "en-US") mock_factory_class.get_recommend_app_factory.assert_called_with(mode) @@ -236,7 +237,7 @@ class TestRecommendedAppServiceGetDetail: mock_instance.get_recommend_app_detail.return_value = expected mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) - result = RecommendedAppService.get_recommend_app_detail(app_id) + result = RecommendedAppService.get_recommend_app_detail(db.session, app_id) assert result == expected mock_instance.get_recommend_app_detail.assert_called_once_with(app_id) @@ -255,7 +256,7 @@ class TestRecommendedAppServiceGetDetail: mock_instance.get_recommend_app_detail.return_value = detail mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) - result = RecommendedAppService.get_recommend_app_detail("test-app") + result = RecommendedAppService.get_recommend_app_detail(db.session, "test-app") assert result is not None mock_instance.get_recommend_app_detail.assert_called_with("test-app") @@ -275,7 +276,7 @@ class TestRecommendedAppServiceTrialFeatures: MagicMock(return_value=SimpleNamespace(enable_trial_app=False)), ) - result = RecommendedAppService.get_recommended_apps_and_categories("en-US") + result = RecommendedAppService.get_recommended_apps_and_categories(db.session, "en-US") assert result == expected retrieval_instance.get_recommended_apps_and_categories.assert_called_once_with("en-US") @@ -306,7 +307,7 @@ class TestRecommendedAppServiceTrialFeatures: MagicMock(return_value=SimpleNamespace(enable_trial_app=True)), ) - result = RecommendedAppService.get_recommended_apps_and_categories("ja-JP") + result = RecommendedAppService.get_recommended_apps_and_categories(db.session, "ja-JP") builtin_instance.fetch_recommended_apps_from_builtin.assert_called_once_with("en-US") assert result["recommended_apps"][0]["can_trial"] is True @@ -342,7 +343,7 @@ class TestRecommendedAppServiceTrialFeatures: MagicMock(return_value=SimpleNamespace(enable_trial_app=True)), ) - result = RecommendedAppService.get_recommend_app_detail(app_id) + result = RecommendedAppService.get_recommend_app_detail(db.session, app_id) assert result is not None detail_result = cast(RecommendedAppPayload, result) @@ -363,7 +364,7 @@ class TestRecommendedAppServiceTrialFeatures: mock_instance.get_recommend_app_detail.return_value = None mock_factory_class.get_recommend_app_factory.return_value = MagicMock(return_value=mock_instance) - result = RecommendedAppService.get_recommend_app_detail("nonexistent") + result = RecommendedAppService.get_recommend_app_detail(db.session, "nonexistent") assert result is None mock_instance.get_recommend_app_detail.assert_called_once_with("nonexistent") @@ -376,7 +377,7 @@ class TestRecommendedAppServiceTrialFeatures: db_session_with_containers.add(AccountTrialAppRecord(app_id=app_id, account_id=account_id, count=3)) db_session_with_containers.commit() - RecommendedAppService.add_trial_app_record(app_id, account_id) + RecommendedAppService.add_trial_app_record(db.session, app_id, account_id) db_session_with_containers.expire_all() record = db_session_with_containers.scalar( @@ -391,7 +392,7 @@ class TestRecommendedAppServiceTrialFeatures: app_id = str(uuid.uuid4()) account_id = str(uuid.uuid4()) - RecommendedAppService.add_trial_app_record(app_id, account_id) + RecommendedAppService.add_trial_app_record(db.session, app_id, account_id) db_session_with_containers.expire_all() record = db_session_with_containers.scalar( diff --git a/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py index 0121d5c424..e0eab9a4d3 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py +++ b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest.mock import ANY, patch from flask import Flask @@ -37,7 +37,7 @@ class TestRecommendedAppListApi: ): result = method(api, make_account("fr-FR")) - service_mock.assert_called_once_with("en-US") + service_mock.assert_called_once_with(ANY, "en-US") assert result == result_data def test_get_fallback_to_user_language(self, app: Flask): @@ -56,7 +56,7 @@ class TestRecommendedAppListApi: ): result = method(api, make_account("fr-FR")) - service_mock.assert_called_once_with("fr-FR") + service_mock.assert_called_once_with(ANY, "fr-FR") assert result == result_data def test_get_fallback_to_default_language(self, app: Flask): @@ -75,7 +75,7 @@ class TestRecommendedAppListApi: ): result = method(api, make_account(None)) - service_mock.assert_called_once_with(module.languages[0]) + service_mock.assert_called_once_with(ANY, module.languages[0]) assert result == result_data @@ -96,7 +96,7 @@ class TestRecommendedAppApi: ): result = method(api, "11111111-1111-1111-1111-111111111111") - service_mock.assert_called_once_with("11111111-1111-1111-1111-111111111111") + service_mock.assert_called_once_with(ANY, "11111111-1111-1111-1111-111111111111") assert result == result_data