chore: use dify_config.BILLING_ENABLED (#36619)

This commit is contained in:
非法操作 2026-05-25 17:41:01 +08:00 committed by GitHub
parent 3a467d1d63
commit fbfb4b3a00
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 93 additions and 22 deletions

View File

@ -82,9 +82,7 @@ def only_edition_self_hosted[**P, R](view: Callable[P, R]) -> Callable[P, R]:
def cloud_edition_billing_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if not features.billing.enabled:
if not dify_config.BILLING_ENABLED:
abort(403, "Billing feature is not enabled.")
return view(*args, **kwargs)
@ -198,15 +196,11 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
with contextlib.suppress(Exception):
_, current_tenant_id = current_account_with_tenant()
features = FeatureService.get_features(current_tenant_id)
if features.billing.enabled:
utm_info = request.cookies.get("utm_info")
if utm_info:
utm_info_dict: UtmInfo = json.loads(utm_info)
OperationService.record_utm(current_tenant_id, utm_info_dict)
utm_info = request.cookies.get("utm_info")
if dify_config.BILLING_ENABLED and utm_info:
_, current_tenant_id = current_account_with_tenant()
utm_info_dict: UtmInfo = json.loads(utm_info)
OperationService.record_utm(current_tenant_id, utm_info_dict)
return view(*args, **kwargs)

View File

@ -42,7 +42,6 @@ from models.dataset import AutomaticRulesConfig, ChildChunk, Dataset, DatasetPro
from models.dataset import Document as DatasetDocument
from models.enums import DataSourceType, IndexingStatus, ProcessRuleMode, SegmentStatus
from models.model import UploadFile
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
@ -282,8 +281,7 @@ class IndexingRunner:
Estimate the indexing for the document.
"""
# check document limit
features = FeatureService.get_features(tenant_id)
if features.billing.enabled:
if dify_config.BILLING_ENABLED:
count = len(extract_settings)
batch_upload_limit = dify_config.BATCH_UPLOAD_LIMIT
if count > batch_upload_limit:

View File

@ -8,8 +8,10 @@ from controllers.console.error import NotInitValidateError, NotSetupError, Unaut
from controllers.console.workspace.error import AccountNotInitializedError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_enabled,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
cloud_utm_record,
enterprise_license_required,
only_edition_cloud,
only_edition_enterprise,
@ -147,6 +149,42 @@ class TestEditionChecks:
assert result == "self_hosted_success"
class TestBillingEnabled:
"""Test billing enabled decorator."""
def test_should_allow_when_billing_config_enabled(self):
"""Test billing decorator uses local config without loading tenant features."""
@cloud_edition_billing_enabled
def billing_view():
return "billing_success"
with patch("controllers.console.wraps.dify_config.BILLING_ENABLED", True):
with patch("controllers.console.wraps.FeatureService.get_features") as get_features:
result = billing_view()
assert result == "billing_success"
get_features.assert_not_called()
def test_should_reject_when_billing_config_disabled(self):
"""Test billing decorator rejects when local billing config is disabled."""
app = create_app_with_login()
@cloud_edition_billing_enabled
def billing_view():
return "billing_success"
with app.test_request_context():
with patch("controllers.console.wraps.dify_config.BILLING_ENABLED", False):
with patch("controllers.console.wraps.FeatureService.get_features") as get_features:
with pytest.raises(Exception) as exc_info:
billing_view()
assert exc_info.value.code == 403
assert "Billing feature is not enabled" in str(exc_info.value.description)
get_features.assert_not_called()
class TestBillingResourceLimits:
"""Test billing resource limit decorators"""
@ -303,6 +341,53 @@ class TestRateLimiting:
mock_session.commit.assert_called_once()
class TestCloudUtmRecord:
"""Test cloud UTM recording decorator."""
def test_should_record_utm_when_billing_config_enabled_and_cookie_exists(self):
"""Test UTM recording uses billing config without loading tenant features."""
app = create_app_with_login()
@cloud_utm_record
def view():
return "success"
with app.test_request_context("/", headers={"Cookie": "utm_info={}"}):
with (
patch("controllers.console.wraps.dify_config.BILLING_ENABLED", True),
patch("controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("u1"), "t1")),
patch("controllers.console.wraps.OperationService.record_utm") as record_utm,
patch("controllers.console.wraps.FeatureService.get_features") as get_features,
):
result = view()
assert result == "success"
record_utm.assert_called_once_with("t1", {})
get_features.assert_not_called()
def test_should_skip_utm_when_billing_config_disabled(self):
"""Test UTM recording skips tenant feature loading when billing config is disabled."""
app = create_app_with_login()
@cloud_utm_record
def view():
return "success"
with app.test_request_context("/", headers={"Cookie": "utm_info={}"}):
with (
patch("controllers.console.wraps.dify_config.BILLING_ENABLED", False),
patch("controllers.console.wraps.current_account_with_tenant") as current_account,
patch("controllers.console.wraps.OperationService.record_utm") as record_utm,
patch("controllers.console.wraps.FeatureService.get_features") as get_features,
):
result = view()
assert result == "success"
current_account.assert_not_called()
record_utm.assert_not_called()
get_features.assert_not_called()
class TestSystemSetup:
"""Test system setup decorator"""

View File

@ -1396,12 +1396,10 @@ class TestIndexingRunnerEstimate:
"""Mock all external dependencies."""
with (
patch("core.indexing_runner.db") as mock_db,
patch("core.indexing_runner.FeatureService") as mock_feature_service,
patch("core.indexing_runner.IndexProcessorFactory") as mock_factory,
):
yield {
"db": mock_db,
"feature_service": mock_feature_service,
"factory": mock_factory,
}
@ -1411,13 +1409,9 @@ class TestIndexingRunnerEstimate:
runner = IndexingRunner()
tenant_id = str(uuid.uuid4())
# Mock feature service
mock_features = MagicMock()
mock_features.billing.enabled = True
mock_dependencies["feature_service"].get_features.return_value = mock_features
# Create too many extract settings
with patch("core.indexing_runner.dify_config") as mock_config:
mock_config.BILLING_ENABLED = True
mock_config.BATCH_UPLOAD_LIMIT = 10
extract_settings = [MagicMock() for _ in range(15)]