mirror of
https://github.com/langgenius/dify.git
synced 2026-03-10 11:10:19 +08:00
refactor: consolidate LLM runtime model state on ModelInstance (#32746)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
parent
48d8667c4f
commit
962df17a15
@ -110,7 +110,6 @@ ignore_imports =
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
|
||||
core.workflow.nodes.llm.llm_utils -> configs
|
||||
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_manager
|
||||
core.workflow.nodes.llm.protocols -> core.model_manager
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
|
||||
@ -129,13 +128,9 @@ ignore_imports =
|
||||
core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
|
||||
core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.advanced_prompt_transform
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.start.entities -> core.app.app_config.entities
|
||||
core.workflow.nodes.start.start_node -> core.app.app_config.entities
|
||||
|
||||
@ -83,14 +83,21 @@ def fetch_model_config(
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
stop: list[str] = []
|
||||
if "stop" in node_data_model.completion_params:
|
||||
stop = node_data_model.completion_params.pop("stop")
|
||||
completion_params = dict(node_data_model.completion_params)
|
||||
stop = completion_params.pop("stop", [])
|
||||
if not isinstance(stop, list):
|
||||
stop = []
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
model_instance.provider = node_data_model.provider
|
||||
model_instance.model_name = node_data_model.name
|
||||
model_instance.credentials = credentials
|
||||
model_instance.parameters = completion_params
|
||||
model_instance.stop = tuple(stop)
|
||||
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=node_data_model.provider,
|
||||
model=node_data_model.name,
|
||||
@ -98,6 +105,6 @@ def fetch_model_config(
|
||||
mode=node_data_model.mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=credentials,
|
||||
parameters=node_data_model.completion_params,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
from typing import TYPE_CHECKING, Any, cast, final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@ -9,6 +9,9 @@ from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
@ -23,6 +26,8 @@ from core.workflow.nodes.datasource import DatasourceNode
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
|
||||
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
@ -171,6 +176,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
)
|
||||
|
||||
if node_type == NodeType.LLM:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
return LLMNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
@ -178,6 +184,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
if node_type == NodeType.DATASOURCE:
|
||||
@ -208,6 +215,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
)
|
||||
|
||||
if node_type == NodeType.QUESTION_CLASSIFIER:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
return QuestionClassifierNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
@ -215,9 +223,11 @@ class DifyNodeFactory(NodeFactory):
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
if node_type == NodeType.PARAMETER_EXTRACTOR:
|
||||
model_instance = self._build_model_instance_for_llm_node(node_data)
|
||||
return ParameterExtractorNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
@ -225,6 +235,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
return node_class(
|
||||
@ -233,3 +244,37 @@ class DifyNodeFactory(NodeFactory):
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
|
||||
def _build_model_instance_for_llm_node(self, node_data: Mapping[str, Any]) -> ModelInstance:
|
||||
node_data_model = ModelConfig.model_validate(node_data["model"])
|
||||
if not node_data_model.mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
credentials = self._llm_credentials_provider.fetch(node_data_model.provider, node_data_model.name)
|
||||
model_instance = self._llm_model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=node_data_model.name,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
completion_params = dict(node_data_model.completion_params)
|
||||
stop = completion_params.pop("stop", [])
|
||||
if not isinstance(stop, list):
|
||||
stop = []
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
model_instance.provider = node_data_model.provider
|
||||
model_instance.model_name = node_data_model.name
|
||||
model_instance.credentials = credentials
|
||||
model_instance.parameters = completion_params
|
||||
model_instance.stop = tuple(stop)
|
||||
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
return model_instance
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Iterable, Sequence
|
||||
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
|
||||
from typing import IO, Any, Literal, Optional, Union, cast, overload
|
||||
|
||||
from configs import dify_config
|
||||
@ -38,6 +38,9 @@ class ModelInstance:
|
||||
self.model_name = model
|
||||
self.provider = provider_model_bundle.configuration.provider.provider
|
||||
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
# Runtime LLM invocation fields.
|
||||
self.parameters: Mapping[str, Any] = {}
|
||||
self.stop: Sequence[str] = ()
|
||||
self.model_type_instance = self.provider_model_bundle.model_type_instance
|
||||
self.load_balancing_manager = self._get_load_balancing_manager(
|
||||
configuration=provider_model_bundle.configuration,
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import cast
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
@ -44,7 +45,8 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_config: ModelConfigWithCredentialsEntity | None = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
prompt_messages = []
|
||||
@ -59,6 +61,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
elif isinstance(prompt_template, list) and all(isinstance(item, ChatModelMessage) for item in prompt_template):
|
||||
@ -71,6 +74,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
memory_config=memory_config,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
|
||||
@ -85,7 +89,8 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_config: ModelConfigWithCredentialsEntity | None = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
@ -111,6 +116,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
parser=parser,
|
||||
prompt_inputs=prompt_inputs,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
if query:
|
||||
@ -146,7 +152,8 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_config: ModelConfigWithCredentialsEntity | None = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
@ -198,8 +205,13 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
if memory and memory_config:
|
||||
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
|
||||
|
||||
prompt_messages = self._append_chat_histories(
|
||||
memory,
|
||||
memory_config,
|
||||
prompt_messages,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
if files and query is not None:
|
||||
for file in files:
|
||||
prompt_message_contents.append(
|
||||
@ -276,7 +288,8 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
role_prefix: MemoryConfig.RolePrefix,
|
||||
parser: PromptTemplateParser,
|
||||
prompt_inputs: Mapping[str, str],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_config: ModelConfigWithCredentialsEntity | None = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
) -> Mapping[str, str]:
|
||||
prompt_inputs = dict(prompt_inputs)
|
||||
if "#histories#" in parser.variable_keys:
|
||||
@ -286,7 +299,11 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
prompt_inputs = {k: inputs[k] for k in parser.variable_keys if k in inputs}
|
||||
tmp_human_message = UserPromptMessage(content=parser.format(prompt_inputs))
|
||||
|
||||
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
||||
rest_tokens = self._calculate_rest_token(
|
||||
[tmp_human_message],
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
histories = self._get_history_messages_from_memory(
|
||||
memory=memory,
|
||||
|
||||
@ -41,7 +41,7 @@ class AgentHistoryPromptTransform(PromptTransform):
|
||||
if not self.memory:
|
||||
return prompt_messages
|
||||
|
||||
max_token_limit = self._calculate_rest_token(self.prompt_messages, self.model_config)
|
||||
max_token_limit = self._calculate_rest_token(self.prompt_messages, model_config=self.model_config)
|
||||
|
||||
model_type_instance = self.model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
@ -4,45 +4,83 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
|
||||
|
||||
class PromptTransform:
|
||||
def _resolve_model_runtime(
|
||||
self,
|
||||
*,
|
||||
model_config: ModelConfigWithCredentialsEntity | None = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
) -> tuple[ModelInstance, AIModelEntity]:
|
||||
if model_instance is None:
|
||||
if model_config is None:
|
||||
raise ValueError("Either model_config or model_instance must be provided.")
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
model_instance.credentials = model_config.credentials
|
||||
model_instance.parameters = model_config.parameters
|
||||
model_instance.stop = model_config.stop
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
model=model_instance.model_name,
|
||||
credentials=model_instance.credentials,
|
||||
)
|
||||
if model_schema is None:
|
||||
if model_config is None:
|
||||
raise ValueError("Model schema not found for the provided model instance.")
|
||||
model_schema = model_config.model_schema
|
||||
|
||||
return model_instance, model_schema
|
||||
|
||||
def _append_chat_histories(
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory_config: MemoryConfig,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
*,
|
||||
model_config: ModelConfigWithCredentialsEntity | None = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
rest_tokens = self._calculate_rest_token(prompt_messages, model_config)
|
||||
rest_tokens = self._calculate_rest_token(
|
||||
prompt_messages,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
histories = self._get_history_messages_list_from_memory(memory, memory_config, rest_tokens)
|
||||
prompt_messages.extend(histories)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _calculate_rest_token(
|
||||
self, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
*,
|
||||
model_config: ModelConfigWithCredentialsEntity | None = None,
|
||||
model_instance: ModelInstance | None = None,
|
||||
) -> int:
|
||||
model_instance, model_schema = self._resolve_model_runtime(
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
model_parameters = model_instance.parameters
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
for parameter_rule in model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
model_parameters.get(parameter_rule.name)
|
||||
or model_parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
|
||||
@ -252,7 +252,7 @@ class SimplePromptTransform(PromptTransform):
|
||||
if memory:
|
||||
tmp_human_message = UserPromptMessage(content=prompt)
|
||||
|
||||
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)
|
||||
rest_tokens = self._calculate_rest_token([tmp_human_message], model_config=model_config)
|
||||
histories = self._get_history_messages_from_memory(
|
||||
memory=memory,
|
||||
memory_config=MemoryConfig(
|
||||
|
||||
@ -5,20 +5,16 @@ from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.file.models import File
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -29,46 +25,14 @@ from models.provider_ids import ModelProviderID
|
||||
from .exc import InvalidVariableTypeError
|
||||
|
||||
|
||||
def fetch_model_config(
|
||||
*,
|
||||
node_data_model: ModelConfig,
|
||||
credentials_provider: CredentialsProvider,
|
||||
model_factory: ModelFactory,
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
if not node_data_model.mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name)
|
||||
model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=node_data_model.name,
|
||||
model_type=ModelType.LLM,
|
||||
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
|
||||
model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema(
|
||||
model_instance.model_name,
|
||||
model_instance.credentials,
|
||||
)
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
stop: list[str] = []
|
||||
if "stop" in node_data_model.completion_params:
|
||||
stop = node_data_model.completion_params.pop("stop")
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=node_data_model.provider,
|
||||
model=node_data_model.name,
|
||||
model_schema=model_schema,
|
||||
mode=node_data_model.mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=credentials,
|
||||
parameters=node_data_model.completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
raise ValueError(f"Model schema not found for {model_instance.model_name}")
|
||||
return model_schema
|
||||
|
||||
|
||||
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:
|
||||
|
||||
@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
@ -38,7 +37,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
@ -83,7 +82,6 @@ from .entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
)
|
||||
from .exc import (
|
||||
InvalidContextStructureError,
|
||||
@ -116,6 +114,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
_llm_file_saver: LLMFileSaver
|
||||
_credentials_provider: CredentialsProvider
|
||||
_model_factory: ModelFactory
|
||||
_model_instance: ModelInstance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -126,6 +125,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
*,
|
||||
credentials_provider: CredentialsProvider,
|
||||
model_factory: ModelFactory,
|
||||
model_instance: ModelInstance,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
@ -139,6 +139,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
self._model_instance = model_instance
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
@ -202,21 +203,10 @@ class LLMNode(Node[LLMNodeData]):
|
||||
node_inputs["#context_files#"] = [file.model_dump() for file in context_files]
|
||||
|
||||
# fetch model config
|
||||
model_instance, model_config = self._fetch_model_config(
|
||||
node_data_model=self.node_data.model,
|
||||
)
|
||||
model_name = getattr(model_instance, "model_name", None)
|
||||
if not isinstance(model_name, str):
|
||||
model_name = model_config.model
|
||||
model_provider = getattr(model_instance, "provider", None)
|
||||
if not isinstance(model_provider, str):
|
||||
model_provider = model_config.provider
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
model_name,
|
||||
model_instance.credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {model_name}")
|
||||
model_instance = self._model_instance
|
||||
model_name = model_instance.model_name
|
||||
model_provider = model_instance.provider
|
||||
model_stop = model_instance.stop
|
||||
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
@ -240,9 +230,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=self.node_data.model.completion_params,
|
||||
stop=model_config.stop,
|
||||
stop=model_stop,
|
||||
prompt_template=self.node_data.prompt_template,
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
@ -254,7 +242,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
# handle invoke result
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=self.node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
@ -371,7 +358,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
@staticmethod
|
||||
def invoke_llm(
|
||||
*,
|
||||
node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None = None,
|
||||
@ -384,11 +370,10 @@ class LLMNode(Node[LLMNodeData]):
|
||||
node_type: NodeType,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
node_data_model.name, model_instance.credentials
|
||||
)
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {node_data_model.name}")
|
||||
model_parameters = model_instance.parameters
|
||||
invoke_model_parameters = dict(model_parameters)
|
||||
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
|
||||
if structured_output_enabled:
|
||||
output_schema = LLMNode.fetch_structured_output_schema(
|
||||
@ -402,7 +387,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=output_schema,
|
||||
model_parameters=node_data_model.completion_params,
|
||||
model_parameters=invoke_model_parameters,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=user_id,
|
||||
@ -412,7 +397,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=node_data_model.completion_params,
|
||||
model_parameters=invoke_model_parameters,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=user_id,
|
||||
@ -771,23 +756,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
return None
|
||||
|
||||
def _fetch_model_config(
|
||||
self,
|
||||
*,
|
||||
node_data_model: ModelConfig,
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
model, model_config_with_cred = llm_utils.fetch_model_config(
|
||||
node_data_model=node_data_model,
|
||||
credentials_provider=self._credentials_provider,
|
||||
model_factory=self._model_factory,
|
||||
)
|
||||
completion_params = model_config_with_cred.parameters
|
||||
|
||||
model_config_with_cred.parameters = completion_params
|
||||
# NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`.
|
||||
node_data_model.completion_params = completion_params
|
||||
return model, model_config_with_cred
|
||||
|
||||
@staticmethod
|
||||
def fetch_prompt_messages(
|
||||
*,
|
||||
@ -796,8 +764,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
model_instance: ModelInstance,
|
||||
model_schema: AIModelEntity,
|
||||
model_parameters: Mapping[str, Any],
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
stop: Sequence[str] | None = None,
|
||||
memory_config: MemoryConfig | None = None,
|
||||
@ -808,6 +774,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
context_files: list[File] | None = None,
|
||||
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
# For chat model
|
||||
@ -826,8 +793,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
# Extend prompt_messages with memory messages
|
||||
prompt_messages.extend(memory_messages)
|
||||
@ -865,8 +830,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
# Insert histories into the prompt
|
||||
prompt_content = prompt_messages[0].content
|
||||
@ -1316,23 +1279,23 @@ def _calculate_rest_token(
|
||||
*,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_instance: ModelInstance,
|
||||
model_schema: AIModelEntity,
|
||||
model_parameters: Mapping[str, Any],
|
||||
) -> int:
|
||||
rest_tokens = 2000
|
||||
runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
runtime_model_parameters = model_instance.parameters
|
||||
|
||||
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_schema.parameter_rules:
|
||||
for parameter_rule in runtime_model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_parameters.get(parameter_rule.name)
|
||||
or model_parameters.get(str(parameter_rule.use_template))
|
||||
runtime_model_parameters.get(parameter_rule.name)
|
||||
or runtime_model_parameters.get(str(parameter_rule.use_template))
|
||||
or 0
|
||||
)
|
||||
|
||||
@ -1347,8 +1310,6 @@ def _handle_memory_chat_mode(
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
model_schema: AIModelEntity,
|
||||
model_parameters: Mapping[str, Any],
|
||||
) -> Sequence[PromptMessage]:
|
||||
memory_messages: Sequence[PromptMessage] = []
|
||||
# Get messages from memory for chat model
|
||||
@ -1356,8 +1317,6 @@ def _handle_memory_chat_mode(
|
||||
rest_tokens = _calculate_rest_token(
|
||||
prompt_messages=[],
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
memory_messages = memory.get_history_prompt_messages(
|
||||
max_token_limit=rest_tokens,
|
||||
@ -1371,8 +1330,6 @@ def _handle_memory_completion_mode(
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
model_schema: AIModelEntity,
|
||||
model_parameters: Mapping[str, Any],
|
||||
) -> str:
|
||||
memory_text = ""
|
||||
# Get history text from memory for completion model
|
||||
@ -1380,8 +1337,6 @@ def _handle_memory_completion_mode(
|
||||
rest_tokens = _calculate_rest_token(
|
||||
prompt_messages=[],
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
if not memory_config.role_prefix:
|
||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||
|
||||
@ -5,7 +5,6 @@ import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import ImagePromptMessageContent
|
||||
@ -31,7 +30,7 @@ from core.workflow.file import File
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base import variable_template_parser
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.llm import ModelConfig, llm_utils
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
from core.workflow.runtime import VariablePool
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
|
||||
@ -95,8 +94,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
node_type = NodeType.PARAMETER_EXTRACTOR
|
||||
|
||||
_model_instance: ModelInstance | None = None
|
||||
_model_config: ModelConfigWithCredentialsEntity | None = None
|
||||
_model_instance: ModelInstance
|
||||
_credentials_provider: "CredentialsProvider"
|
||||
_model_factory: "ModelFactory"
|
||||
|
||||
@ -109,6 +107,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
*,
|
||||
credentials_provider: "CredentialsProvider",
|
||||
model_factory: "ModelFactory",
|
||||
model_instance: ModelInstance,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
@ -118,6 +117,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
)
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
self._model_instance = model_instance
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
@ -155,18 +155,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
else []
|
||||
)
|
||||
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
model_instance = self._model_instance
|
||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||
raise InvalidModelTypeError("Model is not a Large Language Model")
|
||||
|
||||
llm_model = model_instance.model_type_instance
|
||||
model_schema = llm_model.get_model_schema(
|
||||
model=model_config.model,
|
||||
credentials=model_config.credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ModelSchemaNotFoundError("Model schema not found")
|
||||
|
||||
try:
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
except ValueError as exc:
|
||||
raise ModelSchemaNotFoundError("Model schema not found") from exc
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
@ -184,7 +180,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data=node_data,
|
||||
query=query,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
files=files,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
@ -195,7 +191,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
data=node_data,
|
||||
query=query,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
files=files,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
@ -211,24 +207,23 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
}
|
||||
|
||||
process_data = {
|
||||
"model_mode": model_config.mode,
|
||||
"model_mode": node_data.model.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_config.mode, prompt_messages=prompt_messages
|
||||
model_mode=node_data.model.mode, prompt_messages=prompt_messages
|
||||
),
|
||||
"usage": None,
|
||||
"function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]),
|
||||
"tool_call": None,
|
||||
"model_provider": model_config.provider,
|
||||
"model_name": model_config.model,
|
||||
"model_provider": model_instance.provider,
|
||||
"model_name": model_instance.model_name,
|
||||
}
|
||||
|
||||
try:
|
||||
text, usage, tool_call = self._invoke(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=prompt_message_tools,
|
||||
stop=model_config.stop,
|
||||
stop=model_instance.stop,
|
||||
)
|
||||
process_data["usage"] = jsonable_encoder(usage)
|
||||
process_data["tool_call"] = jsonable_encoder(tool_call)
|
||||
@ -290,17 +285,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
def _invoke(
|
||||
self,
|
||||
node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
stop: list[str],
|
||||
stop: Sequence[str],
|
||||
) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]:
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=node_data_model.completion_params,
|
||||
model_parameters=dict(model_instance.parameters),
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stop=list(stop),
|
||||
stream=False,
|
||||
user=self.user_id,
|
||||
)
|
||||
@ -324,7 +318,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
@ -337,7 +331,13 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
)
|
||||
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, "")
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data,
|
||||
query=query,
|
||||
variable_pool=variable_pool,
|
||||
model_instance=model_instance,
|
||||
context="",
|
||||
)
|
||||
prompt_template = self._get_function_calling_prompt_template(
|
||||
node_data, query, variable_pool, memory, rest_token
|
||||
)
|
||||
@ -349,7 +349,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
image_detail_config=vision_detail,
|
||||
)
|
||||
|
||||
@ -406,7 +406,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
@ -421,7 +421,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data=data,
|
||||
query=query,
|
||||
variable_pool=variable_pool,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
files=files,
|
||||
vision_detail=vision_detail,
|
||||
@ -431,7 +431,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data=data,
|
||||
query=query,
|
||||
variable_pool=variable_pool,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
memory=memory,
|
||||
files=files,
|
||||
vision_detail=vision_detail,
|
||||
@ -444,7 +444,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
@ -454,7 +454,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
|
||||
node_data=node_data,
|
||||
query=query,
|
||||
variable_pool=variable_pool,
|
||||
model_instance=model_instance,
|
||||
context="",
|
||||
)
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(
|
||||
node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token
|
||||
@ -467,7 +471,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
image_detail_config=vision_detail,
|
||||
)
|
||||
|
||||
@ -478,7 +482,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
@ -488,7 +492,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data, query=query, variable_pool=variable_pool, model_config=model_config, context=""
|
||||
node_data=node_data,
|
||||
query=query,
|
||||
variable_pool=variable_pool,
|
||||
model_instance=model_instance,
|
||||
context="",
|
||||
)
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(
|
||||
node_data=node_data,
|
||||
@ -508,7 +516,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
context="",
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
image_detail_config=vision_detail,
|
||||
)
|
||||
|
||||
@ -769,21 +777,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
context: str | None,
|
||||
) -> int:
|
||||
try:
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
except ValueError as exc:
|
||||
raise ModelSchemaNotFoundError("Model schema not found") from exc
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||
raise InvalidModelTypeError("Model is not a Large Language Model")
|
||||
|
||||
llm_model = model_instance.model_type_instance
|
||||
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
|
||||
if not model_schema:
|
||||
raise ModelSchemaNotFoundError("Model schema not found")
|
||||
|
||||
if set(model_schema.features or []) & {ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
|
||||
if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}:
|
||||
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
|
||||
else:
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000)
|
||||
@ -796,27 +799,28 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
curr_message_tokens = (
|
||||
model_type_instance.get_num_tokens(model_config.model, model_config.credentials, prompt_messages) + 1000
|
||||
model_type_instance.get_num_tokens(
|
||||
model_instance.model_name, model_instance.credentials, prompt_messages
|
||||
)
|
||||
+ 1000
|
||||
) # add 1000 to ensure tool call messages
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
for parameter_rule in model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
model_instance.parameters.get(parameter_rule.name)
|
||||
or model_instance.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
@ -824,21 +828,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def _fetch_model_config(
|
||||
self, node_data_model: ModelConfig
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config.
|
||||
"""
|
||||
if not self._model_instance or not self._model_config:
|
||||
self._model_instance, self._model_config = llm_utils.fetch_model_config(
|
||||
node_data_model=node_data_model,
|
||||
credentials_provider=self._credentials_provider,
|
||||
model_factory=self._model_factory,
|
||||
)
|
||||
|
||||
return self._model_instance, self._model_config
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
|
||||
@ -3,12 +3,10 @@ import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities import GraphInitParams
|
||||
@ -22,7 +20,12 @@ from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils
|
||||
from core.workflow.nodes.llm import (
|
||||
LLMNode,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
llm_utils,
|
||||
)
|
||||
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
@ -52,6 +55,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
_llm_file_saver: LLMFileSaver
|
||||
_credentials_provider: "CredentialsProvider"
|
||||
_model_factory: "ModelFactory"
|
||||
_model_instance: ModelInstance
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -62,6 +66,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
*,
|
||||
credentials_provider: "CredentialsProvider",
|
||||
model_factory: "ModelFactory",
|
||||
model_instance: ModelInstance,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
@ -75,6 +80,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
self._model_instance = model_instance
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
@ -95,18 +101,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None
|
||||
query = variable.value if variable else None
|
||||
variables = {"query": query}
|
||||
# fetch model config
|
||||
model_instance, model_config = llm_utils.fetch_model_config(
|
||||
node_data_model=node_data.model,
|
||||
credentials_provider=self._credentials_provider,
|
||||
model_factory=self._model_factory,
|
||||
)
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
model_instance.model_name,
|
||||
model_instance.credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {model_instance.model_name}")
|
||||
# fetch model instance
|
||||
model_instance = self._model_instance
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
@ -131,7 +127,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
rest_token = self._calculate_rest_token(
|
||||
node_data=node_data,
|
||||
query=query or "",
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
context="",
|
||||
)
|
||||
prompt_template = self._get_prompt_template(
|
||||
@ -149,9 +145,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
sys_query="",
|
||||
memory=memory,
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=node_data.model.completion_params,
|
||||
stop=model_config.stop,
|
||||
stop=model_instance.stop,
|
||||
sys_files=files,
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
@ -166,7 +160,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
try:
|
||||
# handle invoke result
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
@ -205,14 +198,14 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
category_name = classes_map[category_id_result]
|
||||
category_id = category_id_result
|
||||
process_data = {
|
||||
"model_mode": model_config.mode,
|
||||
"model_mode": node_data.model.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_config.mode, prompt_messages=prompt_messages
|
||||
model_mode=node_data.model.mode, prompt_messages=prompt_messages
|
||||
),
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
"model_provider": model_config.provider,
|
||||
"model_name": model_config.model,
|
||||
"model_provider": model_instance.provider,
|
||||
"model_name": model_instance.model_name,
|
||||
}
|
||||
outputs = {
|
||||
"class_name": category_name,
|
||||
@ -285,39 +278,40 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
context: str | None,
|
||||
) -> int:
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
|
||||
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_messages, _ = LLMNode.fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query="",
|
||||
files=[],
|
||||
sys_query="",
|
||||
sys_files=[],
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
stop=model_instance.stop,
|
||||
memory_config=node_data.memory,
|
||||
vision_enabled=False,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
jinja2_variables=[],
|
||||
)
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
for parameter_rule in model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template or "")
|
||||
model_instance.parameters.get(parameter_rule.name)
|
||||
or model_instance.parameters.get(parameter_rule.use_template or "")
|
||||
) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
|
||||
@ -48,3 +48,19 @@ def get_mocked_fetch_model_config(
|
||||
)
|
||||
|
||||
return MagicMock(return_value=(model_instance, model_config))
|
||||
|
||||
|
||||
def get_mocked_fetch_model_instance(
|
||||
provider: str,
|
||||
model: str,
|
||||
mode: str,
|
||||
credentials: dict,
|
||||
):
|
||||
mock_fetch_model_config = get_mocked_fetch_model_config(
|
||||
provider=provider,
|
||||
model=model,
|
||||
mode=mode,
|
||||
credentials=credentials,
|
||||
)
|
||||
model_instance, _ = mock_fetch_model_config()
|
||||
return MagicMock(return_value=model_instance)
|
||||
|
||||
@ -5,13 +5,13 @@ from collections.abc import Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.llm_generator.output_parser.structured_output import _parse_structured_output
|
||||
from core.model_manager import ModelInstance
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.node_events import StreamCompletedEvent
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
@ -67,21 +67,14 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = LLMNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=MagicMock(),
|
||||
model_factory=MagicMock(),
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
model_factory=MagicMock(spec=ModelFactory),
|
||||
model_instance=MagicMock(spec=ModelInstance),
|
||||
)
|
||||
|
||||
return node
|
||||
@ -116,8 +109,7 @@ def test_execute_llm():
|
||||
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# Mock the _fetch_model_config to avoid database calls
|
||||
def mock_fetch_model_config(*_args, **_kwargs):
|
||||
def build_mock_model_instance() -> MagicMock:
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -125,7 +117,20 @@ def test_execute_llm():
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
# Create mock model instance
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance = MagicMock(spec=ModelInstance)
|
||||
mock_model_instance.provider = "openai"
|
||||
mock_model_instance.model_name = "gpt-3.5-turbo"
|
||||
mock_model_instance.credentials = {}
|
||||
mock_model_instance.parameters = {}
|
||||
mock_model_instance.stop = []
|
||||
mock_model_instance.model_type_instance = MagicMock()
|
||||
mock_model_instance.model_type_instance.get_model_schema.return_value = MagicMock(
|
||||
model_properties={},
|
||||
parameter_rules=[],
|
||||
features=[],
|
||||
)
|
||||
mock_model_instance.provider_model_bundle = MagicMock()
|
||||
mock_model_instance.provider_model_bundle.configuration.using_provider_type = "custom"
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
@ -149,14 +154,7 @@ def test_execute_llm():
|
||||
)
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
|
||||
# Create mock model config
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.parameters = {}
|
||||
|
||||
return mock_model_instance, mock_model_config
|
||||
return mock_model_instance
|
||||
|
||||
# Mock fetch_prompt_messages to avoid database calls
|
||||
def mock_fetch_prompt_messages_1(**_kwargs):
|
||||
@ -167,10 +165,9 @@ def test_execute_llm():
|
||||
UserPromptMessage(content="what's the weather today?"),
|
||||
], []
|
||||
|
||||
with (
|
||||
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
|
||||
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1),
|
||||
):
|
||||
node._model_instance = build_mock_model_instance()
|
||||
|
||||
with patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1):
|
||||
# execute node
|
||||
result = node._run()
|
||||
assert isinstance(result, Generator)
|
||||
@ -228,8 +225,7 @@ def test_execute_llm_with_jinja2():
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# Mock the _fetch_model_config method
|
||||
def mock_fetch_model_config(*_args, **_kwargs):
|
||||
def build_mock_model_instance() -> MagicMock:
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -237,7 +233,20 @@ def test_execute_llm_with_jinja2():
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
# Create mock model instance
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_instance = MagicMock(spec=ModelInstance)
|
||||
mock_model_instance.provider = "openai"
|
||||
mock_model_instance.model_name = "gpt-3.5-turbo"
|
||||
mock_model_instance.credentials = {}
|
||||
mock_model_instance.parameters = {}
|
||||
mock_model_instance.stop = []
|
||||
mock_model_instance.model_type_instance = MagicMock()
|
||||
mock_model_instance.model_type_instance.get_model_schema.return_value = MagicMock(
|
||||
model_properties={},
|
||||
parameter_rules=[],
|
||||
features=[],
|
||||
)
|
||||
mock_model_instance.provider_model_bundle = MagicMock()
|
||||
mock_model_instance.provider_model_bundle.configuration.using_provider_type = "custom"
|
||||
mock_usage = LLMUsage(
|
||||
prompt_tokens=30,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
@ -261,14 +270,7 @@ def test_execute_llm_with_jinja2():
|
||||
)
|
||||
mock_model_instance.invoke_llm.return_value = mock_llm_result
|
||||
|
||||
# Create mock model config
|
||||
mock_model_config = MagicMock()
|
||||
mock_model_config.mode = "chat"
|
||||
mock_model_config.provider = "openai"
|
||||
mock_model_config.model = "gpt-3.5-turbo"
|
||||
mock_model_config.parameters = {}
|
||||
|
||||
return mock_model_instance, mock_model_config
|
||||
return mock_model_instance
|
||||
|
||||
# Mock fetch_prompt_messages to avoid database calls
|
||||
def mock_fetch_prompt_messages_2(**_kwargs):
|
||||
@ -279,10 +281,9 @@ def test_execute_llm_with_jinja2():
|
||||
UserPromptMessage(content="what's the weather today?"),
|
||||
], []
|
||||
|
||||
with (
|
||||
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
|
||||
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2),
|
||||
):
|
||||
node._model_instance = build_mock_model_instance()
|
||||
|
||||
with patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2):
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
||||
|
||||
@ -4,18 +4,17 @@ import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import AssistantPromptMessage
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from models.enums import UserFrom
|
||||
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config
|
||||
from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance
|
||||
|
||||
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
|
||||
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
|
||||
@ -72,14 +71,6 @@ def init_parameter_extractor_node(config: dict):
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# Create node factory
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
node = ParameterExtractorNode(
|
||||
id=str(uuid.uuid4()),
|
||||
config=config,
|
||||
@ -87,6 +78,7 @@ def init_parameter_extractor_node(config: dict):
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
model_factory=MagicMock(spec=ModelFactory),
|
||||
model_instance=MagicMock(spec=ModelInstance),
|
||||
)
|
||||
return node
|
||||
|
||||
@ -116,12 +108,12 @@ def test_function_calling_parameter_extractor(setup_model_mock):
|
||||
}
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
node._model_instance = get_mocked_fetch_model_instance(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo",
|
||||
mode="chat",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
)()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
result = node._run()
|
||||
@ -157,12 +149,12 @@ def test_instructions(setup_model_mock):
|
||||
},
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
node._model_instance = get_mocked_fetch_model_instance(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo",
|
||||
mode="chat",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
)()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
result = node._run()
|
||||
@ -207,12 +199,12 @@ def test_chat_parameter_extractor(setup_model_mock):
|
||||
},
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
node._model_instance = get_mocked_fetch_model_instance(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo",
|
||||
mode="chat",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
)()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
result = node._run()
|
||||
@ -258,12 +250,12 @@ def test_completion_parameter_extractor(setup_model_mock):
|
||||
},
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
node._model_instance = get_mocked_fetch_model_instance(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
mode="completion",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
)()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
result = node._run()
|
||||
@ -383,12 +375,12 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
|
||||
},
|
||||
)
|
||||
|
||||
node._fetch_model_config = get_mocked_fetch_model_config(
|
||||
node._model_instance = get_mocked_fetch_model_instance(
|
||||
provider="langgenius/openai/openai",
|
||||
model="gpt-3.5-turbo",
|
||||
mode="chat",
|
||||
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
|
||||
)
|
||||
)()
|
||||
# Test the mock before running the actual test
|
||||
monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory"))
|
||||
db.session.close = MagicMock()
|
||||
|
||||
@ -1391,10 +1391,20 @@ class TestWorkflowService:
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.model_manager import ModelInstance
|
||||
|
||||
# Act
|
||||
result = workflow_service.run_free_workflow_node(
|
||||
node_data=node_data, tenant_id=tenant_id, user_id=user_id, node_id=node_id, user_inputs=user_inputs
|
||||
)
|
||||
with patch.object(
|
||||
DifyNodeFactory,
|
||||
"_build_model_instance_for_llm_node",
|
||||
return_value=MagicMock(spec=ModelInstance),
|
||||
):
|
||||
result = workflow_service.run_free_workflow_node(
|
||||
node_data=node_data, tenant_id=tenant_id, user_id=user_id, node_id=node_id, user_inputs=user_inputs
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
|
||||
@ -10,6 +10,7 @@ from collections.abc import Generator, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent
|
||||
@ -44,9 +45,10 @@ class MockNodeMixin:
|
||||
mock_config: Optional["MockConfig"] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if isinstance(self, (LLMNode, QuestionClassifierNode)):
|
||||
if isinstance(self, (LLMNode, QuestionClassifierNode, ParameterExtractorNode)):
|
||||
kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider))
|
||||
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
|
||||
kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance))
|
||||
|
||||
super().__init__(
|
||||
id=id,
|
||||
|
||||
@ -9,11 +9,12 @@ This test validates that:
|
||||
"""
|
||||
|
||||
import time
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
@ -115,7 +116,12 @@ def test_parallel_streaming_workflow():
|
||||
|
||||
# Create node factory and graph
|
||||
node_factory = DifyNodeFactory(graph_init_params=init_params, graph_runtime_state=graph_runtime_state)
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
with patch.object(
|
||||
DifyNodeFactory,
|
||||
"_build_model_instance_for_llm_node",
|
||||
return_value=MagicMock(spec=ModelInstance),
|
||||
):
|
||||
graph = Graph.init(graph_config=graph_config, node_factory=node_factory)
|
||||
|
||||
# Create the graph engine
|
||||
engine = GraphEngine(
|
||||
|
||||
@ -547,8 +547,22 @@ class TableTestRunner:
|
||||
"""Run tests in parallel."""
|
||||
results = []
|
||||
|
||||
flask_app: Any = None
|
||||
try:
|
||||
from flask import current_app
|
||||
|
||||
flask_app = current_app._get_current_object() # type: ignore[attr-defined]
|
||||
except RuntimeError:
|
||||
flask_app = None
|
||||
|
||||
def _run_test_case_with_context(test_case: WorkflowTestCase) -> WorkflowTestResult:
|
||||
if flask_app is None:
|
||||
return self.run_test_case(test_case)
|
||||
with flask_app.app_context():
|
||||
return self.run_test_case(test_case)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
future_to_test = {executor.submit(self.run_test_case, tc): tc for tc in test_cases}
|
||||
future_to_test = {executor.submit(_run_test_case_with_context, tc): tc for tc in test_cases}
|
||||
|
||||
for future in as_completed(future_to_test):
|
||||
test_case = future_to_test[future]
|
||||
|
||||
@ -9,6 +9,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCre
|
||||
from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
@ -115,6 +116,7 @@ def llm_node(
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=mock_credentials_provider,
|
||||
model_factory=mock_model_factory,
|
||||
model_instance=mock.MagicMock(spec=ModelInstance),
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
return node
|
||||
@ -601,6 +603,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=mock_credentials_provider,
|
||||
model_factory=mock_model_factory,
|
||||
model_instance=mock.MagicMock(spec=ModelInstance),
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
return node, mock_file_saver
|
||||
|
||||
Loading…
Reference in New Issue
Block a user