From 962df17a1552e663069d58188947a09e28e0720d Mon Sep 17 00:00:00 2001 From: -LAN- Date: Sun, 1 Mar 2026 02:29:32 +0800 Subject: [PATCH] refactor: consolidate LLM runtime model state on ModelInstance (#32746) Signed-off-by: -LAN- --- api/.importlinter | 5 - api/core/app/llm/model_access.py | 15 +- api/core/app/workflow/node_factory.py | 47 ++++++- api/core/model_manager.py | 5 +- api/core/prompt/advanced_prompt_transform.py | 31 +++- .../prompt/agent_history_prompt_transform.py | 2 +- api/core/prompt/prompt_transform.py | 62 ++++++-- api/core/prompt/simple_prompt_transform.py | 2 +- api/core/workflow/nodes/llm/llm_utils.py | 50 +------ api/core/workflow/nodes/llm/node.py | 89 +++--------- .../parameter_extractor_node.py | 133 ++++++++---------- .../question_classifier_node.py | 74 +++++----- .../workflow/nodes/__mock/model.py | 16 +++ .../workflow/nodes/test_llm.py | 85 +++++------ .../nodes/test_parameter_extractor.py | 34 ++--- .../services/test_workflow_service.py | 16 ++- .../workflow/graph_engine/test_mock_nodes.py | 4 +- .../test_parallel_streaming_workflow.py | 10 +- .../graph_engine/test_table_runner.py | 16 ++- .../core/workflow/nodes/llm/test_node.py | 3 + 20 files changed, 375 insertions(+), 324 deletions(-) diff --git a/api/.importlinter b/api/.importlinter index 725999c28e..c180f8d76b 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -110,7 +110,6 @@ ignore_imports = core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory core.workflow.nodes.llm.llm_utils -> configs - core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities core.workflow.nodes.llm.llm_utils -> core.model_manager core.workflow.nodes.llm.protocols -> core.model_manager core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model @@ -129,13 +128,9 @@ ignore_imports = core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities - core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities - core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model - core.workflow.nodes.question_classifier.question_classifier_node -> core.app.entities.app_invoke_entities - core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.advanced_prompt_transform core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform core.workflow.nodes.start.entities -> core.app.app_config.entities core.workflow.nodes.start.start_node -> core.app.app_config.entities diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index 2b162920ee..ebae830389 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -83,14 +83,21 @@ def fetch_model_config( raise ModelNotExistError(f"Model {node_data_model.name} not exist.") provider_model.raise_for_status() - stop: list[str] = [] - if "stop" in node_data_model.completion_params: - stop = node_data_model.completion_params.pop("stop") + completion_params = dict(node_data_model.completion_params) + stop = completion_params.pop("stop", []) + if not isinstance(stop, list): + stop = [] model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) if not model_schema: raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + model_instance.provider = node_data_model.provider + model_instance.model_name = node_data_model.name + model_instance.credentials = credentials + model_instance.parameters = completion_params + model_instance.stop = tuple(stop) + return model_instance, ModelConfigWithCredentialsEntity( provider=node_data_model.provider, model=node_data_model.name, @@ -98,6 +105,6 @@ def fetch_model_config( mode=node_data_model.mode, provider_model_bundle=provider_model_bundle, credentials=credentials, - parameters=node_data_model.completion_params, + parameters=completion_params, stop=stop, ) diff --git a/api/core/app/workflow/node_factory.py b/api/core/app/workflow/node_factory.py index 3eeb1d5d58..159500a609 100644 --- a/api/core/app/workflow/node_factory.py +++ b/api/core/app/workflow/node_factory.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, final +from typing import TYPE_CHECKING, Any, cast, final from typing_extensions import override @@ -9,6 +9,9 @@ from core.datasource.datasource_manager import DatasourceManager from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.ssrf_proxy import ssrf_proxy +from core.model_manager import ModelInstance +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.tools.tool_file_manager import ToolFileManager from core.workflow.entities.graph_config import NodeConfigDict @@ -23,6 +26,8 @@ from core.workflow.nodes.datasource import DatasourceNode from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.llm.entities import ModelConfig +from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError from core.workflow.nodes.llm.node import LLMNode from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode @@ -171,6 +176,7 @@ class DifyNodeFactory(NodeFactory): ) if node_type == NodeType.LLM: + model_instance = self._build_model_instance_for_llm_node(node_data) return LLMNode( id=node_id, config=node_config, @@ -178,6 +184,7 @@ class DifyNodeFactory(NodeFactory): graph_runtime_state=self.graph_runtime_state, credentials_provider=self._llm_credentials_provider, model_factory=self._llm_model_factory, + model_instance=model_instance, ) if node_type == NodeType.DATASOURCE: @@ -208,6 +215,7 @@ class DifyNodeFactory(NodeFactory): ) if node_type == NodeType.QUESTION_CLASSIFIER: + model_instance = self._build_model_instance_for_llm_node(node_data) return QuestionClassifierNode( id=node_id, config=node_config, @@ -215,9 +223,11 @@ class DifyNodeFactory(NodeFactory): graph_runtime_state=self.graph_runtime_state, credentials_provider=self._llm_credentials_provider, model_factory=self._llm_model_factory, + model_instance=model_instance, ) if node_type == NodeType.PARAMETER_EXTRACTOR: + model_instance = self._build_model_instance_for_llm_node(node_data) return ParameterExtractorNode( id=node_id, config=node_config, @@ -225,6 +235,7 @@ class DifyNodeFactory(NodeFactory): graph_runtime_state=self.graph_runtime_state, credentials_provider=self._llm_credentials_provider, model_factory=self._llm_model_factory, + model_instance=model_instance, ) return node_class( @@ -233,3 +244,37 @@ class DifyNodeFactory(NodeFactory): graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, ) + + def _build_model_instance_for_llm_node(self, node_data: Mapping[str, Any]) -> ModelInstance: + node_data_model = ModelConfig.model_validate(node_data["model"]) + if not node_data_model.mode: + raise LLMModeRequiredError("LLM mode is required.") + + credentials = self._llm_credentials_provider.fetch(node_data_model.provider, node_data_model.name) + model_instance = self._llm_model_factory.init_model_instance(node_data_model.provider, node_data_model.name) + provider_model_bundle = model_instance.provider_model_bundle + + provider_model = provider_model_bundle.configuration.get_provider_model( + model=node_data_model.name, + model_type=ModelType.LLM, + ) + if provider_model is None: + raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + provider_model.raise_for_status() + + completion_params = dict(node_data_model.completion_params) + stop = completion_params.pop("stop", []) + if not isinstance(stop, list): + stop = [] + + model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) + if not model_schema: + raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + + model_instance.provider = node_data_model.provider + model_instance.model_name = node_data_model.name + model_instance.credentials = credentials + model_instance.parameters = completion_params + model_instance.stop = tuple(stop) + model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) + return model_instance diff --git a/api/core/model_manager.py b/api/core/model_manager.py index ac096c5e54..2b3a3be1b9 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,5 +1,5 @@ import logging -from collections.abc import Callable, Generator, Iterable, Sequence +from collections.abc import Callable, Generator, Iterable, Mapping, Sequence from typing import IO, Any, Literal, Optional, Union, cast, overload from configs import dify_config @@ -38,6 +38,9 @@ class ModelInstance: self.model_name = model self.provider = provider_model_bundle.configuration.provider.provider self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) + # Runtime LLM invocation fields. + self.parameters: Mapping[str, Any] = {} + self.stop: Sequence[str] = () self.model_type_instance = self.provider_model_bundle.model_type_instance self.load_balancing_manager = self._get_load_balancing_manager( configuration=provider_model_bundle.configuration, diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index fd1b7d838c..771b6be332 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -4,6 +4,7 @@ from typing import cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance from core.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, @@ -44,7 +45,8 @@ class AdvancedPromptTransform(PromptTransform): context: str | None, memory_config: MemoryConfig | None, memory: TokenBufferMemory | None, - model_config: ModelConfigWithCredentialsEntity, + model_config: ModelConfigWithCredentialsEntity | None = None, + model_instance: ModelInstance | None = None, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: prompt_messages = [] @@ -59,6 +61,7 @@ class AdvancedPromptTransform(PromptTransform): memory_config=memory_config, memory=memory, model_config=model_config, + model_instance=model_instance, image_detail_config=image_detail_config, ) elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template): @@ -71,6 +74,7 @@ class AdvancedPromptTransform(PromptTransform): memory_config=memory_config, memory=memory, model_config=model_config, + model_instance=model_instance, image_detail_config=image_detail_config, ) @@ -85,7 +89,8 @@ class AdvancedPromptTransform(PromptTransform): context: str | None, memory_config: MemoryConfig | None, memory: TokenBufferMemory | None, - model_config: ModelConfigWithCredentialsEntity, + model_config: ModelConfigWithCredentialsEntity | None = None, + model_instance: ModelInstance | None = None, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ @@ -111,6 +116,7 @@ class AdvancedPromptTransform(PromptTransform): parser=parser, prompt_inputs=prompt_inputs, model_config=model_config, + model_instance=model_instance, ) if query: @@ -146,7 +152,8 @@ class AdvancedPromptTransform(PromptTransform): context: str | None, memory_config: MemoryConfig | None, memory: TokenBufferMemory | None, - model_config: ModelConfigWithCredentialsEntity, + model_config: ModelConfigWithCredentialsEntity | None = None, + model_instance: ModelInstance | None = None, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ) -> list[PromptMessage]: """ @@ -198,8 +205,13 @@ class AdvancedPromptTransform(PromptTransform): prompt_message_contents: list[PromptMessageContentUnionTypes] = [] if memory and memory_config: - prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config) - + prompt_messages = self._append_chat_histories( + memory, + memory_config, + prompt_messages, + model_config=model_config, + model_instance=model_instance, + ) if files and query is not None: for file in files: prompt_message_contents.append( @@ -276,7 +288,8 @@ class AdvancedPromptTransform(PromptTransform): role_prefix: MemoryConfig.RolePrefix, parser: PromptTemplateParser, prompt_inputs: Mapping[str, str], - model_config: ModelConfigWithCredentialsEntity, + model_config: ModelConfigWithCredentialsEntity | None = None, + model_instance: ModelInstance | None = None, ) -> Mapping[str, str]: prompt_inputs = dict(prompt_inputs) if "#histories#" in parser.variable_keys: @@ -286,7 +299,11 @@ class AdvancedPromptTransform(PromptTransform): prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs} tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs)) - rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) + rest_tokens = self._calculate_rest_token( + [tmp_human_message], + model_config=model_config, + model_instance=model_instance, + ) histories = self._get_history_messages_from_memory( memory=memory, diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index 2b32062140..c1ae47709f 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -41,7 +41,7 @@ class AgentHistoryPromptTransform(PromptTransform): if not self.memory: return prompt_messages - max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config) + max_token_limit = self._calculate_rest_token(self.prompt_messages, model_config=self.model_config) model_type_instance = self.model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index a6e873d587..22ef5809bb 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -4,45 +4,83 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import PromptMessage -from core.model_runtime.entities.model_entities import ModelPropertyKey +from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey from core.prompt.entities.advanced_prompt_entities import MemoryConfig class PromptTransform: + def _resolve_model_runtime( + self, + *, + model_config: ModelConfigWithCredentialsEntity | None = None, + model_instance: ModelInstance | None = None, + ) -> tuple[ModelInstance, AIModelEntity]: + if model_instance is None: + if model_config is None: + raise ValueError("Either model_config or model_instance must be provided.") + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model + ) + model_instance.credentials = model_config.credentials + model_instance.parameters = model_config.parameters + model_instance.stop = model_config.stop + + model_schema = model_instance.model_type_instance.get_model_schema( + model=model_instance.model_name, + credentials=model_instance.credentials, + ) + if model_schema is None: + if model_config is None: + raise ValueError("Model schema not found for the provided model instance.") + model_schema = model_config.model_schema + + return model_instance, model_schema + def _append_chat_histories( self, memory: TokenBufferMemory, memory_config: MemoryConfig, prompt_messages: list[PromptMessage], - model_config: ModelConfigWithCredentialsEntity, + *, + model_config: ModelConfigWithCredentialsEntity | None = None, + model_instance: ModelInstance | None = None, ) -> list[PromptMessage]: - rest_tokens = self._calculate_rest_token(prompt_messages, model_config) + rest_tokens = self._calculate_rest_token( + prompt_messages, + model_config=model_config, + model_instance=model_instance, + ) histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens) prompt_messages.extend(histories) return prompt_messages def _calculate_rest_token( - self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity + self, + prompt_messages: list[PromptMessage], + *, + model_config: ModelConfigWithCredentialsEntity | None = None, + model_instance: ModelInstance | None = None, ) -> int: + model_instance, model_schema = self._resolve_model_runtime( + model_config=model_config, + model_instance=model_instance, + ) + model_parameters = model_instance.parameters rest_tokens = 2000 - model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: - model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, model=model_config.model - ) - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) max_tokens = 0 - for parameter_rule in model_config.model_schema.parameter_rules: + for parameter_rule in model_schema.parameter_rules: if parameter_rule.name == "max_tokens" or ( parameter_rule.use_template and parameter_rule.use_template == "max_tokens" ): max_tokens = ( - model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template or "") + model_parameters.get(parameter_rule.name) + or model_parameters.get(parameter_rule.use_template or "") ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index d6abbaaa69..936a093488 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -252,7 +252,7 @@ class SimplePromptTransform(PromptTransform): if memory: tmp_human_message = UserPromptMessage(content=prompt) - rest_tokens = self._calculate_rest_token([tmp_human_message], model_config) + rest_tokens = self._calculate_rest_token([tmp_human_message], model_config=model_config) histories = self._get_history_messages_from_memory( memory=memory, memory_config=MemoryConfig( diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index 341a1c1a4c..cf509f65f0 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -5,20 +5,16 @@ from sqlalchemy import select, update from sqlalchemy.orm import Session from configs import dify_config -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment from core.workflow.enums import SystemVariableKey from core.workflow.file.models import File -from core.workflow.nodes.llm.entities import ModelConfig -from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory from core.workflow.runtime import VariablePool from extensions.ext_database import db from libs.datetime_utils import naive_utc_now @@ -29,46 +25,14 @@ from models.provider_ids import ModelProviderID from .exc import InvalidVariableTypeError -def fetch_model_config( - *, - node_data_model: ModelConfig, - credentials_provider: CredentialsProvider, - model_factory: ModelFactory, -) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: - if not node_data_model.mode: - raise LLMModeRequiredError("LLM mode is required.") - - credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name) - model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name) - provider_model_bundle = model_instance.provider_model_bundle - - provider_model = provider_model_bundle.configuration.get_provider_model( - model=node_data_model.name, - model_type=ModelType.LLM, +def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity: + model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema( + model_instance.model_name, + model_instance.credentials, ) - if provider_model is None: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - provider_model.raise_for_status() - - stop: list[str] = [] - if "stop" in node_data_model.completion_params: - stop = node_data_model.completion_params.pop("stop") - - model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) if not model_schema: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - - model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) - return model_instance, ModelConfigWithCredentialsEntity( - provider=node_data_model.provider, - model=node_data_model.name, - model_schema=model_schema, - mode=node_data_model.mode, - provider_model_bundle=provider_model_bundle, - credentials=credentials, - parameters=node_data_model.completion_params, - stop=stop, - ) + raise ValueError(f"Model schema not found for {model_instance.model_name}") + return model_schema def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]: diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 0259434d90..ec23fd7231 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Any, Literal from sqlalchemy import select -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.helper.code_executor import CodeExecutor, CodeLanguage from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output @@ -38,7 +37,7 @@ from core.model_runtime.entities.message_entities import ( SystemPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey +from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil @@ -83,7 +82,6 @@ from .entities import ( LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, LLMNodeData, - ModelConfig, ) from .exc import ( InvalidContextStructureError, @@ -116,6 +114,7 @@ class LLMNode(Node[LLMNodeData]): _llm_file_saver: LLMFileSaver _credentials_provider: CredentialsProvider _model_factory: ModelFactory + _model_instance: ModelInstance def __init__( self, @@ -126,6 +125,7 @@ class LLMNode(Node[LLMNodeData]): *, credentials_provider: CredentialsProvider, model_factory: ModelFactory, + model_instance: ModelInstance, llm_file_saver: LLMFileSaver | None = None, ): super().__init__( @@ -139,6 +139,7 @@ class LLMNode(Node[LLMNodeData]): self._credentials_provider = credentials_provider self._model_factory = model_factory + self._model_instance = model_instance if llm_file_saver is None: llm_file_saver = FileSaverImpl( @@ -202,21 +203,10 @@ class LLMNode(Node[LLMNodeData]): node_inputs["#context_files#"] = [file.model_dump() for file in context_files] # fetch model config - model_instance, model_config = self._fetch_model_config( - node_data_model=self.node_data.model, - ) - model_name = getattr(model_instance, "model_name", None) - if not isinstance(model_name, str): - model_name = model_config.model - model_provider = getattr(model_instance, "provider", None) - if not isinstance(model_provider, str): - model_provider = model_config.provider - model_schema = model_instance.model_type_instance.get_model_schema( - model_name, - model_instance.credentials, - ) - if not model_schema: - raise ValueError(f"Model schema not found for {model_name}") + model_instance = self._model_instance + model_name = model_instance.model_name + model_provider = model_instance.provider + model_stop = model_instance.stop # fetch memory memory = llm_utils.fetch_memory( @@ -240,9 +230,7 @@ class LLMNode(Node[LLMNodeData]): context=context, memory=memory, model_instance=model_instance, - model_schema=model_schema, - model_parameters=self.node_data.model.completion_params, - stop=model_config.stop, + stop=model_stop, prompt_template=self.node_data.prompt_template, memory_config=self.node_data.memory, vision_enabled=self.node_data.vision.enabled, @@ -254,7 +242,6 @@ class LLMNode(Node[LLMNodeData]): # handle invoke result generator = LLMNode.invoke_llm( - node_data_model=self.node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, @@ -371,7 +358,6 @@ class LLMNode(Node[LLMNodeData]): @staticmethod def invoke_llm( *, - node_data_model: ModelConfig, model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], stop: Sequence[str] | None = None, @@ -384,11 +370,10 @@ class LLMNode(Node[LLMNodeData]): node_type: NodeType, reasoning_format: Literal["separated", "tagged"] = "tagged", ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: - model_schema = model_instance.model_type_instance.get_model_schema( - node_data_model.name, model_instance.credentials - ) - if not model_schema: - raise ValueError(f"Model schema not found for {node_data_model.name}") + model_parameters = model_instance.parameters + invoke_model_parameters = dict(model_parameters) + + model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) if structured_output_enabled: output_schema = LLMNode.fetch_structured_output_schema( @@ -402,7 +387,7 @@ class LLMNode(Node[LLMNodeData]): model_instance=model_instance, prompt_messages=prompt_messages, json_schema=output_schema, - model_parameters=node_data_model.completion_params, + model_parameters=invoke_model_parameters, stop=list(stop or []), stream=True, user=user_id, @@ -412,7 +397,7 @@ class LLMNode(Node[LLMNodeData]): invoke_result = model_instance.invoke_llm( prompt_messages=list(prompt_messages), - model_parameters=node_data_model.completion_params, + model_parameters=invoke_model_parameters, stop=list(stop or []), stream=True, user=user_id, @@ -771,23 +756,6 @@ class LLMNode(Node[LLMNodeData]): return None - def _fetch_model_config( - self, - *, - node_data_model: ModelConfig, - ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: - model, model_config_with_cred = llm_utils.fetch_model_config( - node_data_model=node_data_model, - credentials_provider=self._credentials_provider, - model_factory=self._model_factory, - ) - completion_params = model_config_with_cred.parameters - - model_config_with_cred.parameters = completion_params - # NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`. - node_data_model.completion_params = completion_params - return model, model_config_with_cred - @staticmethod def fetch_prompt_messages( *, @@ -796,8 +764,6 @@ class LLMNode(Node[LLMNodeData]): context: str | None = None, memory: TokenBufferMemory | None = None, model_instance: ModelInstance, - model_schema: AIModelEntity, - model_parameters: Mapping[str, Any], prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, stop: Sequence[str] | None = None, memory_config: MemoryConfig | None = None, @@ -808,6 +774,7 @@ class LLMNode(Node[LLMNodeData]): context_files: list[File] | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] + model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) if isinstance(prompt_template, list): # For chat model @@ -826,8 +793,6 @@ class LLMNode(Node[LLMNodeData]): memory=memory, memory_config=memory_config, model_instance=model_instance, - model_schema=model_schema, - model_parameters=model_parameters, ) # Extend prompt_messages with memory messages prompt_messages.extend(memory_messages) @@ -865,8 +830,6 @@ class LLMNode(Node[LLMNodeData]): memory=memory, memory_config=memory_config, model_instance=model_instance, - model_schema=model_schema, - model_parameters=model_parameters, ) # Insert histories into the prompt prompt_content = prompt_messages[0].content @@ -1316,23 +1279,23 @@ def _calculate_rest_token( *, prompt_messages: list[PromptMessage], model_instance: ModelInstance, - model_schema: AIModelEntity, - model_parameters: Mapping[str, Any], ) -> int: rest_tokens = 2000 + runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) + runtime_model_parameters = model_instance.parameters - model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) max_tokens = 0 - for parameter_rule in model_schema.parameter_rules: + for parameter_rule in runtime_model_schema.parameter_rules: if parameter_rule.name == "max_tokens" or ( parameter_rule.use_template and parameter_rule.use_template == "max_tokens" ): max_tokens = ( - model_parameters.get(parameter_rule.name) - or model_parameters.get(str(parameter_rule.use_template)) + runtime_model_parameters.get(parameter_rule.name) + or runtime_model_parameters.get(str(parameter_rule.use_template)) or 0 ) @@ -1347,8 +1310,6 @@ def _handle_memory_chat_mode( memory: TokenBufferMemory | None, memory_config: MemoryConfig | None, model_instance: ModelInstance, - model_schema: AIModelEntity, - model_parameters: Mapping[str, Any], ) -> Sequence[PromptMessage]: memory_messages: Sequence[PromptMessage] = [] # Get messages from memory for chat model @@ -1356,8 +1317,6 @@ def _handle_memory_chat_mode( rest_tokens = _calculate_rest_token( prompt_messages=[], model_instance=model_instance, - model_schema=model_schema, - model_parameters=model_parameters, ) memory_messages = memory.get_history_prompt_messages( max_token_limit=rest_tokens, @@ -1371,8 +1330,6 @@ def _handle_memory_completion_mode( memory: TokenBufferMemory | None, memory_config: MemoryConfig | None, model_instance: ModelInstance, - model_schema: AIModelEntity, - model_parameters: Mapping[str, Any], ) -> str: memory_text = "" # Get history text from memory for completion model @@ -1380,8 +1337,6 @@ def _handle_memory_completion_mode( rest_tokens = _calculate_rest_token( prompt_messages=[], model_instance=model_instance, - model_schema=model_schema, - model_parameters=model_parameters, ) if not memory_config.role_prefix: raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index f549d44efa..93402d5084 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -5,7 +5,6 @@ import uuid from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, cast -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities import ImagePromptMessageContent @@ -31,7 +30,7 @@ from core.workflow.file import File from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base import variable_template_parser from core.workflow.nodes.base.node import Node -from core.workflow.nodes.llm import ModelConfig, llm_utils +from core.workflow.nodes.llm import llm_utils from core.workflow.runtime import VariablePool from factories.variable_factory import build_segment_with_type @@ -95,8 +94,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_type = NodeType.PARAMETER_EXTRACTOR - _model_instance: ModelInstance | None = None - _model_config: ModelConfigWithCredentialsEntity | None = None + _model_instance: ModelInstance _credentials_provider: "CredentialsProvider" _model_factory: "ModelFactory" @@ -109,6 +107,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): *, credentials_provider: "CredentialsProvider", model_factory: "ModelFactory", + model_instance: ModelInstance, ) -> None: super().__init__( id=id, @@ -118,6 +117,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): ) self._credentials_provider = credentials_provider self._model_factory = model_factory + self._model_instance = model_instance @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: @@ -155,18 +155,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): else [] ) - model_instance, model_config = self._fetch_model_config(node_data.model) + model_instance = self._model_instance if not isinstance(model_instance.model_type_instance, LargeLanguageModel): raise InvalidModelTypeError("Model is not a Large Language Model") - llm_model = model_instance.model_type_instance - model_schema = llm_model.get_model_schema( - model=model_config.model, - credentials=model_config.credentials, - ) - if not model_schema: - raise ModelSchemaNotFoundError("Model schema not found") - + try: + model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) + except ValueError as exc: + raise ModelSchemaNotFoundError("Model schema not found") from exc # fetch memory memory = llm_utils.fetch_memory( variable_pool=variable_pool, @@ -184,7 +180,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_data=node_data, query=query, variable_pool=self.graph_runtime_state.variable_pool, - model_config=model_config, + model_instance=model_instance, memory=memory, files=files, vision_detail=node_data.vision.configs.detail, @@ -195,7 +191,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): data=node_data, query=query, variable_pool=self.graph_runtime_state.variable_pool, - model_config=model_config, + model_instance=model_instance, memory=memory, files=files, vision_detail=node_data.vision.configs.detail, @@ -211,24 +207,23 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): } process_data = { - "model_mode": model_config.mode, + "model_mode": node_data.model.mode, "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, prompt_messages=prompt_messages + model_mode=node_data.model.mode, prompt_messages=prompt_messages ), "usage": None, "function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), "tool_call": None, - "model_provider": model_config.provider, - "model_name": model_config.model, + "model_provider": model_instance.provider, + "model_name": model_instance.model_name, } try: text, usage, tool_call = self._invoke( - node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, tools=prompt_message_tools, - stop=model_config.stop, + stop=model_instance.stop, ) process_data["usage"] = jsonable_encoder(usage) process_data["tool_call"] = jsonable_encoder(tool_call) @@ -290,17 +285,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): def _invoke( self, - node_data_model: ModelConfig, model_instance: ModelInstance, prompt_messages: list[PromptMessage], tools: list[PromptMessageTool], - stop: list[str], + stop: Sequence[str], ) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]: invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, - model_parameters=node_data_model.completion_params, + model_parameters=dict(model_instance.parameters), tools=tools, - stop=stop, + stop=list(stop), stream=False, user=self.user_id, ) @@ -324,7 +318,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, memory: TokenBufferMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -337,7 +331,13 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): ) prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "") + rest_token = self._calculate_rest_token( + node_data=node_data, + query=query, + variable_pool=variable_pool, + model_instance=model_instance, + context="", + ) prompt_template = self._get_function_calling_prompt_template( node_data, query, variable_pool, memory, rest_token ) @@ -349,7 +349,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): context="", memory_config=node_data.memory, memory=None, - model_config=model_config, + model_instance=model_instance, image_detail_config=vision_detail, ) @@ -406,7 +406,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, memory: TokenBufferMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -421,7 +421,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_data=data, query=query, variable_pool=variable_pool, - model_config=model_config, + model_instance=model_instance, memory=memory, files=files, vision_detail=vision_detail, @@ -431,7 +431,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_data=data, query=query, variable_pool=variable_pool, - model_config=model_config, + model_instance=model_instance, memory=memory, files=files, vision_detail=vision_detail, @@ -444,7 +444,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, memory: TokenBufferMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -454,7 +454,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) rest_token = self._calculate_rest_token( - node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context="" + node_data=node_data, + query=query, + variable_pool=variable_pool, + model_instance=model_instance, + context="", ) prompt_template = self._get_prompt_engineering_prompt_template( node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token @@ -467,7 +471,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): context="", memory_config=node_data.memory, memory=memory, - model_config=model_config, + model_instance=model_instance, image_detail_config=vision_detail, ) @@ -478,7 +482,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, memory: TokenBufferMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -488,7 +492,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): """ prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) rest_token = self._calculate_rest_token( - node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context="" + node_data=node_data, + query=query, + variable_pool=variable_pool, + model_instance=model_instance, + context="", ) prompt_template = self._get_prompt_engineering_prompt_template( node_data=node_data, @@ -508,7 +516,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): context="", memory_config=node_data.memory, memory=None, - model_config=model_config, + model_instance=model_instance, image_detail_config=vision_detail, ) @@ -769,21 +777,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, context: str | None, ) -> int: + try: + model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) + except ValueError as exc: + raise ModelSchemaNotFoundError("Model schema not found") from exc prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - model_instance, model_config = self._fetch_model_config(node_data.model) - if not isinstance(model_instance.model_type_instance, LargeLanguageModel): - raise InvalidModelTypeError("Model is not a Large Language Model") - - llm_model = model_instance.model_type_instance - model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials) - if not model_schema: - raise ModelSchemaNotFoundError("Model schema not found") - - if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: + if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) else: prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000) @@ -796,27 +799,28 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): context=context, memory_config=node_data.memory, memory=None, - model_config=model_config, + model_instance=model_instance, ) rest_tokens = 2000 - model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - + model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) curr_message_tokens = ( - model_type_instance.get_num_tokens(model_config.model, model_config.credentials, prompt_messages) + 1000 + model_type_instance.get_num_tokens( + model_instance.model_name, model_instance.credentials, prompt_messages + ) + + 1000 ) # add 1000 to ensure tool call messages max_tokens = 0 - for parameter_rule in model_config.model_schema.parameter_rules: + for parameter_rule in model_schema.parameter_rules: if parameter_rule.name == "max_tokens" or ( parameter_rule.use_template and parameter_rule.use_template == "max_tokens" ): max_tokens = ( - model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template or "") + model_instance.parameters.get(parameter_rule.name) + or model_instance.parameters.get(parameter_rule.use_template or "") ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens @@ -824,21 +828,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): return rest_tokens - def _fetch_model_config( - self, node_data_model: ModelConfig - ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: - """ - Fetch model config. - """ - if not self._model_instance or not self._model_config: - self._model_instance, self._model_config = llm_utils.fetch_model_config( - node_data_model=node_data_model, - credentials_provider=self._credentials_provider, - model_factory=self._model_factory, - ) - - return self._model_instance, self._model_config - @classmethod def _extract_variable_selector_to_variable_mapping( cls, diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 3f41c0d0b7..464d9b6b9c 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -3,12 +3,10 @@ import re from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole from core.model_runtime.utils.encoders import jsonable_encoder -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities import GraphInitParams @@ -22,7 +20,12 @@ from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils +from core.workflow.nodes.llm import ( + LLMNode, + LLMNodeChatModelMessage, + LLMNodeCompletionModelPromptTemplate, + llm_utils, +) from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory from libs.json_in_md_parser import parse_and_check_json_markdown @@ -52,6 +55,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): _llm_file_saver: LLMFileSaver _credentials_provider: "CredentialsProvider" _model_factory: "ModelFactory" + _model_instance: ModelInstance def __init__( self, @@ -62,6 +66,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): *, credentials_provider: "CredentialsProvider", model_factory: "ModelFactory", + model_instance: ModelInstance, llm_file_saver: LLMFileSaver | None = None, ): super().__init__( @@ -75,6 +80,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): self._credentials_provider = credentials_provider self._model_factory = model_factory + self._model_instance = model_instance if llm_file_saver is None: llm_file_saver = FileSaverImpl( @@ -95,18 +101,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None query = variable.value if variable else None variables = {"query": query} - # fetch model config - model_instance, model_config = llm_utils.fetch_model_config( - node_data_model=node_data.model, - credentials_provider=self._credentials_provider, - model_factory=self._model_factory, - ) - model_schema = model_instance.model_type_instance.get_model_schema( - model_instance.model_name, - model_instance.credentials, - ) - if not model_schema: - raise ValueError(f"Model schema not found for {model_instance.model_name}") + # fetch model instance + model_instance = self._model_instance # fetch memory memory = llm_utils.fetch_memory( variable_pool=variable_pool, @@ -131,7 +127,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): rest_token = self._calculate_rest_token( node_data=node_data, query=query or "", - model_config=model_config, + model_instance=model_instance, context="", ) prompt_template = self._get_prompt_template( @@ -149,9 +145,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): sys_query="", memory=memory, model_instance=model_instance, - model_schema=model_schema, - model_parameters=node_data.model.completion_params, - stop=model_config.stop, + stop=model_instance.stop, sys_files=files, vision_enabled=node_data.vision.enabled, vision_detail=node_data.vision.configs.detail, @@ -166,7 +160,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): try: # handle invoke result generator = LLMNode.invoke_llm( - node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, @@ -205,14 +198,14 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): category_name = classes_map[category_id_result] category_id = category_id_result process_data = { - "model_mode": model_config.mode, + "model_mode": node_data.model.mode, "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, prompt_messages=prompt_messages + model_mode=node_data.model.mode, prompt_messages=prompt_messages ), "usage": jsonable_encoder(usage), "finish_reason": finish_reason, - "model_provider": model_config.provider, - "model_name": model_config.model, + "model_provider": model_instance.provider, + "model_name": model_instance.model_name, } outputs = { "class_name": category_name, @@ -285,39 +278,40 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): self, node_data: QuestionClassifierNodeData, query: str, - model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, context: str | None, ) -> int: - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) + model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) + prompt_template = self._get_prompt_template(node_data, query, None, 2000) - prompt_messages = prompt_transform.get_prompt( + prompt_messages, _ = LLMNode.fetch_prompt_messages( prompt_template=prompt_template, - inputs={}, - query="", - files=[], + sys_query="", + sys_files=[], context=context, - memory_config=node_data.memory, memory=None, - model_config=model_config, + model_instance=model_instance, + stop=model_instance.stop, + memory_config=node_data.memory, + vision_enabled=False, + vision_detail=node_data.vision.configs.detail, + variable_pool=self.graph_runtime_state.variable_pool, + jinja2_variables=[], ) rest_tokens = 2000 - model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: - model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, model=model_config.model - ) - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) max_tokens = 0 - for parameter_rule in model_config.model_schema.parameter_rules: + for parameter_rule in model_schema.parameter_rules: if parameter_rule.name == "max_tokens" or ( parameter_rule.use_template and parameter_rule.use_template == "max_tokens" ): max_tokens = ( - model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template or "") + model_instance.parameters.get(parameter_rule.name) + or model_instance.parameters.get(parameter_rule.use_template or "") ) or 0 rest_tokens = model_context_tokens - max_tokens - curr_message_tokens diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index 330ebfd54a..cdecdf41d2 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -48,3 +48,19 @@ def get_mocked_fetch_model_config( ) return MagicMock(return_value=(model_instance, model_config)) + + +def get_mocked_fetch_model_instance( + provider: str, + model: str, + mode: str, + credentials: dict, +): + mock_fetch_model_config = get_mocked_fetch_model_config( + provider=provider, + model=model, + mode=mode, + credentials=credentials, + ) + model_instance, _ = mock_fetch_model_config() + return MagicMock(return_value=model_instance) diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 1b341e8f21..b5b0fb5334 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -5,13 +5,13 @@ from collections.abc import Generator from unittest.mock import MagicMock, patch from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory from core.llm_generator.output_parser.structured_output import _parse_structured_output +from core.model_manager import ModelInstance from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.graph import Graph from core.workflow.node_events import StreamCompletedEvent from core.workflow.nodes.llm.node import LLMNode +from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from extensions.ext_database import db @@ -67,21 +67,14 @@ def init_llm_node(config: dict) -> LLMNode: graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - # Create node factory - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) - node = LLMNode( id=str(uuid.uuid4()), config=config, graph_init_params=init_params, graph_runtime_state=graph_runtime_state, - credentials_provider=MagicMock(), - model_factory=MagicMock(), + credentials_provider=MagicMock(spec=CredentialsProvider), + model_factory=MagicMock(spec=ModelFactory), + model_instance=MagicMock(spec=ModelInstance), ) return node @@ -116,8 +109,7 @@ def test_execute_llm(): db.session.close = MagicMock() - # Mock the _fetch_model_config to avoid database calls - def mock_fetch_model_config(*_args, **_kwargs): + def build_mock_model_instance() -> MagicMock: from decimal import Decimal from unittest.mock import MagicMock @@ -125,7 +117,20 @@ def test_execute_llm(): from core.model_runtime.entities.message_entities import AssistantPromptMessage # Create mock model instance - mock_model_instance = MagicMock() + mock_model_instance = MagicMock(spec=ModelInstance) + mock_model_instance.provider = "openai" + mock_model_instance.model_name = "gpt-3.5-turbo" + mock_model_instance.credentials = {} + mock_model_instance.parameters = {} + mock_model_instance.stop = [] + mock_model_instance.model_type_instance = MagicMock() + mock_model_instance.model_type_instance.get_model_schema.return_value = MagicMock( + model_properties={}, + parameter_rules=[], + features=[], + ) + mock_model_instance.provider_model_bundle = MagicMock() + mock_model_instance.provider_model_bundle.configuration.using_provider_type = "custom" mock_usage = LLMUsage( prompt_tokens=30, prompt_unit_price=Decimal("0.001"), @@ -149,14 +154,7 @@ def test_execute_llm(): ) mock_model_instance.invoke_llm.return_value = mock_llm_result - # Create mock model config - mock_model_config = MagicMock() - mock_model_config.mode = "chat" - mock_model_config.provider = "openai" - mock_model_config.model = "gpt-3.5-turbo" - mock_model_config.parameters = {} - - return mock_model_instance, mock_model_config + return mock_model_instance # Mock fetch_prompt_messages to avoid database calls def mock_fetch_prompt_messages_1(**_kwargs): @@ -167,10 +165,9 @@ def test_execute_llm(): UserPromptMessage(content="what's the weather today?"), ], [] - with ( - patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config), - patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1), - ): + node._model_instance = build_mock_model_instance() + + with patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1): # execute node result = node._run() assert isinstance(result, Generator) @@ -228,8 +225,7 @@ def test_execute_llm_with_jinja2(): # Mock db.session.close() db.session.close = MagicMock() - # Mock the _fetch_model_config method - def mock_fetch_model_config(*_args, **_kwargs): + def build_mock_model_instance() -> MagicMock: from decimal import Decimal from unittest.mock import MagicMock @@ -237,7 +233,20 @@ def test_execute_llm_with_jinja2(): from core.model_runtime.entities.message_entities import AssistantPromptMessage # Create mock model instance - mock_model_instance = MagicMock() + mock_model_instance = MagicMock(spec=ModelInstance) + mock_model_instance.provider = "openai" + mock_model_instance.model_name = "gpt-3.5-turbo" + mock_model_instance.credentials = {} + mock_model_instance.parameters = {} + mock_model_instance.stop = [] + mock_model_instance.model_type_instance = MagicMock() + mock_model_instance.model_type_instance.get_model_schema.return_value = MagicMock( + model_properties={}, + parameter_rules=[], + features=[], + ) + mock_model_instance.provider_model_bundle = MagicMock() + mock_model_instance.provider_model_bundle.configuration.using_provider_type = "custom" mock_usage = LLMUsage( prompt_tokens=30, prompt_unit_price=Decimal("0.001"), @@ -261,14 +270,7 @@ def test_execute_llm_with_jinja2(): ) mock_model_instance.invoke_llm.return_value = mock_llm_result - # Create mock model config - mock_model_config = MagicMock() - mock_model_config.mode = "chat" - mock_model_config.provider = "openai" - mock_model_config.model = "gpt-3.5-turbo" - mock_model_config.parameters = {} - - return mock_model_instance, mock_model_config + return mock_model_instance # Mock fetch_prompt_messages to avoid database calls def mock_fetch_prompt_messages_2(**_kwargs): @@ -279,10 +281,9 @@ def test_execute_llm_with_jinja2(): UserPromptMessage(content="what's the weather today?"), ], [] - with ( - patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config), - patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2), - ): + node._model_instance = build_mock_model_instance() + + with patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2): # execute node result = node._run() diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 88edc4f9b3..e791f12393 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -4,18 +4,17 @@ import uuid from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory +from core.model_manager import ModelInstance from core.model_runtime.entities import AssistantPromptMessage from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.graph import Graph from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from extensions.ext_database import db from models.enums import UserFrom -from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config +from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance """FOR MOCK FIXTURES, DO NOT REMOVE""" from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock @@ -72,14 +71,6 @@ def init_parameter_extractor_node(config: dict): graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - # Create node factory - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) - node = ParameterExtractorNode( id=str(uuid.uuid4()), config=config, @@ -87,6 +78,7 @@ def init_parameter_extractor_node(config: dict): graph_runtime_state=graph_runtime_state, credentials_provider=MagicMock(spec=CredentialsProvider), model_factory=MagicMock(spec=ModelFactory), + model_instance=MagicMock(spec=ModelInstance), ) return node @@ -116,12 +108,12 @@ def test_function_calling_parameter_extractor(setup_model_mock): } ) - node._fetch_model_config = get_mocked_fetch_model_config( + node._model_instance = get_mocked_fetch_model_instance( provider="langgenius/openai/openai", model="gpt-3.5-turbo", mode="chat", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, - ) + )() db.session.close = MagicMock() result = node._run() @@ -157,12 +149,12 @@ def test_instructions(setup_model_mock): }, ) - node._fetch_model_config = get_mocked_fetch_model_config( + node._model_instance = get_mocked_fetch_model_instance( provider="langgenius/openai/openai", model="gpt-3.5-turbo", mode="chat", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, - ) + )() db.session.close = MagicMock() result = node._run() @@ -207,12 +199,12 @@ def test_chat_parameter_extractor(setup_model_mock): }, ) - node._fetch_model_config = get_mocked_fetch_model_config( + node._model_instance = get_mocked_fetch_model_instance( provider="langgenius/openai/openai", model="gpt-3.5-turbo", mode="chat", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, - ) + )() db.session.close = MagicMock() result = node._run() @@ -258,12 +250,12 @@ def test_completion_parameter_extractor(setup_model_mock): }, ) - node._fetch_model_config = get_mocked_fetch_model_config( + node._model_instance = get_mocked_fetch_model_instance( provider="langgenius/openai/openai", model="gpt-3.5-turbo-instruct", mode="completion", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, - ) + )() db.session.close = MagicMock() result = node._run() @@ -383,12 +375,12 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch): }, ) - node._fetch_model_config = get_mocked_fetch_model_config( + node._model_instance = get_mocked_fetch_model_instance( provider="langgenius/openai/openai", model="gpt-3.5-turbo", mode="chat", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, - ) + )() # Test the mock before running the actual test monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory")) db.session.close = MagicMock() diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index cb691d5c3d..eb85fc21ca 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -1391,10 +1391,20 @@ class TestWorkflowService: workflow_service = WorkflowService() + from unittest.mock import patch + + from core.app.workflow.node_factory import DifyNodeFactory + from core.model_manager import ModelInstance + # Act - result = workflow_service.run_free_workflow_node( - node_data=node_data, tenant_id=tenant_id, user_id=user_id, node_id=node_id, user_inputs=user_inputs - ) + with patch.object( + DifyNodeFactory, + "_build_model_instance_for_llm_node", + return_value=MagicMock(spec=ModelInstance), + ): + result = workflow_service.run_free_workflow_node( + node_data=node_data, tenant_id=tenant_id, user_id=user_id, node_id=node_id, user_inputs=user_inputs + ) # Assert assert result is not None diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 71e8a9d863..5aed463a45 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -10,6 +10,7 @@ from collections.abc import Generator, Mapping from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock +from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent @@ -44,9 +45,10 @@ class MockNodeMixin: mock_config: Optional["MockConfig"] = None, **kwargs: Any, ): - if isinstance(self, (LLMNode, QuestionClassifierNode)): + if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)): kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) + kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance)) super().__init__( id=id, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py index 53c6bc3d60..d7becaaded 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py @@ -9,11 +9,12 @@ This test validates that: """ import time -from unittest.mock import patch +from unittest.mock import MagicMock, patch from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.app.workflow.node_factory import DifyNodeFactory +from core.model_manager import ModelInstance from core.workflow.entities import GraphInitParams from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.graph import Graph @@ -115,7 +116,12 @@ def test_parallel_streaming_workflow(): # Create node factory and graph node_factory = DifyNodeFactory(graph_init_params=init_params, graph_runtime_state=graph_runtime_state) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + with patch.object( + DifyNodeFactory, + "_build_model_instance_for_llm_node", + return_value=MagicMock(spec=ModelInstance), + ): + graph = Graph.init(graph_config=graph_config, node_factory=node_factory) # Create the graph engine engine = GraphEngine( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index afa9265fcd..5c85f2be92 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -547,8 +547,22 @@ class TableTestRunner: """Run tests in parallel.""" results = [] + flask_app: Any = None + try: + from flask import current_app + + flask_app = current_app._get_current_object() # type: ignore[attr-defined] + except RuntimeError: + flask_app = None + + def _run_test_case_with_context(test_case: WorkflowTestCase) -> WorkflowTestResult: + if flask_app is None: + return self.run_test_case(test_case) + with flask_app.app_context(): + return self.run_test_case(test_case) + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: - future_to_test = {executor.submit(self.run_test_case, tc): tc for tc in test_cases} + future_to_test = {executor.submit(_run_test_case_with_context, tc): tc for tc in test_cases} for future in as_completed(future_to_test): test_case = future_to_test[future] diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index ebabf66b41..a235a4167c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -9,6 +9,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCre from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, SystemConfiguration +from core.model_manager import ModelInstance from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.message_entities import ( ImagePromptMessageContent, @@ -115,6 +116,7 @@ def llm_node( graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, + model_instance=mock.MagicMock(spec=ModelInstance), llm_file_saver=mock_file_saver, ) return node @@ -601,6 +603,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, + model_instance=mock.MagicMock(spec=ModelInstance), llm_file_saver=mock_file_saver, ) return node, mock_file_saver