diff --git a/api/core/workflow/nodes/llm/llm_node.py b/api/core/workflow/nodes/llm/llm_node.py index d1050a5f5b..9285bbe74e 100644 --- a/api/core/workflow/nodes/llm/llm_node.py +++ b/api/core/workflow/nodes/llm/llm_node.py @@ -3,6 +3,7 @@ from typing import Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus +from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory @@ -21,6 +22,7 @@ from core.workflow.nodes.base_node import BaseNode from core.workflow.nodes.llm.entities import LLMNodeData from extensions.ext_database import db from models.model import Conversation +from models.provider import Provider, ProviderType from models.workflow import WorkflowNodeExecutionStatus @@ -144,10 +146,15 @@ class LLMNode(BaseNode): ) # handle invoke result - return self._handle_invoke_result( + text, usage = self._handle_invoke_result( invoke_result=invoke_result ) + # deduct quota + self._deduct_llm_quota(model_instance=model_instance, usage=usage) + + return text, usage + def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]: """ Handle invoke result @@ -373,6 +380,53 @@ class LLMNode(BaseNode): return prompt_messages, stop + def _deduct_llm_quota(self, model_instance: ModelInstance, usage: LLMUsage) -> None: + """ + Deduct LLM quota + :param model_instance: model instance + :param usage: usage + :return: + """ + provider_model_bundle = model_instance.provider_model_bundle + provider_configuration = provider_model_bundle.configuration + + if provider_configuration.using_provider_type != ProviderType.SYSTEM: + return + + system_configuration = provider_configuration.system_configuration + + quota_unit = None + for quota_configuration in system_configuration.quota_configurations: + if quota_configuration.quota_type == system_configuration.current_quota_type: + quota_unit = quota_configuration.quota_unit + + if quota_configuration.quota_limit == -1: + return + + break + + used_quota = None + if quota_unit: + if quota_unit == QuotaUnit.TOKENS: + used_quota = usage.total_tokens + elif quota_unit == QuotaUnit.CREDITS: + used_quota = 1 + + if 'gpt-4' in model_instance.model: + used_quota = 20 + else: + used_quota = 1 + + if used_quota is not None: + db.session.query(Provider).filter( + Provider.tenant_id == self.tenant_id, + Provider.provider_name == model_instance.provider, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == system_configuration.current_quota_type.value, + Provider.quota_limit > Provider.quota_used + ).update({'quota_used': Provider.quota_used + used_quota}) + db.session.commit() + @classmethod def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]: """