refactor: consolidate LLM runtime model state on ModelInstance (#32746)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN- 2026-03-01 02:29:32 +08:00 committed by GitHub
parent 48d8667c4f
commit 962df17a15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 375 additions and 324 deletions

View File

@ -110,7 +110,6 @@ ignore_imports =
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
core.workflow.nodes.llm.llm_utils -> configs
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
core.workflow.nodes.llm.llm_utils -> core.model_manager
core.workflow.nodes.llm.protocols -> core.model_manager
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
@ -129,13 +128,9 @@ ignore_imports =
core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model
core.workflow.nodes.question_classifier.question_classifier_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.advanced_prompt_transform
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
core.workflow.nodes.start.entities -> core.app.app_config.entities
core.workflow.nodes.start.start_node -> core.app.app_config.entities

View File

@ -83,14 +83,21 @@ def fetch_model_config(
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
stop: list[str] = []
if "stop" in node_data_model.completion_params:
stop = node_data_model.completion_params.pop("stop")
completion_params = dict(node_data_model.completion_params)
stop = completion_params.pop("stop", [])
if not isinstance(stop, list):
stop = []
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
model_instance.provider = node_data_model.provider
model_instance.model_name = node_data_model.name
model_instance.credentials = credentials
model_instance.parameters = completion_params
model_instance.stop = tuple(stop)
return model_instance, ModelConfigWithCredentialsEntity(
provider=node_data_model.provider,
model=node_data_model.name,
@ -98,6 +105,6 @@ def fetch_model_config(
mode=node_data_model.mode,
provider_model_bundle=provider_model_bundle,
credentials=credentials,
parameters=node_data_model.completion_params,
parameters=completion_params,
stop=stop,
)

View File

@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, final
from typing import TYPE_CHECKING, Any, cast, final
from typing_extensions import override
@ -9,6 +9,9 @@ from core.datasource.datasource_manager import DatasourceManager
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.ssrf_proxy import ssrf_proxy
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.graph_config import NodeConfigDict
@ -23,6 +26,8 @@ from core.workflow.nodes.datasource import DatasourceNode
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
@ -171,6 +176,7 @@ class DifyNodeFactory(NodeFactory):
)
if node_type == NodeType.LLM:
model_instance = self._build_model_instance_for_llm_node(node_data)
return LLMNode(
id=node_id,
config=node_config,
@ -178,6 +184,7 @@ class DifyNodeFactory(NodeFactory):
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
)
if node_type == NodeType.DATASOURCE:
@ -208,6 +215,7 @@ class DifyNodeFactory(NodeFactory):
)
if node_type == NodeType.QUESTION_CLASSIFIER:
model_instance = self._build_model_instance_for_llm_node(node_data)
return QuestionClassifierNode(
id=node_id,
config=node_config,
@ -215,9 +223,11 @@ class DifyNodeFactory(NodeFactory):
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
)
if node_type == NodeType.PARAMETER_EXTRACTOR:
model_instance = self._build_model_instance_for_llm_node(node_data)
return ParameterExtractorNode(
id=node_id,
config=node_config,
@ -225,6 +235,7 @@ class DifyNodeFactory(NodeFactory):
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
model_instance=model_instance,
)
return node_class(
@ -233,3 +244,37 @@ class DifyNodeFactory(NodeFactory):
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)
def _build_model_instance_for_llm_node(self, node_data: Mapping[str, Any]) -> ModelInstance:
node_data_model = ModelConfig.model_validate(node_data["model"])
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
credentials = self._llm_credentials_provider.fetch(node_data_model.provider, node_data_model.name)
model_instance = self._llm_model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
provider_model_bundle = model_instance.provider_model_bundle
provider_model = provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name,
model_type=ModelType.LLM,
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
completion_params = dict(node_data_model.completion_params)
stop = completion_params.pop("stop", [])
if not isinstance(stop, list):
stop = []
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
model_instance.provider = node_data_model.provider
model_instance.model_name = node_data_model.name
model_instance.credentials = credentials
model_instance.parameters = completion_params
model_instance.stop = tuple(stop)
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
return model_instance

View File

@ -1,5 +1,5 @@
import logging
from collections.abc import Callable, Generator, Iterable, Sequence
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from typing import IO, Any, Literal, Optional, Union, cast, overload
from configs import dify_config
@ -38,6 +38,9 @@ class ModelInstance:
self.model_name = model
self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
# Runtime LLM invocation fields.
self.parameters: Mapping[str, Any] = {}
self.stop: Sequence[str] = ()
self.model_type_instance = self.provider_model_bundle.model_type_instance
self.load_balancing_manager = self._get_load_balancing_manager(
configuration=provider_model_bundle.configuration,

View File

@ -4,6 +4,7 @@ from typing import cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
PromptMessage,
@ -44,7 +45,8 @@ class AdvancedPromptTransform(PromptTransform):
context: str | None,
memory_config: MemoryConfig | None,
memory: TokenBufferMemory | None,
model_config: ModelConfigWithCredentialsEntity,
model_config: ModelConfigWithCredentialsEntity | None = None,
model_instance: ModelInstance | None = None,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
prompt_messages = []
@ -59,6 +61,7 @@ class AdvancedPromptTransform(PromptTransform):
memory_config=memory_config,
memory=memory,
model_config=model_config,
model_instance=model_instance,
image_detail_config=image_detail_config,
)
elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template):
@ -71,6 +74,7 @@ class AdvancedPromptTransform(PromptTransform):
memory_config=memory_config,
memory=memory,
model_config=model_config,
model_instance=model_instance,
image_detail_config=image_detail_config,
)
@ -85,7 +89,8 @@ class AdvancedPromptTransform(PromptTransform):
context: str | None,
memory_config: MemoryConfig | None,
memory: TokenBufferMemory | None,
model_config: ModelConfigWithCredentialsEntity,
model_config: ModelConfigWithCredentialsEntity | None = None,
model_instance: ModelInstance | None = None,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
@ -111,6 +116,7 @@ class AdvancedPromptTransform(PromptTransform):
parser=parser,
prompt_inputs=prompt_inputs,
model_config=model_config,
model_instance=model_instance,
)
if query:
@ -146,7 +152,8 @@ class AdvancedPromptTransform(PromptTransform):
context: str | None,
memory_config: MemoryConfig | None,
memory: TokenBufferMemory | None,
model_config: ModelConfigWithCredentialsEntity,
model_config: ModelConfigWithCredentialsEntity | None = None,
model_instance: ModelInstance | None = None,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
"""
@ -198,8 +205,13 @@ class AdvancedPromptTransform(PromptTransform):
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
if memory and memory_config:
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
prompt_messages = self._append_chat_histories(
memory,
memory_config,
prompt_messages,
model_config=model_config,
model_instance=model_instance,
)
if files and query is not None:
for file in files:
prompt_message_contents.append(
@ -276,7 +288,8 @@ class AdvancedPromptTransform(PromptTransform):
role_prefix: MemoryConfig.RolePrefix,
parser: PromptTemplateParser,
prompt_inputs: Mapping[str, str],
model_config: ModelConfigWithCredentialsEntity,
model_config: ModelConfigWithCredentialsEntity | None = None,
model_instance: ModelInstance | None = None,
) -> Mapping[str, str]:
prompt_inputs = dict(prompt_inputs)
if "#histories#" in parser.variable_keys:
@ -286,7 +299,11 @@ class AdvancedPromptTransform(PromptTransform):
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs))
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
rest_tokens = self._calculate_rest_token(
[tmp_human_message],
model_config=model_config,
model_instance=model_instance,
)
histories = self._get_history_messages_from_memory(
memory=memory,

View File

@ -41,7 +41,7 @@ class AgentHistoryPromptTransform(PromptTransform):
if not self.memory:
return prompt_messages
max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config)
max_token_limit = self._calculate_rest_token(self.prompt_messages, model_config=self.model_config)
model_type_instance = self.model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)

View File

@ -4,45 +4,83 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
class PromptTransform:
def _resolve_model_runtime(
self,
*,
model_config: ModelConfigWithCredentialsEntity | None = None,
model_instance: ModelInstance | None = None,
) -> tuple[ModelInstance, AIModelEntity]:
if model_instance is None:
if model_config is None:
raise ValueError("Either model_config or model_instance must be provided.")
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
model_instance.credentials = model_config.credentials
model_instance.parameters = model_config.parameters
model_instance.stop = model_config.stop
model_schema = model_instance.model_type_instance.get_model_schema(
model=model_instance.model_name,
credentials=model_instance.credentials,
)
if model_schema is None:
if model_config is None:
raise ValueError("Model schema not found for the provided model instance.")
model_schema = model_config.model_schema
return model_instance, model_schema
def _append_chat_histories(
self,
memory: TokenBufferMemory,
memory_config: MemoryConfig,
prompt_messages: list[PromptMessage],
model_config: ModelConfigWithCredentialsEntity,
*,
model_config: ModelConfigWithCredentialsEntity | None = None,
model_instance: ModelInstance | None = None,
) -> list[PromptMessage]:
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
rest_tokens = self._calculate_rest_token(
prompt_messages,
model_config=model_config,
model_instance=model_instance,
)
histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
prompt_messages.extend(histories)
return prompt_messages
def _calculate_rest_token(
self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
self,
prompt_messages: list[PromptMessage],
*,
model_config: ModelConfigWithCredentialsEntity | None = None,
model_instance: ModelInstance | None = None,
) -> int:
model_instance, model_schema = self._resolve_model_runtime(
model_config=model_config,
model_instance=model_instance,
)
model_parameters = model_instance.parameters
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
for parameter_rule in model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template or "")
model_parameters.get(parameter_rule.name)
or model_parameters.get(parameter_rule.use_template or "")
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens

View File

@ -252,7 +252,7 @@ class SimplePromptTransform(PromptTransform):
if memory:
tmp_human_message = UserPromptMessage(content=prompt)
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config=model_config)
histories = self._get_history_messages_from_memory(
memory=memory,
memory_config=MemoryConfig(

View File

@ -5,20 +5,16 @@ from sqlalchemy import select, update
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from core.workflow.enums import SystemVariableKey
from core.workflow.file.models import File
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
@ -29,46 +25,14 @@ from models.provider_ids import ModelProviderID
from .exc import InvalidVariableTypeError
def fetch_model_config(
*,
node_data_model: ModelConfig,
credentials_provider: CredentialsProvider,
model_factory: ModelFactory,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name)
model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
provider_model_bundle = model_instance.provider_model_bundle
provider_model = provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name,
model_type=ModelType.LLM,
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema(
model_instance.model_name,
model_instance.credentials,
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
stop: list[str] = []
if "stop" in node_data_model.completion_params:
stop = node_data_model.completion_params.pop("stop")
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
return model_instance, ModelConfigWithCredentialsEntity(
provider=node_data_model.provider,
model=node_data_model.name,
model_schema=model_schema,
mode=node_data_model.mode,
provider_model_bundle=provider_model_bundle,
credentials=credentials,
parameters=node_data_model.completion_params,
stop=stop,
)
raise ValueError(f"Model schema not found for {model_instance.model_name}")
return model_schema
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:

View File

@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import select
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
@ -38,7 +37,7 @@ from core.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
@ -83,7 +82,6 @@ from .entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
)
from .exc import (
InvalidContextStructureError,
@ -116,6 +114,7 @@ class LLMNode(Node[LLMNodeData]):
_llm_file_saver: LLMFileSaver
_credentials_provider: CredentialsProvider
_model_factory: ModelFactory
_model_instance: ModelInstance
def __init__(
self,
@ -126,6 +125,7 @@ class LLMNode(Node[LLMNodeData]):
*,
credentials_provider: CredentialsProvider,
model_factory: ModelFactory,
model_instance: ModelInstance,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
@ -139,6 +139,7 @@ class LLMNode(Node[LLMNodeData]):
self._credentials_provider = credentials_provider
self._model_factory = model_factory
self._model_instance = model_instance
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
@ -202,21 +203,10 @@ class LLMNode(Node[LLMNodeData]):
node_inputs["#context_files#"] = [file.model_dump() for file in context_files]
# fetch model config
model_instance, model_config = self._fetch_model_config(
node_data_model=self.node_data.model,
)
model_name = getattr(model_instance, "model_name", None)
if not isinstance(model_name, str):
model_name = model_config.model
model_provider = getattr(model_instance, "provider", None)
if not isinstance(model_provider, str):
model_provider = model_config.provider
model_schema = model_instance.model_type_instance.get_model_schema(
model_name,
model_instance.credentials,
)
if not model_schema:
raise ValueError(f"Model schema not found for {model_name}")
model_instance = self._model_instance
model_name = model_instance.model_name
model_provider = model_instance.provider
model_stop = model_instance.stop
# fetch memory
memory = llm_utils.fetch_memory(
@ -240,9 +230,7 @@ class LLMNode(Node[LLMNodeData]):
context=context,
memory=memory,
model_instance=model_instance,
model_schema=model_schema,
model_parameters=self.node_data.model.completion_params,
stop=model_config.stop,
stop=model_stop,
prompt_template=self.node_data.prompt_template,
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
@ -254,7 +242,6 @@ class LLMNode(Node[LLMNodeData]):
# handle invoke result
generator = LLMNode.invoke_llm(
node_data_model=self.node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
@ -371,7 +358,6 @@ class LLMNode(Node[LLMNodeData]):
@staticmethod
def invoke_llm(
*,
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None = None,
@ -384,11 +370,10 @@ class LLMNode(Node[LLMNodeData]):
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
model_schema = model_instance.model_type_instance.get_model_schema(
node_data_model.name, model_instance.credentials
)
if not model_schema:
raise ValueError(f"Model schema not found for {node_data_model.name}")
model_parameters = model_instance.parameters
invoke_model_parameters = dict(model_parameters)
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
if structured_output_enabled:
output_schema = LLMNode.fetch_structured_output_schema(
@ -402,7 +387,7 @@ class LLMNode(Node[LLMNodeData]):
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=output_schema,
model_parameters=node_data_model.completion_params,
model_parameters=invoke_model_parameters,
stop=list(stop or []),
stream=True,
user=user_id,
@ -412,7 +397,7 @@ class LLMNode(Node[LLMNodeData]):
invoke_result = model_instance.invoke_llm(
prompt_messages=list(prompt_messages),
model_parameters=node_data_model.completion_params,
model_parameters=invoke_model_parameters,
stop=list(stop or []),
stream=True,
user=user_id,
@ -771,23 +756,6 @@ class LLMNode(Node[LLMNodeData]):
return None
def _fetch_model_config(
self,
*,
node_data_model: ModelConfig,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model, model_config_with_cred = llm_utils.fetch_model_config(
node_data_model=node_data_model,
credentials_provider=self._credentials_provider,
model_factory=self._model_factory,
)
completion_params = model_config_with_cred.parameters
model_config_with_cred.parameters = completion_params
# NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`.
node_data_model.completion_params = completion_params
return model, model_config_with_cred
@staticmethod
def fetch_prompt_messages(
*,
@ -796,8 +764,6 @@ class LLMNode(Node[LLMNodeData]):
context: str | None = None,
memory: TokenBufferMemory | None = None,
model_instance: ModelInstance,
model_schema: AIModelEntity,
model_parameters: Mapping[str, Any],
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
stop: Sequence[str] | None = None,
memory_config: MemoryConfig | None = None,
@ -808,6 +774,7 @@ class LLMNode(Node[LLMNodeData]):
context_files: list[File] | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
if isinstance(prompt_template, list):
# For chat model
@ -826,8 +793,6 @@ class LLMNode(Node[LLMNodeData]):
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
model_schema=model_schema,
model_parameters=model_parameters,
)
# Extend prompt_messages with memory messages
prompt_messages.extend(memory_messages)
@ -865,8 +830,6 @@ class LLMNode(Node[LLMNodeData]):
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
model_schema=model_schema,
model_parameters=model_parameters,
)
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
@ -1316,23 +1279,23 @@ def _calculate_rest_token(
*,
prompt_messages: list[PromptMessage],
model_instance: ModelInstance,
model_schema: AIModelEntity,
model_parameters: Mapping[str, Any],
) -> int:
rest_tokens = 2000
runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
runtime_model_parameters = model_instance.parameters
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_schema.parameter_rules:
for parameter_rule in runtime_model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_parameters.get(parameter_rule.name)
or model_parameters.get(str(parameter_rule.use_template))
runtime_model_parameters.get(parameter_rule.name)
or runtime_model_parameters.get(str(parameter_rule.use_template))
or 0
)
@ -1347,8 +1310,6 @@ def _handle_memory_chat_mode(
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
model_schema: AIModelEntity,
model_parameters: Mapping[str, Any],
) -> Sequence[PromptMessage]:
memory_messages: Sequence[PromptMessage] = []
# Get messages from memory for chat model
@ -1356,8 +1317,6 @@ def _handle_memory_chat_mode(
rest_tokens = _calculate_rest_token(
prompt_messages=[],
model_instance=model_instance,
model_schema=model_schema,
model_parameters=model_parameters,
)
memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
@ -1371,8 +1330,6 @@ def _handle_memory_completion_mode(
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
model_schema: AIModelEntity,
model_parameters: Mapping[str, Any],
) -> str:
memory_text = ""
# Get history text from memory for completion model
@ -1380,8 +1337,6 @@ def _handle_memory_completion_mode(
rest_tokens = _calculate_rest_token(
prompt_messages=[],
model_instance=model_instance,
model_schema=model_schema,
model_parameters=model_parameters,
)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")

View File

@ -5,7 +5,6 @@ import uuid
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import ImagePromptMessageContent
@ -31,7 +30,7 @@ from core.workflow.file import File
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import variable_template_parser
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.nodes.llm import llm_utils
from core.workflow.runtime import VariablePool
from factories.variable_factory import build_segment_with_type
@ -95,8 +94,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_type = NodeType.PARAMETER_EXTRACTOR
_model_instance: ModelInstance | None = None
_model_config: ModelConfigWithCredentialsEntity | None = None
_model_instance: ModelInstance
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
@ -109,6 +107,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
*,
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
) -> None:
super().__init__(
id=id,
@ -118,6 +117,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
)
self._credentials_provider = credentials_provider
self._model_factory = model_factory
self._model_instance = model_instance
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@ -155,18 +155,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
else []
)
model_instance, model_config = self._fetch_model_config(node_data.model)
model_instance = self._model_instance
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise InvalidModelTypeError("Model is not a Large Language Model")
llm_model = model_instance.model_type_instance
model_schema = llm_model.get_model_schema(
model=model_config.model,
credentials=model_config.credentials,
)
if not model_schema:
raise ModelSchemaNotFoundError("Model schema not found")
try:
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
except ValueError as exc:
raise ModelSchemaNotFoundError("Model schema not found") from exc
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
@ -184,7 +180,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data=node_data,
query=query,
variable_pool=self.graph_runtime_state.variable_pool,
model_config=model_config,
model_instance=model_instance,
memory=memory,
files=files,
vision_detail=node_data.vision.configs.detail,
@ -195,7 +191,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
data=node_data,
query=query,
variable_pool=self.graph_runtime_state.variable_pool,
model_config=model_config,
model_instance=model_instance,
memory=memory,
files=files,
vision_detail=node_data.vision.configs.detail,
@ -211,24 +207,23 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
}
process_data = {
"model_mode": model_config.mode,
"model_mode": node_data.model.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode, prompt_messages=prompt_messages
model_mode=node_data.model.mode, prompt_messages=prompt_messages
),
"usage": None,
"function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]),
"tool_call": None,
"model_provider": model_config.provider,
"model_name": model_config.model,
"model_provider": model_instance.provider,
"model_name": model_instance.model_name,
}
try:
text, usage, tool_call = self._invoke(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
tools=prompt_message_tools,
stop=model_config.stop,
stop=model_instance.stop,
)
process_data["usage"] = jsonable_encoder(usage)
process_data["tool_call"] = jsonable_encoder(tool_call)
@ -290,17 +285,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
def _invoke(
self,
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
stop: list[str],
stop: Sequence[str],
) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]:
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data_model.completion_params,
model_parameters=dict(model_instance.parameters),
tools=tools,
stop=stop,
stop=list(stop),
stream=False,
user=self.user_id,
)
@ -324,7 +318,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
@ -337,7 +331,13 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
)
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "")
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query,
variable_pool=variable_pool,
model_instance=model_instance,
context="",
)
prompt_template = self._get_function_calling_prompt_template(
node_data, query, variable_pool, memory, rest_token
)
@ -349,7 +349,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
context="",
memory_config=node_data.memory,
memory=None,
model_config=model_config,
model_instance=model_instance,
image_detail_config=vision_detail,
)
@ -406,7 +406,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
@ -421,7 +421,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data=data,
query=query,
variable_pool=variable_pool,
model_config=model_config,
model_instance=model_instance,
memory=memory,
files=files,
vision_detail=vision_detail,
@ -431,7 +431,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data=data,
query=query,
variable_pool=variable_pool,
model_config=model_config,
model_instance=model_instance,
memory=memory,
files=files,
vision_detail=vision_detail,
@ -444,7 +444,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
@ -454,7 +454,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
node_data=node_data,
query=query,
variable_pool=variable_pool,
model_instance=model_instance,
context="",
)
prompt_template = self._get_prompt_engineering_prompt_template(
node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token
@ -467,7 +471,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
context="",
memory_config=node_data.memory,
memory=memory,
model_config=model_config,
model_instance=model_instance,
image_detail_config=vision_detail,
)
@ -478,7 +482,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
@ -488,7 +492,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
node_data=node_data,
query=query,
variable_pool=variable_pool,
model_instance=model_instance,
context="",
)
prompt_template = self._get_prompt_engineering_prompt_template(
node_data=node_data,
@ -508,7 +516,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
context="",
memory_config=node_data.memory,
memory=None,
model_config=model_config,
model_instance=model_instance,
image_detail_config=vision_detail,
)
@ -769,21 +777,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
context: str | None,
) -> int:
try:
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
except ValueError as exc:
raise ModelSchemaNotFoundError("Model schema not found") from exc
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
model_instance, model_config = self._fetch_model_config(node_data.model)
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise InvalidModelTypeError("Model is not a Large Language Model")
llm_model = model_instance.model_type_instance
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
if not model_schema:
raise ModelSchemaNotFoundError("Model schema not found")
if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
else:
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000)
@ -796,27 +799,28 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
context=context,
memory_config=node_data.memory,
memory=None,
model_config=model_config,
model_instance=model_instance,
)
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
curr_message_tokens = (
model_type_instance.get_num_tokens(model_config.model, model_config.credentials, prompt_messages) + 1000
model_type_instance.get_num_tokens(
model_instance.model_name, model_instance.credentials, prompt_messages
)
+ 1000
) # add 1000 to ensure tool call messages
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
for parameter_rule in model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template or "")
model_instance.parameters.get(parameter_rule.name)
or model_instance.parameters.get(parameter_rule.use_template or "")
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
@ -824,21 +828,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
return rest_tokens
def _fetch_model_config(
self, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config.
"""
if not self._model_instance or not self._model_config:
self._model_instance, self._model_config = llm_utils.fetch_model_config(
node_data_model=node_data_model,
credentials_provider=self._credentials_provider,
model_factory=self._model_factory,
)
return self._model_instance, self._model_config
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@ -3,12 +3,10 @@ import re
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities import GraphInitParams
@ -22,7 +20,12 @@ from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils
from core.workflow.nodes.llm import (
LLMNode,
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
llm_utils,
)
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from libs.json_in_md_parser import parse_and_check_json_markdown
@ -52,6 +55,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
_llm_file_saver: LLMFileSaver
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
_model_instance: ModelInstance
def __init__(
self,
@ -62,6 +66,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
*,
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
model_instance: ModelInstance,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
@ -75,6 +80,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self._credentials_provider = credentials_provider
self._model_factory = model_factory
self._model_instance = model_instance
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
@ -95,18 +101,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None
query = variable.value if variable else None
variables = {"query": query}
# fetch model config
model_instance, model_config = llm_utils.fetch_model_config(
node_data_model=node_data.model,
credentials_provider=self._credentials_provider,
model_factory=self._model_factory,
)
model_schema = model_instance.model_type_instance.get_model_schema(
model_instance.model_name,
model_instance.credentials,
)
if not model_schema:
raise ValueError(f"Model schema not found for {model_instance.model_name}")
# fetch model instance
model_instance = self._model_instance
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
@ -131,7 +127,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
rest_token = self._calculate_rest_token(
node_data=node_data,
query=query or "",
model_config=model_config,
model_instance=model_instance,
context="",
)
prompt_template = self._get_prompt_template(
@ -149,9 +145,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
sys_query="",
memory=memory,
model_instance=model_instance,
model_schema=model_schema,
model_parameters=node_data.model.completion_params,
stop=model_config.stop,
stop=model_instance.stop,
sys_files=files,
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
@ -166,7 +160,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
try:
# handle invoke result
generator = LLMNode.invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
@ -205,14 +198,14 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
category_name = classes_map[category_id_result]
category_id = category_id_result
process_data = {
"model_mode": model_config.mode,
"model_mode": node_data.model.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode, prompt_messages=prompt_messages
model_mode=node_data.model.mode, prompt_messages=prompt_messages
),
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"model_provider": model_config.provider,
"model_name": model_config.model,
"model_provider": model_instance.provider,
"model_name": model_instance.model_name,
}
outputs = {
"class_name": category_name,
@ -285,39 +278,40 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self,
node_data: QuestionClassifierNodeData,
query: str,
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
context: str | None,
) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
prompt_messages = prompt_transform.get_prompt(
prompt_messages, _ = LLMNode.fetch_prompt_messages(
prompt_template=prompt_template,
inputs={},
query="",
files=[],
sys_query="",
sys_files=[],
context=context,
memory_config=node_data.memory,
memory=None,
model_config=model_config,
model_instance=model_instance,
stop=model_instance.stop,
memory_config=node_data.memory,
vision_enabled=False,
vision_detail=node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[],
)
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
for parameter_rule in model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template or "")
model_instance.parameters.get(parameter_rule.name)
or model_instance.parameters.get(parameter_rule.use_template or "")
) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens

View File

@ -48,3 +48,19 @@ def get_mocked_fetch_model_config(
)
return MagicMock(return_value=(model_instance, model_config))
def get_mocked_fetch_model_instance(
provider: str,
model: str,
mode: str,
credentials: dict,
):
mock_fetch_model_config = get_mocked_fetch_model_config(
provider=provider,
model=model,
mode=mode,
credentials=credentials,
)
model_instance, _ = mock_fetch_model_config()
return MagicMock(return_value=model_instance)

View File

@ -5,13 +5,13 @@ from collections.abc import Generator
from unittest.mock import MagicMock, patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.workflow.node_factory import DifyNodeFactory
from core.llm_generator.output_parser.structured_output import _parse_structured_output
from core.model_manager import ModelInstance
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.node_events import StreamCompletedEvent
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
@ -67,21 +67,14 @@ def init_llm_node(config: dict) -> LLMNode:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# Create node factory
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
node = LLMNode(
id=str(uuid.uuid4()),
config=config,
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(),
model_factory=MagicMock(),
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
model_instance=MagicMock(spec=ModelInstance),
)
return node
@ -116,8 +109,7 @@ def test_execute_llm():
db.session.close = MagicMock()
# Mock the _fetch_model_config to avoid database calls
def mock_fetch_model_config(*_args, **_kwargs):
def build_mock_model_instance() -> MagicMock:
from decimal import Decimal
from unittest.mock import MagicMock
@ -125,7 +117,20 @@ def test_execute_llm():
from core.model_runtime.entities.message_entities import AssistantPromptMessage
# Create mock model instance
mock_model_instance = MagicMock()
mock_model_instance = MagicMock(spec=ModelInstance)
mock_model_instance.provider = "openai"
mock_model_instance.model_name = "gpt-3.5-turbo"
mock_model_instance.credentials = {}
mock_model_instance.parameters = {}
mock_model_instance.stop = []
mock_model_instance.model_type_instance = MagicMock()
mock_model_instance.model_type_instance.get_model_schema.return_value = MagicMock(
model_properties={},
parameter_rules=[],
features=[],
)
mock_model_instance.provider_model_bundle = MagicMock()
mock_model_instance.provider_model_bundle.configuration.using_provider_type = "custom"
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
@ -149,14 +154,7 @@ def test_execute_llm():
)
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create mock model config
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.parameters = {}
return mock_model_instance, mock_model_config
return mock_model_instance
# Mock fetch_prompt_messages to avoid database calls
def mock_fetch_prompt_messages_1(**_kwargs):
@ -167,10 +165,9 @@ def test_execute_llm():
UserPromptMessage(content="what's the weather today?"),
], []
with (
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1),
):
node._model_instance = build_mock_model_instance()
with patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1):
# execute node
result = node._run()
assert isinstance(result, Generator)
@ -228,8 +225,7 @@ def test_execute_llm_with_jinja2():
# Mock db.session.close()
db.session.close = MagicMock()
# Mock the _fetch_model_config method
def mock_fetch_model_config(*_args, **_kwargs):
def build_mock_model_instance() -> MagicMock:
from decimal import Decimal
from unittest.mock import MagicMock
@ -237,7 +233,20 @@ def test_execute_llm_with_jinja2():
from core.model_runtime.entities.message_entities import AssistantPromptMessage
# Create mock model instance
mock_model_instance = MagicMock()
mock_model_instance = MagicMock(spec=ModelInstance)
mock_model_instance.provider = "openai"
mock_model_instance.model_name = "gpt-3.5-turbo"
mock_model_instance.credentials = {}
mock_model_instance.parameters = {}
mock_model_instance.stop = []
mock_model_instance.model_type_instance = MagicMock()
mock_model_instance.model_type_instance.get_model_schema.return_value = MagicMock(
model_properties={},
parameter_rules=[],
features=[],
)
mock_model_instance.provider_model_bundle = MagicMock()
mock_model_instance.provider_model_bundle.configuration.using_provider_type = "custom"
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
@ -261,14 +270,7 @@ def test_execute_llm_with_jinja2():
)
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create mock model config
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.parameters = {}
return mock_model_instance, mock_model_config
return mock_model_instance
# Mock fetch_prompt_messages to avoid database calls
def mock_fetch_prompt_messages_2(**_kwargs):
@ -279,10 +281,9 @@ def test_execute_llm_with_jinja2():
UserPromptMessage(content="what's the weather today?"),
], []
with (
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2),
):
node._model_instance = build_mock_model_instance()
with patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2):
# execute node
result = node._run()

