mirror of https://github.com/langgenius/dify.git
add deduct quota for llm node
This commit is contained in:
parent
4d7caa3458
commit
5fe0d50cee
|
|
@ -3,6 +3,7 @@ from typing import Optional, cast
|
|||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.file.file_obj import FileObj
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
|
@ -21,6 +22,7 @@ from core.workflow.nodes.base_node import BaseNode
|
|||
from core.workflow.nodes.llm.entities import LLMNodeData
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
|
|
@ -144,10 +146,15 @@ class LLMNode(BaseNode):
|
|||
)
|
||||
|
||||
# handle invoke result
|
||||
return self._handle_invoke_result(
|
||||
text, usage = self._handle_invoke_result(
|
||||
invoke_result=invoke_result
|
||||
)
|
||||
|
||||
# deduct quota
|
||||
self._deduct_llm_quota(model_instance=model_instance, usage=usage)
|
||||
|
||||
return text, usage
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
|
||||
"""
|
||||
Handle invoke result
|
||||
|
|
@ -373,6 +380,53 @@ class LLMNode(BaseNode):
|
|||
|
||||
return prompt_messages, stop
|
||||
|
||||
def _deduct_llm_quota(self, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
"""
|
||||
Deduct LLM quota
|
||||
:param model_instance: model instance
|
||||
:param usage: usage
|
||||
:return:
|
||||
"""
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return
|
||||
|
||||
break
|
||||
|
||||
used_quota = None
|
||||
if quota_unit:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = usage.total_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = 1
|
||||
|
||||
if 'gpt-4' in model_instance.model:
|
||||
used_quota = 20
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None:
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == self.tenant_id,
|
||||
Provider.provider_name == model_instance.provider,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used
|
||||
).update({'quota_used': Provider.quota_used + used_quota})
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue