diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index 596e439a7a..fff6e8e77a 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -147,7 +147,7 @@ class LLMNode(BaseNode): ) # deduct quota - self._deduct_llm_quota(model_instance=model_instance, usage=usage) + self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) return text, usage @@ -418,9 +418,11 @@ class LLMNode(BaseNode): return prompt_messages, stop - def _deduct_llm_quota(self, model_instance: ModelInstance, usage: LLMUsage) -> None: + @classmethod + def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: """ Deduct LLM quota + :param tenant_id: tenant id :param model_instance: model instance :param usage: usage :return: @@ -457,7 +459,7 @@ class LLMNode(BaseNode): if used_quota is not None: db.session.query(Provider).filter( - Provider.tenant_id == self.tenant_id, + Provider.tenant_id == tenant_id, Provider.provider_name == model_instance.provider, Provider.provider_type == ProviderType.SYSTEM.value, Provider.quota_type == system_configuration.current_quota_type.value, diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 5d2a76d5e4..d351dfb692 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -4,7 +4,6 @@ from typing import Optional, Union, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus -from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager @@ -21,6 +20,7 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData from core.workflow.entities.node_entities import NodeRunResult, NodeType, SystemVariable from core.workflow.entities.variable_pool import VariablePool from core.workflow.nodes.base_node import BaseNode +from core.workflow.nodes.llm.llm_node import LLMNode from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData from core.workflow.nodes.question_classifier.template_prompts import ( QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1, @@ -33,7 +33,6 @@ from core.workflow.nodes.question_classifier.template_prompts import ( ) from extensions.ext_database import db from models.model import Conversation -from models.provider import Provider, ProviderType from models.workflow import WorkflowNodeExecutionStatus @@ -338,7 +337,7 @@ class QuestionClassifierNode(BaseNode): ) # deduct quota - self._deduct_llm_quota(model_instance=model_instance, usage=usage) + LLMNode.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) return text, usage @@ -371,50 +370,3 @@ class QuestionClassifierNode(BaseNode): usage = LLMUsage.empty_usage() return full_text, usage - - def _deduct_llm_quota(self, model_instance: ModelInstance, usage: LLMUsage) -> None: - """ - Deduct LLM quota - :param model_instance: model instance - :param usage: usage - :return: - """ - provider_model_bundle = model_instance.provider_model_bundle - provider_configuration = provider_model_bundle.configuration - - if provider_configuration.using_provider_type != ProviderType.SYSTEM: - return - - system_configuration = provider_configuration.system_configuration - - quota_unit = None - for quota_configuration in system_configuration.quota_configurations: - if quota_configuration.quota_type == system_configuration.current_quota_type: - quota_unit = quota_configuration.quota_unit - - if quota_configuration.quota_limit == -1: - return - - break - - used_quota = None - if quota_unit: - if quota_unit == QuotaUnit.TOKENS: - used_quota = usage.total_tokens - elif quota_unit == QuotaUnit.CREDITS: - used_quota = 1 - - if 'gpt-4' in model_instance.model: - used_quota = 20 - else: - used_quota = 1 - - if used_quota is not None: - db.session.query(Provider).filter( - Provider.tenant_id == self.tenant_id, - Provider.provider_name == model_instance.provider, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == system_configuration.current_quota_type.value, - Provider.quota_limit > Provider.quota_used - ).update({'quota_used': Provider.quota_used + used_quota}) - db.session.commit()