View File

@ -4,18 +4,17 @@ import uuid
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.workflow.node_factory import DifyNodeFactory
from core.model_manager import ModelInstance
from core.model_runtime.entities import AssistantPromptMessage
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from models.enums import UserFrom
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
@ -72,14 +71,6 @@ def init_parameter_extractor_node(config: dict):
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# Create node factory
node_factory = DifyNodeFactory(
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
node = ParameterExtractorNode(
id=str(uuid.uuid4()),
config=config,
@ -87,6 +78,7 @@ def init_parameter_extractor_node(config: dict):
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
model_instance=MagicMock(spec=ModelInstance),
)
return node
@ -116,12 +108,12 @@ def test_function_calling_parameter_extractor(setup_model_mock):
}
)
node._fetch_model_config = get_mocked_fetch_model_config(
node._model_instance = get_mocked_fetch_model_instance(
provider="langgenius/openai/openai",
model="gpt-3.5-turbo",
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)
)()
db.session.close = MagicMock()
result = node._run()
@ -157,12 +149,12 @@ def test_instructions(setup_model_mock):
},
)
node._fetch_model_config = get_mocked_fetch_model_config(
node._model_instance = get_mocked_fetch_model_instance(
provider="langgenius/openai/openai",
model="gpt-3.5-turbo",
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)
)()
db.session.close = MagicMock()
result = node._run()
@ -207,12 +199,12 @@ def test_chat_parameter_extractor(setup_model_mock):
},
)
node._fetch_model_config = get_mocked_fetch_model_config(
node._model_instance = get_mocked_fetch_model_instance(
provider="langgenius/openai/openai",
model="gpt-3.5-turbo",
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)
)()
db.session.close = MagicMock()
result = node._run()
@ -258,12 +250,12 @@ def test_completion_parameter_extractor(setup_model_mock):
},
)
node._fetch_model_config = get_mocked_fetch_model_config(
node._model_instance = get_mocked_fetch_model_instance(
provider="langgenius/openai/openai",
model="gpt-3.5-turbo-instruct",
mode="completion",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)
)()
db.session.close = MagicMock()
result = node._run()
@ -383,12 +375,12 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
},
)
node._fetch_model_config = get_mocked_fetch_model_config(
node._model_instance = get_mocked_fetch_model_instance(
provider="langgenius/openai/openai",
model="gpt-3.5-turbo",
mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
)
)()
# Test the mock before running the actual test
monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory"))
db.session.close = MagicMock()

