mirror of
https://github.com/langgenius/dify.git
synced 2026-06-07 16:32:01 +08:00
chore: use dify_config.BILLING_ENABLED (#36619)
This commit is contained in:
parent
3a467d1d63
commit
fbfb4b3a00
@ -82,9 +82,7 @@ def only_edition_self_hosted[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
def cloud_edition_billing_enabled[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if not features.billing.enabled:
|
||||
if not dify_config.BILLING_ENABLED:
|
||||
abort(403, "Billing feature is not enabled.")
|
||||
return view(*args, **kwargs)
|
||||
|
||||
@ -198,15 +196,11 @@ def cloud_utm_record[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
with contextlib.suppress(Exception):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
|
||||
if features.billing.enabled:
|
||||
utm_info = request.cookies.get("utm_info")
|
||||
|
||||
if utm_info:
|
||||
utm_info_dict: UtmInfo = json.loads(utm_info)
|
||||
OperationService.record_utm(current_tenant_id, utm_info_dict)
|
||||
utm_info = request.cookies.get("utm_info")
|
||||
if dify_config.BILLING_ENABLED and utm_info:
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
utm_info_dict: UtmInfo = json.loads(utm_info)
|
||||
OperationService.record_utm(current_tenant_id, utm_info_dict)
|
||||
|
||||
return view(*args, **kwargs)
|
||||
|
||||
|
||||
@ -42,7 +42,6 @@ from models.dataset import AutomaticRulesConfig, ChildChunk, Dataset, DatasetPro
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.enums import DataSourceType, IndexingStatus, ProcessRuleMode, SegmentStatus
|
||||
from models.model import UploadFile
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -282,8 +281,7 @@ class IndexingRunner:
|
||||
Estimate the indexing for the document.
|
||||
"""
|
||||
# check document limit
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
if features.billing.enabled:
|
||||
if dify_config.BILLING_ENABLED:
|
||||
count = len(extract_settings)
|
||||
batch_upload_limit = dify_config.BATCH_UPLOAD_LIMIT
|
||||
if count > batch_upload_limit:
|
||||
|
||||
@ -8,8 +8,10 @@ from controllers.console.error import NotInitValidateError, NotSetupError, Unaut
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_enabled,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
cloud_utm_record,
|
||||
enterprise_license_required,
|
||||
only_edition_cloud,
|
||||
only_edition_enterprise,
|
||||
@ -147,6 +149,42 @@ class TestEditionChecks:
|
||||
assert result == "self_hosted_success"
|
||||
|
||||
|
||||
class TestBillingEnabled:
|
||||
"""Test billing enabled decorator."""
|
||||
|
||||
def test_should_allow_when_billing_config_enabled(self):
|
||||
"""Test billing decorator uses local config without loading tenant features."""
|
||||
|
||||
@cloud_edition_billing_enabled
|
||||
def billing_view():
|
||||
return "billing_success"
|
||||
|
||||
with patch("controllers.console.wraps.dify_config.BILLING_ENABLED", True):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features") as get_features:
|
||||
result = billing_view()
|
||||
|
||||
assert result == "billing_success"
|
||||
get_features.assert_not_called()
|
||||
|
||||
def test_should_reject_when_billing_config_disabled(self):
|
||||
"""Test billing decorator rejects when local billing config is disabled."""
|
||||
app = create_app_with_login()
|
||||
|
||||
@cloud_edition_billing_enabled
|
||||
def billing_view():
|
||||
return "billing_success"
|
||||
|
||||
with app.test_request_context():
|
||||
with patch("controllers.console.wraps.dify_config.BILLING_ENABLED", False):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features") as get_features:
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
billing_view()
|
||||
|
||||
assert exc_info.value.code == 403
|
||||
assert "Billing feature is not enabled" in str(exc_info.value.description)
|
||||
get_features.assert_not_called()
|
||||
|
||||
|
||||
class TestBillingResourceLimits:
|
||||
"""Test billing resource limit decorators"""
|
||||
|
||||
@ -303,6 +341,53 @@ class TestRateLimiting:
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestCloudUtmRecord:
|
||||
"""Test cloud UTM recording decorator."""
|
||||
|
||||
def test_should_record_utm_when_billing_config_enabled_and_cookie_exists(self):
|
||||
"""Test UTM recording uses billing config without loading tenant features."""
|
||||
app = create_app_with_login()
|
||||
|
||||
@cloud_utm_record
|
||||
def view():
|
||||
return "success"
|
||||
|
||||
with app.test_request_context("/", headers={"Cookie": "utm_info={}"}):
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config.BILLING_ENABLED", True),
|
||||
patch("controllers.console.wraps.current_account_with_tenant", return_value=(MockUser("u1"), "t1")),
|
||||
patch("controllers.console.wraps.OperationService.record_utm") as record_utm,
|
||||
patch("controllers.console.wraps.FeatureService.get_features") as get_features,
|
||||
):
|
||||
result = view()
|
||||
|
||||
assert result == "success"
|
||||
record_utm.assert_called_once_with("t1", {})
|
||||
get_features.assert_not_called()
|
||||
|
||||
def test_should_skip_utm_when_billing_config_disabled(self):
|
||||
"""Test UTM recording skips tenant feature loading when billing config is disabled."""
|
||||
app = create_app_with_login()
|
||||
|
||||
@cloud_utm_record
|
||||
def view():
|
||||
return "success"
|
||||
|
||||
with app.test_request_context("/", headers={"Cookie": "utm_info={}"}):
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config.BILLING_ENABLED", False),
|
||||
patch("controllers.console.wraps.current_account_with_tenant") as current_account,
|
||||
patch("controllers.console.wraps.OperationService.record_utm") as record_utm,
|
||||
patch("controllers.console.wraps.FeatureService.get_features") as get_features,
|
||||
):
|
||||
result = view()
|
||||
|
||||
assert result == "success"
|
||||
current_account.assert_not_called()
|
||||
record_utm.assert_not_called()
|
||||
get_features.assert_not_called()
|
||||
|
||||
|
||||
class TestSystemSetup:
|
||||
"""Test system setup decorator"""
|
||||
|
||||
|
||||
@ -1396,12 +1396,10 @@ class TestIndexingRunnerEstimate:
|
||||
"""Mock all external dependencies."""
|
||||
with (
|
||||
patch("core.indexing_runner.db") as mock_db,
|
||||
patch("core.indexing_runner.FeatureService") as mock_feature_service,
|
||||
patch("core.indexing_runner.IndexProcessorFactory") as mock_factory,
|
||||
):
|
||||
yield {
|
||||
"db": mock_db,
|
||||
"feature_service": mock_feature_service,
|
||||
"factory": mock_factory,
|
||||
}
|
||||
|
||||
@ -1411,13 +1409,9 @@ class TestIndexingRunnerEstimate:
|
||||
runner = IndexingRunner()
|
||||
tenant_id = str(uuid.uuid4())
|
||||
|
||||
# Mock feature service
|
||||
mock_features = MagicMock()
|
||||
mock_features.billing.enabled = True
|
||||
mock_dependencies["feature_service"].get_features.return_value = mock_features
|
||||
|
||||
# Create too many extract settings
|
||||
with patch("core.indexing_runner.dify_config") as mock_config:
|
||||
mock_config.BILLING_ENABLED = True
|
||||
mock_config.BATCH_UPLOAD_LIMIT = 10
|
||||
extract_settings = [MagicMock() for _ in range(15)]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user