View File

@ -1391,10 +1391,20 @@ class TestWorkflowService:
workflow_service = WorkflowService()
from unittest.mock import patch
from core.app.workflow.node_factory import DifyNodeFactory
from core.model_manager import ModelInstance
# Act
result = workflow_service.run_free_workflow_node(
node_data=node_data, tenant_id=tenant_id, user_id=user_id, node_id=node_id, user_inputs=user_inputs
)
with patch.object(
DifyNodeFactory,
"_build_model_instance_for_llm_node",
return_value=MagicMock(spec=ModelInstance),
):
result = workflow_service.run_free_workflow_node(
node_data=node_data, tenant_id=tenant_id, user_id=user_id, node_id=node_id, user_inputs=user_inputs
)
# Assert
assert result is not None

View File

@ -10,6 +10,7 @@ from collections.abc import Generator, Mapping
from typing import TYPE_CHECKING, Any, Optional
from unittest.mock import MagicMock
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
@ -44,9 +45,10 @@ class MockNodeMixin:
mock_config: Optional["MockConfig"] = None,
**kwargs: Any,
):
if isinstance(self, (LLMNode, QuestionClassifierNode)):
if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)):
kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider))
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance))
super().__init__(
id=id,

View File

@ -9,11 +9,12 @@ This test validates that:
"""
import time
from unittest.mock import patch
from unittest.mock import MagicMock, patch
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.workflow.node_factory import DifyNodeFactory
from core.model_manager import ModelInstance
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
@ -115,7 +116,12 @@ def test_parallel_streaming_workflow():
# Create node factory and graph
node_factory = DifyNodeFactory(graph_init_params=init_params, graph_runtime_state=graph_runtime_state)
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
with patch.object(
DifyNodeFactory,
"_build_model_instance_for_llm_node",
return_value=MagicMock(spec=ModelInstance),
):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
# Create the graph engine
engine = GraphEngine(

View File

@ -547,8 +547,22 @@ class TableTestRunner:
"""Run tests in parallel."""
results = []
flask_app: Any = None
try:
from flask import current_app
flask_app = current_app._get_current_object() # type: ignore[attr-defined]
except RuntimeError:
flask_app = None
def _run_test_case_with_context(test_case: WorkflowTestCase) -> WorkflowTestResult:
if flask_app is None:
return self.run_test_case(test_case)
with flask_app.app_context():
return self.run_test_case(test_case)
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
future_to_test = {executor.submit(self.run_test_case, tc): tc for tc in test_cases}
future_to_test = {executor.submit(_run_test_case_with_context, tc): tc for tc in test_cases}
for future in as_completed(future_to_test):
test_case = future_to_test[future]

View File

@ -9,6 +9,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCre
from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
from core.model_manager import ModelInstance
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
@ -115,6 +116,7 @@ def llm_node(
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,
model_factory=mock_model_factory,
model_instance=mock.MagicMock(spec=ModelInstance),
llm_file_saver=mock_file_saver,
)
return node
@ -601,6 +603,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,
model_factory=mock_model_factory,
model_instance=mock.MagicMock(spec=ModelInstance),
llm_file_saver=mock_file_saver,
)
return node, mock_file_saver