diff --git a/api/.importlinter b/api/.importlinter index f615a2ea5f..c9364a0896 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -89,7 +89,6 @@ forbidden_modules = core.logging core.mcp core.memory - core.model_manager core.moderation core.ops core.plugin @@ -117,6 +116,7 @@ ignore_imports = core.workflow.nodes.llm.llm_utils -> configs core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities core.workflow.nodes.llm.llm_utils -> core.model_manager + core.workflow.nodes.llm.protocols -> core.model_manager core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model core.workflow.nodes.llm.llm_utils -> models.model core.workflow.nodes.llm.llm_utils -> models.provider diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index a125050082..80e180ce96 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -112,7 +112,7 @@ class BaseAgentRunner(AppRunner): # check if model supports stream tool call llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) - model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + model_schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials) features = model_schema.features if model_schema and model_schema.features else [] self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features self.files = application_generate_entity.files if ModelFeature.VISION in features else [] diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index a55f2d0f5f..0464afe194 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -245,7 +245,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): iteration_step += 1 yield LLMResultChunk( - model=model_instance.model, + model=model_instance.model_name, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"] @@ -268,7 +268,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): self.queue_manager.publish( QueueMessageEndEvent( llm_result=LLMResult( - model=model_instance.model, + model=model_instance.model_name, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"] or LLMUsage.empty_usage(), diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index f9da2f3b43..633609f54f 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -178,7 +178,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): ) yield LLMResultChunk( - model=model_instance.model, + model=model_instance.model_name, prompt_messages=result.prompt_messages, system_fingerprint=result.system_fingerprint, delta=LLMResultChunkDelta( @@ -308,7 +308,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): self.queue_manager.publish( QueueMessageEndEvent( llm_result=LLMResult( - model=model_instance.model, + model=model_instance.model_name, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"] or LLMUsage.empty_usage(), diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 8b6b8f227b..7309113f27 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -178,7 +178,7 @@ class AgentChatAppRunner(AppRunner): # change function call strategy based on LLM model llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) - model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + model_schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials) if not model_schema: raise ValueError("Model schema not found") diff --git a/api/core/app/llm/__init__.py b/api/core/app/llm/__init__.py new file mode 100644 index 0000000000..5ac76c8086 --- /dev/null +++ b/api/core/app/llm/__init__.py @@ -0,0 +1 @@ +"""LLM-related application services.""" diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py new file mode 100644 index 0000000000..2b162920ee --- /dev/null +++ b/api/core/app/llm/model_access.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from typing import Any + +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.errors.error import ProviderTokenNotInitError +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.model_entities import ModelType +from core.provider_manager import ProviderManager +from core.workflow.nodes.llm.entities import ModelConfig +from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError +from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory + + +class DifyCredentialsProvider: + tenant_id: str + provider_manager: ProviderManager + + def __init__(self, tenant_id: str, provider_manager: ProviderManager | None = None) -> None: + self.tenant_id = tenant_id + self.provider_manager = provider_manager or ProviderManager() + + def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: + provider_configurations = self.provider_manager.get_configurations(self.tenant_id) + provider_configuration = provider_configurations.get(provider_name) + if not provider_configuration: + raise ValueError(f"Provider {provider_name} does not exist.") + + provider_model = provider_configuration.get_provider_model(model_type=ModelType.LLM, model=model_name) + if provider_model is None: + raise ModelNotExistError(f"Model {model_name} not exist.") + provider_model.raise_for_status() + + credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model_name) + if credentials is None: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + + return credentials + + +class DifyModelFactory: + tenant_id: str + model_manager: ModelManager + + def __init__(self, tenant_id: str, model_manager: ModelManager | None = None) -> None: + self.tenant_id = tenant_id + self.model_manager = model_manager or ModelManager() + + def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance: + return self.model_manager.get_model_instance( + tenant_id=self.tenant_id, + provider=provider_name, + model_type=ModelType.LLM, + model=model_name, + ) + + +def build_dify_model_access(tenant_id: str) -> tuple[CredentialsProvider, ModelFactory]: + return ( + DifyCredentialsProvider(tenant_id=tenant_id), + DifyModelFactory(tenant_id=tenant_id), + ) + + +def fetch_model_config( + *, + node_data_model: ModelConfig, + credentials_provider: CredentialsProvider, + model_factory: ModelFactory, +) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + if not node_data_model.mode: + raise LLMModeRequiredError("LLM mode is required.") + + credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name) + model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name) + provider_model_bundle = model_instance.provider_model_bundle + + provider_model = provider_model_bundle.configuration.get_provider_model( + model=node_data_model.name, + model_type=ModelType.LLM, + ) + if provider_model is None: + raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + provider_model.raise_for_status() + + stop: list[str] = [] + if "stop" in node_data_model.completion_params: + stop = node_data_model.completion_params.pop("stop") + + model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) + if not model_schema: + raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + + return model_instance, ModelConfigWithCredentialsEntity( + provider=node_data_model.provider, + model=node_data_model.name, + model_schema=model_schema, + mode=node_data_model.mode, + provider_model_bundle=provider_model_bundle, + credentials=credentials, + parameters=node_data_model.completion_params, + stop=stop, + ) diff --git a/api/core/app/workflow/node_factory.py b/api/core/app/workflow/node_factory.py index 965f3ddb1d..07dec1b070 100644 --- a/api/core/app/workflow/node_factory.py +++ b/api/core/app/workflow/node_factory.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, final from typing_extensions import override from configs import dify_config +from core.app.llm.model_access import build_dify_model_access from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.ssrf_proxy import ssrf_proxy @@ -20,8 +21,13 @@ from core.workflow.nodes.code.limits import CodeNodeLimits from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.nodes.llm.node import LLMNode from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING -from core.workflow.nodes.template_transform.template_renderer import CodeExecutorJinja2TemplateRenderer +from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode +from core.workflow.nodes.template_transform.template_renderer import ( + CodeExecutorJinja2TemplateRenderer, +) from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode if TYPE_CHECKING: @@ -95,6 +101,8 @@ class DifyNodeFactory(NodeFactory): ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES, ) + self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(graph_init_params.tenant_id) + @override def create_node(self, node_config: NodeConfigDict) -> Node: """ @@ -160,6 +168,16 @@ class DifyNodeFactory(NodeFactory): file_manager=self._http_request_file_manager, ) + if node_type == NodeType.LLM: + return LLMNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + credentials_provider=self._llm_credentials_provider, + model_factory=self._llm_model_factory, + ) + if node_type == NodeType.KNOWLEDGE_RETRIEVAL: return KnowledgeRetrievalNode( id=node_id, @@ -178,6 +196,26 @@ class DifyNodeFactory(NodeFactory): unstructured_api_config=self._document_extractor_unstructured_api_config, ) + if node_type == NodeType.QUESTION_CLASSIFIER: + return QuestionClassifierNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + credentials_provider=self._llm_credentials_provider, + model_factory=self._llm_model_factory, + ) + + if node_type == NodeType.PARAMETER_EXTRACTOR: + return ParameterExtractorNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + credentials_provider=self._llm_credentials_provider, + model_factory=self._llm_model_factory, + ) + return node_class( id=node_id, config=node_config, diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 5a28bbcc3a..ac096c5e54 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -35,7 +35,7 @@ class ModelInstance: def __init__(self, provider_model_bundle: ProviderModelBundle, model: str): self.provider_model_bundle = provider_model_bundle - self.model = model + self.model_name = model self.provider = provider_model_bundle.configuration.provider.provider self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) self.model_type_instance = self.provider_model_bundle.model_type_instance @@ -163,7 +163,7 @@ class ModelInstance: Union[LLMResult, Generator], self._round_robin_invoke( function=self.model_type_instance.invoke, - model=self.model, + model=self.model_name, credentials=self.credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, @@ -191,7 +191,7 @@ class ModelInstance: int, self._round_robin_invoke( function=self.model_type_instance.get_num_tokens, - model=self.model, + model=self.model_name, credentials=self.credentials, prompt_messages=prompt_messages, tools=tools, @@ -215,7 +215,7 @@ class ModelInstance: EmbeddingResult, self._round_robin_invoke( function=self.model_type_instance.invoke, - model=self.model, + model=self.model_name, credentials=self.credentials, texts=texts, user=user, @@ -243,7 +243,7 @@ class ModelInstance: EmbeddingResult, self._round_robin_invoke( function=self.model_type_instance.invoke, - model=self.model, + model=self.model_name, credentials=self.credentials, multimodel_documents=multimodel_documents, user=user, @@ -264,7 +264,7 @@ class ModelInstance: list[int], self._round_robin_invoke( function=self.model_type_instance.get_num_tokens, - model=self.model, + model=self.model_name, credentials=self.credentials, texts=texts, ), @@ -294,7 +294,7 @@ class ModelInstance: RerankResult, self._round_robin_invoke( function=self.model_type_instance.invoke, - model=self.model, + model=self.model_name, credentials=self.credentials, query=query, docs=docs, @@ -328,7 +328,7 @@ class ModelInstance: RerankResult, self._round_robin_invoke( function=self.model_type_instance.invoke_multimodal_rerank, - model=self.model, + model=self.model_name, credentials=self.credentials, query=query, docs=docs, @@ -352,7 +352,7 @@ class ModelInstance: bool, self._round_robin_invoke( function=self.model_type_instance.invoke, - model=self.model, + model=self.model_name, credentials=self.credentials, text=text, user=user, @@ -373,7 +373,7 @@ class ModelInstance: str, self._round_robin_invoke( function=self.model_type_instance.invoke, - model=self.model, + model=self.model_name, credentials=self.credentials, file=file, user=user, @@ -396,7 +396,7 @@ class ModelInstance: Iterable[bytes], self._round_robin_invoke( function=self.model_type_instance.invoke, - model=self.model, + model=self.model_name, credentials=self.credentials, content_text=content_text, user=user, @@ -469,7 +469,7 @@ class ModelInstance: if not isinstance(self.model_type_instance, TTSModel): raise Exception("Model type instance is not TTSModel") return self.model_type_instance.get_tts_model_voices( - model=self.model, credentials=self.credentials, language=language + model=self.model_name, credentials=self.credentials, language=language ) diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index a96b094e6d..2b32062140 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -47,7 +47,9 @@ class AgentHistoryPromptTransform(PromptTransform): model_type_instance = cast(LargeLanguageModel, model_type_instance) curr_message_tokens = model_type_instance.get_num_tokens( - self.memory.model_instance.model, self.memory.model_instance.credentials, self.history_messages + self.model_config.model, + self.model_config.credentials, + self.history_messages, ) if curr_message_tokens <= max_token_limit: return self.history_messages @@ -63,7 +65,9 @@ class AgentHistoryPromptTransform(PromptTransform): # a message is start with UserPromptMessage if isinstance(prompt_message, UserPromptMessage): curr_message_tokens = model_type_instance.get_num_tokens( - self.memory.model_instance.model, self.memory.model_instance.credentials, prompt_messages + self.model_config.model, + self.model_config.credentials, + prompt_messages, ) # if current message token is overflow, drop all the prompts in current message and break if curr_message_tokens > max_token_limit: diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 3cbc7db75d..0efe19a57c 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -35,7 +35,9 @@ class CacheEmbedding(Embeddings): embedding = ( db.session.query(Embedding) .filter_by( - model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider + model_name=self._model_instance.model_name, + hash=hash, + provider_name=self._model_instance.provider, ) .first() ) @@ -52,7 +54,7 @@ class CacheEmbedding(Embeddings): try: model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) model_schema = model_type_instance.get_model_schema( - self._model_instance.model, self._model_instance.credentials + self._model_instance.model_name, self._model_instance.credentials ) max_chunks = ( model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] @@ -87,7 +89,7 @@ class CacheEmbedding(Embeddings): hash = helper.generate_text_hash(texts[i]) if hash not in cache_embeddings: embedding_cache = Embedding( - model_name=self._model_instance.model, + model_name=self._model_instance.model_name, hash=hash, provider_name=self._model_instance.provider, embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL), @@ -114,7 +116,9 @@ class CacheEmbedding(Embeddings): embedding = ( db.session.query(Embedding) .filter_by( - model_name=self._model_instance.model, hash=file_id, provider_name=self._model_instance.provider + model_name=self._model_instance.model_name, + hash=file_id, + provider_name=self._model_instance.provider, ) .first() ) @@ -131,7 +135,7 @@ class CacheEmbedding(Embeddings): try: model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance) model_schema = model_type_instance.get_model_schema( - self._model_instance.model, self._model_instance.credentials + self._model_instance.model_name, self._model_instance.credentials ) max_chunks = ( model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] @@ -168,7 +172,7 @@ class CacheEmbedding(Embeddings): file_id = multimodel_documents[i]["file_id"] if file_id not in cache_embeddings: embedding_cache = Embedding( - model_name=self._model_instance.model, + model_name=self._model_instance.model_name, hash=file_id, provider_name=self._model_instance.provider, embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL), @@ -190,7 +194,7 @@ class CacheEmbedding(Embeddings): """Embed query text.""" # use doc embedding cache or store if not exists hash = helper.generate_text_hash(text) - embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}" + embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{hash}" embedding = redis_client.get(embedding_cache_key) if embedding: redis_client.expire(embedding_cache_key, 600) @@ -233,7 +237,7 @@ class CacheEmbedding(Embeddings): """Embed multimodal documents.""" # use doc embedding cache or store if not exists file_id = multimodel_document["file_id"] - embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{file_id}" + embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{file_id}" embedding = redis_client.get(embedding_cache_key) if embedding: redis_client.expire(embedding_cache_key, 600) diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 38309d3d77..690e780921 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -38,7 +38,7 @@ class RerankModelRunner(BaseRerankRunner): is_support_vision = model_manager.check_model_support_vision( tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id, provider=self.rerank_model_instance.provider, - model=self.rerank_model_instance.model, + model=self.rerank_model_instance.model_name, model_type=ModelType.RERANK, ) if not is_support_vision: diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index b4bae08a9b..e7fba09359 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -47,7 +47,7 @@ class ModelInvocationUtils: raise InvokeModelError("Model not found") llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) - schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials) if not schema: raise InvokeModelError("No model schema found") diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py index 78fad37659..341a1c1a4c 100644 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -8,7 +8,7 @@ from configs import dify_config from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance, ModelManager +from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -17,6 +17,8 @@ from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegme from core.workflow.enums import SystemVariableKey from core.workflow.file.models import File from core.workflow.nodes.llm.entities import ModelConfig +from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError +from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory from core.workflow.runtime import VariablePool from extensions.ext_database import db from libs.datetime_utils import naive_utc_now @@ -24,49 +26,46 @@ from models.model import Conversation from models.provider import Provider, ProviderType from models.provider_ids import ModelProviderID -from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError +from .exc import InvalidVariableTypeError def fetch_model_config( - tenant_id: str, node_data_model: ModelConfig + *, + node_data_model: ModelConfig, + credentials_provider: CredentialsProvider, + model_factory: ModelFactory, ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: if not node_data_model.mode: raise LLMModeRequiredError("LLM mode is required.") - model = ModelManager().get_model_instance( - tenant_id=tenant_id, - model_type=ModelType.LLM, - provider=node_data_model.provider, + credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name) + model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name) + provider_model_bundle = model_instance.provider_model_bundle + + provider_model = provider_model_bundle.configuration.get_provider_model( model=node_data_model.name, + model_type=ModelType.LLM, ) - - model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance) - - # check model - provider_model = model.provider_model_bundle.configuration.get_provider_model( - model=node_data_model.name, model_type=ModelType.LLM - ) - if provider_model is None: raise ModelNotExistError(f"Model {node_data_model.name} not exist.") provider_model.raise_for_status() - # model config stop: list[str] = [] if "stop" in node_data_model.completion_params: stop = node_data_model.completion_params.pop("stop") - model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials) + model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) if not model_schema: raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - return model, ModelConfigWithCredentialsEntity( + model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) + return model_instance, ModelConfigWithCredentialsEntity( provider=node_data_model.provider, model=node_data_model.name, model_schema=model_schema, mode=node_data_model.mode, - provider_model_bundle=model.provider_model_bundle, - credentials=model.credentials, + provider_model_bundle=provider_model_bundle, + credentials=credentials, parameters=node_data_model.completion_params, stop=stop, ) @@ -131,7 +130,7 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs if quota_unit == QuotaUnit.TOKENS: used_quota = usage.total_tokens elif quota_unit == QuotaUnit.CREDITS: - used_quota = dify_config.get_model_credits(model_instance.model) + used_quota = dify_config.get_model_credits(model_instance.model_name) else: used_quota = 1 diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 49ae5d16c7..0259434d90 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -16,7 +16,7 @@ from core.helper.code_executor import CodeExecutor, CodeLanguage from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance, ModelManager +from core.model_manager import ModelInstance from core.model_runtime.entities import ( ImagePromptMessageContent, PromptMessage, @@ -38,11 +38,7 @@ from core.model_runtime.entities.message_entities import ( SystemPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import ( - ModelFeature, - ModelPropertyKey, - ModelType, -) +from core.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil @@ -76,6 +72,7 @@ from core.workflow.node_events import ( from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser +from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory from core.workflow.runtime import VariablePool from extensions.ext_database import db from models.dataset import SegmentAttachmentBinding @@ -93,7 +90,6 @@ from .exc import ( InvalidVariableTypeError, LLMNodeError, MemoryRolePrefixRequiredError, - ModelNotExistError, NoPromptFoundError, TemplateTypeNotSupportError, VariableNotFoundError, @@ -118,6 +114,8 @@ class LLMNode(Node[LLMNodeData]): _file_outputs: list[File] _llm_file_saver: LLMFileSaver + _credentials_provider: CredentialsProvider + _model_factory: ModelFactory def __init__( self, @@ -126,6 +124,8 @@ class LLMNode(Node[LLMNodeData]): graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, *, + credentials_provider: CredentialsProvider, + model_factory: ModelFactory, llm_file_saver: LLMFileSaver | None = None, ): super().__init__( @@ -137,6 +137,9 @@ class LLMNode(Node[LLMNodeData]): # LLM file outputs, used for MultiModal outputs. self._file_outputs = [] + self._credentials_provider = credentials_provider + self._model_factory = model_factory + if llm_file_saver is None: llm_file_saver = FileSaverImpl( user_id=graph_init_params.user_id, @@ -199,10 +202,21 @@ class LLMNode(Node[LLMNodeData]): node_inputs["#context_files#"] = [file.model_dump() for file in context_files] # fetch model config - model_instance, model_config = LLMNode._fetch_model_config( + model_instance, model_config = self._fetch_model_config( node_data_model=self.node_data.model, - tenant_id=self.tenant_id, ) + model_name = getattr(model_instance, "model_name", None) + if not isinstance(model_name, str): + model_name = model_config.model + model_provider = getattr(model_instance, "provider", None) + if not isinstance(model_provider, str): + model_provider = model_config.provider + model_schema = model_instance.model_type_instance.get_model_schema( + model_name, + model_instance.credentials, + ) + if not model_schema: + raise ValueError(f"Model schema not found for {model_name}") # fetch memory memory = llm_utils.fetch_memory( @@ -225,14 +239,16 @@ class LLMNode(Node[LLMNodeData]): sys_files=files, context=context, memory=memory, - model_config=model_config, + model_instance=model_instance, + model_schema=model_schema, + model_parameters=self.node_data.model.completion_params, + stop=model_config.stop, prompt_template=self.node_data.prompt_template, memory_config=self.node_data.memory, vision_enabled=self.node_data.vision.enabled, vision_detail=self.node_data.vision.configs.detail, variable_pool=variable_pool, jinja2_variables=self.node_data.prompt_config.jinja2_variables, - tenant_id=self.tenant_id, context_files=context_files, ) @@ -286,14 +302,14 @@ class LLMNode(Node[LLMNodeData]): structured_output = event process_data = { - "model_mode": model_config.mode, + "model_mode": self.node_data.model.mode, "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=model_config.mode, prompt_messages=prompt_messages + model_mode=self.node_data.model.mode, prompt_messages=prompt_messages ), "usage": jsonable_encoder(usage), "finish_reason": finish_reason, - "model_provider": model_config.provider, - "model_name": model_config.model, + "model_provider": model_provider, + "model_name": model_name, } outputs = { @@ -755,21 +771,18 @@ class LLMNode(Node[LLMNodeData]): return None - @staticmethod def _fetch_model_config( + self, *, node_data_model: ModelConfig, - tenant_id: str, ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: model, model_config_with_cred = llm_utils.fetch_model_config( - tenant_id=tenant_id, node_data_model=node_data_model + node_data_model=node_data_model, + credentials_provider=self._credentials_provider, + model_factory=self._model_factory, ) completion_params = model_config_with_cred.parameters - model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials) - if not model_schema: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - model_config_with_cred.parameters = completion_params # NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`. node_data_model.completion_params = completion_params @@ -782,14 +795,16 @@ class LLMNode(Node[LLMNodeData]): sys_files: Sequence[File], context: str | None = None, memory: TokenBufferMemory | None = None, - model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, + model_schema: AIModelEntity, + model_parameters: Mapping[str, Any], prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, + stop: Sequence[str] | None = None, memory_config: MemoryConfig | None = None, vision_enabled: bool = False, vision_detail: ImagePromptMessageContent.DETAIL, variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], - tenant_id: str, context_files: list[File] | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] @@ -810,7 +825,9 @@ class LLMNode(Node[LLMNodeData]): memory_messages = _handle_memory_chat_mode( memory=memory, memory_config=memory_config, - model_config=model_config, + model_instance=model_instance, + model_schema=model_schema, + model_parameters=model_parameters, ) # Extend prompt_messages with memory messages prompt_messages.extend(memory_messages) @@ -847,7 +864,9 @@ class LLMNode(Node[LLMNodeData]): memory_text = _handle_memory_completion_mode( memory=memory, memory_config=memory_config, - model_config=model_config, + model_instance=model_instance, + model_schema=model_schema, + model_parameters=model_parameters, ) # Insert histories into the prompt prompt_content = prompt_messages[0].content @@ -924,7 +943,7 @@ class LLMNode(Node[LLMNodeData]): prompt_message_content: list[PromptMessageContentUnionTypes] = [] for content_item in prompt_message.content: # Skip content if features are not defined - if not model_config.model_schema.features: + if not model_schema.features: if content_item.type != PromptMessageContentType.TEXT: continue prompt_message_content.append(content_item) @@ -934,19 +953,19 @@ class LLMNode(Node[LLMNodeData]): if ( ( content_item.type == PromptMessageContentType.IMAGE - and ModelFeature.VISION not in model_config.model_schema.features + and ModelFeature.VISION not in model_schema.features ) or ( content_item.type == PromptMessageContentType.DOCUMENT - and ModelFeature.DOCUMENT not in model_config.model_schema.features + and ModelFeature.DOCUMENT not in model_schema.features ) or ( content_item.type == PromptMessageContentType.VIDEO - and ModelFeature.VIDEO not in model_config.model_schema.features + and ModelFeature.VIDEO not in model_schema.features ) or ( content_item.type == PromptMessageContentType.AUDIO - and ModelFeature.AUDIO not in model_config.model_schema.features + and ModelFeature.AUDIO not in model_schema.features ) ): continue @@ -965,19 +984,7 @@ class LLMNode(Node[LLMNodeData]): "Please ensure a prompt is properly configured before proceeding." ) - model = ModelManager().get_model_instance( - tenant_id=tenant_id, - model_type=ModelType.LLM, - provider=model_config.provider, - model=model_config.model, - ) - model_schema = model.model_type_instance.get_model_schema( - model=model_config.model, - credentials=model.credentials, - ) - if not model_schema: - raise ModelNotExistError(f"Model {model_config.model} not exist.") - return filtered_prompt_messages, model_config.stop + return filtered_prompt_messages, stop @classmethod def _extract_variable_selector_to_variable_mapping( @@ -1306,26 +1313,26 @@ def _render_jinja2_message( def _calculate_rest_token( - *, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity + *, + prompt_messages: list[PromptMessage], + model_instance: ModelInstance, + model_schema: AIModelEntity, + model_parameters: Mapping[str, Any], ) -> int: rest_tokens = 2000 - model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: - model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, model=model_config.model - ) - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) max_tokens = 0 - for parameter_rule in model_config.model_schema.parameter_rules: + for parameter_rule in model_schema.parameter_rules: if parameter_rule.name == "max_tokens" or ( parameter_rule.use_template and parameter_rule.use_template == "max_tokens" ): max_tokens = ( - model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(str(parameter_rule.use_template)) + model_parameters.get(parameter_rule.name) + or model_parameters.get(str(parameter_rule.use_template)) or 0 ) @@ -1339,12 +1346,19 @@ def _handle_memory_chat_mode( *, memory: TokenBufferMemory | None, memory_config: MemoryConfig | None, - model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, + model_schema: AIModelEntity, + model_parameters: Mapping[str, Any], ) -> Sequence[PromptMessage]: memory_messages: Sequence[PromptMessage] = [] # Get messages from memory for chat model if memory and memory_config: - rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) + rest_tokens = _calculate_rest_token( + prompt_messages=[], + model_instance=model_instance, + model_schema=model_schema, + model_parameters=model_parameters, + ) memory_messages = memory.get_history_prompt_messages( max_token_limit=rest_tokens, message_limit=memory_config.window.size if memory_config.window.enabled else None, @@ -1356,12 +1370,19 @@ def _handle_memory_completion_mode( *, memory: TokenBufferMemory | None, memory_config: MemoryConfig | None, - model_config: ModelConfigWithCredentialsEntity, + model_instance: ModelInstance, + model_schema: AIModelEntity, + model_parameters: Mapping[str, Any], ) -> str: memory_text = "" # Get history text from memory for completion model if memory and memory_config: - rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) + rest_tokens = _calculate_rest_token( + prompt_messages=[], + model_instance=model_instance, + model_schema=model_schema, + model_parameters=model_parameters, + ) if not memory_config.role_prefix: raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") memory_text = memory.get_history_prompt_text( diff --git a/api/core/workflow/nodes/llm/protocols.py b/api/core/workflow/nodes/llm/protocols.py new file mode 100644 index 0000000000..8e0365299d --- /dev/null +++ b/api/core/workflow/nodes/llm/protocols.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from typing import Any, Protocol + +from core.model_manager import ModelInstance + + +class CredentialsProvider(Protocol): + """Port for loading runtime credentials for a provider/model pair.""" + + def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: + """Return credentials for the target provider/model or raise a domain error.""" + ... + + +class ModelFactory(Protocol): + """Port for creating initialized LLM model instances for execution.""" + + def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance: + """Create a model instance that is ready for schema lookup and invocation.""" + ... diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 2f11a91b7e..f549d44efa 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -3,7 +3,7 @@ import json import logging import uuid from collections.abc import Mapping, Sequence -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -60,6 +60,11 @@ from .prompts import ( logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from core.workflow.entities import GraphInitParams + from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory + from core.workflow.runtime import GraphRuntimeState + def extract_json(text): """ @@ -92,6 +97,27 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): _model_instance: ModelInstance | None = None _model_config: ModelConfigWithCredentialsEntity | None = None + _credentials_provider: "CredentialsProvider" + _model_factory: "ModelFactory" + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + *, + credentials_provider: "CredentialsProvider", + model_factory: "ModelFactory", + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + self._credentials_provider = credentials_provider + self._model_factory = model_factory @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: @@ -806,7 +832,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): """ if not self._model_instance or not self._model_config: self._model_instance, self._model_config = llm_utils.fetch_model_config( - tenant_id=self.tenant_id, node_data_model=node_data_model + node_data_model=node_data_model, + credentials_provider=self._credentials_provider, + model_factory=self._model_factory, ) return self._model_instance, self._model_config diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 6491e8e531..3f41c0d0b7 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -24,6 +24,7 @@ from core.workflow.nodes.base.node import Node from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver +from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory from libs.json_in_md_parser import parse_and_check_json_markdown from .entities import QuestionClassifierNodeData @@ -49,6 +50,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): _file_outputs: list["File"] _llm_file_saver: LLMFileSaver + _credentials_provider: "CredentialsProvider" + _model_factory: "ModelFactory" def __init__( self, @@ -57,6 +60,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, + credentials_provider: "CredentialsProvider", + model_factory: "ModelFactory", llm_file_saver: LLMFileSaver | None = None, ): super().__init__( @@ -68,6 +73,9 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): # LLM file outputs, used for MultiModal outputs. self._file_outputs = [] + self._credentials_provider = credentials_provider + self._model_factory = model_factory + if llm_file_saver is None: llm_file_saver = FileSaverImpl( user_id=graph_init_params.user_id, @@ -89,9 +97,16 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): variables = {"query": query} # fetch model config model_instance, model_config = llm_utils.fetch_model_config( - tenant_id=self.tenant_id, node_data_model=node_data.model, + credentials_provider=self._credentials_provider, + model_factory=self._model_factory, ) + model_schema = model_instance.model_type_instance.get_model_schema( + model_instance.model_name, + model_instance.credentials, + ) + if not model_schema: + raise ValueError(f"Model schema not found for {model_instance.model_name}") # fetch memory memory = llm_utils.fetch_memory( variable_pool=variable_pool, @@ -133,13 +148,15 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): prompt_template=prompt_template, sys_query="", memory=memory, - model_config=model_config, + model_instance=model_instance, + model_schema=model_schema, + model_parameters=node_data.model.completion_params, + stop=model_config.stop, sys_files=files, vision_enabled=node_data.vision.enabled, vision_detail=node_data.vision.configs.detail, variable_pool=variable_pool, jinja2_variables=[], - tenant_id=self.tenant_id, ) result_text = "" diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 29ffb8027f..a724fbcab7 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -1,8 +1,7 @@ import logging import time -import uuid from collections.abc import Generator, Mapping, Sequence -from typing import Any +from typing import Any, cast from configs import dify_config from core.app.apps.exc import GenerateTaskStoppedError @@ -11,6 +10,7 @@ from core.app.workflow.layers.observability import ObservabilityLayer from core.app.workflow.node_factory import DifyNodeFactory from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID from core.workflow.entities import GraphInitParams +from core.workflow.entities.graph_config import NodeConfigData, NodeConfigDict from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.file.models import File from core.workflow.graph import Graph @@ -168,7 +168,8 @@ class WorkflowEntry: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - node = node_factory.create_node(node_config) + typed_node_config = cast(dict[str, object], node_config) + node = cast(Any, node_factory).create_node(typed_node_config) node_cls = type(node) try: @@ -256,7 +257,7 @@ class WorkflowEntry: @classmethod def run_free_node( - cls, node_data: dict, node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any] + cls, node_data: dict[str, Any], node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any] ) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]: """ Run free node @@ -302,16 +303,15 @@ class WorkflowEntry: graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) # init workflow run state - node_config = { + node_config: NodeConfigDict = { "id": node_id, - "data": node_data, + "data": cast(NodeConfigData, node_data), } - node: Node = node_cls( - id=str(uuid.uuid4()), - config=node_config, + node_factory = DifyNodeFactory( graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) + node = node_factory.create_node(node_config) try: # variable selector to variable mapping diff --git a/api/services/app_service.py b/api/services/app_service.py index af458ff618..e57253f8b6 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -107,19 +107,19 @@ class AppService: if model_instance: if ( - model_instance.model == default_model_config["model"]["name"] + model_instance.model_name == default_model_config["model"]["name"] and model_instance.provider == default_model_config["model"]["provider"] ): default_model_dict = default_model_config["model"] else: llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) - model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + model_schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials) if model_schema is None: - raise ValueError(f"model schema not found for model {model_instance.model}") + raise ValueError(f"model schema not found for model {model_instance.model_name}") default_model_dict = { "provider": model_instance.provider, - "name": model_instance.model, + "name": model_instance.model_name, "mode": model_schema.model_properties.get(ModelPropertyKey.MODE), "completion_params": {}, } diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 785e02a19a..35b20f7601 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -252,7 +252,7 @@ class DatasetService: dataset.updated_by = account.id dataset.tenant_id = tenant_id dataset.embedding_model_provider = embedding_model.provider if embedding_model else None - dataset.embedding_model = embedding_model.model if embedding_model else None + dataset.embedding_model = embedding_model.model_name if embedding_model else None dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None dataset.permission = permission or DatasetPermissionEnum.ONLY_ME dataset.provider = provider @@ -384,7 +384,7 @@ class DatasetService: model=model, ) text_embedding_model = cast(TextEmbeddingModel, model_instance.model_type_instance) - model_schema = text_embedding_model.get_model_schema(model_instance.model, model_instance.credentials) + model_schema = text_embedding_model.get_model_schema(model_instance.model_name, model_instance.credentials) if not model_schema: raise ValueError("Model schema not found") if model_schema.features and ModelFeature.VISION in model_schema.features: @@ -743,10 +743,12 @@ class DatasetService: model_type=ModelType.TEXT_EMBEDDING, model=data["embedding_model"], ) - filtered_data["embedding_model"] = embedding_model.model + embedding_model_name = embedding_model.model_name + filtered_data["embedding_model"] = embedding_model_name filtered_data["embedding_model_provider"] = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model + embedding_model.provider, + embedding_model_name, ) filtered_data["collection_binding_id"] = dataset_collection_binding.id except LLMBadRequestError: @@ -876,10 +878,12 @@ class DatasetService: return # Apply new embedding model settings - filtered_data["embedding_model"] = embedding_model.model + embedding_model_name = embedding_model.model_name + filtered_data["embedding_model"] = embedding_model_name filtered_data["embedding_model_provider"] = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model + embedding_model.provider, + embedding_model_name, ) filtered_data["collection_binding_id"] = dataset_collection_binding.id @@ -955,10 +959,12 @@ class DatasetService: knowledge_configuration.embedding_model, ) dataset.is_multimodal = is_multimodal - dataset.embedding_model = embedding_model.model + embedding_model_name = embedding_model.model_name + dataset.embedding_model = embedding_model_name dataset.embedding_model_provider = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model + embedding_model.provider, + embedding_model_name, ) dataset.collection_binding_id = dataset_collection_binding.id elif knowledge_configuration.indexing_technique == "economy": @@ -989,10 +995,12 @@ class DatasetService: model_type=ModelType.TEXT_EMBEDDING, model=knowledge_configuration.embedding_model, ) - dataset.embedding_model = embedding_model.model + embedding_model_name = embedding_model.model_name + dataset.embedding_model = embedding_model_name dataset.embedding_model_provider = embedding_model.provider dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model + embedding_model.provider, + embedding_model_name, ) is_multimodal = DatasetService.check_is_multimodal_model( current_user.current_tenant_id, @@ -1049,11 +1057,13 @@ class DatasetService: skip_embedding_update = True if not skip_embedding_update: if embedding_model: - dataset.embedding_model = embedding_model.model + embedding_model_name = embedding_model.model_name + dataset.embedding_model = embedding_model_name dataset.embedding_model_provider = embedding_model.provider dataset_collection_binding = ( DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model + embedding_model.provider, + embedding_model_name, ) ) dataset.collection_binding_id = dataset_collection_binding.id @@ -1884,7 +1894,7 @@ class DocumentService: embedding_model = model_manager.get_default_model_instance( tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING ) - dataset_embedding_model = embedding_model.model + dataset_embedding_model = embedding_model.model_name dataset_embedding_model_provider = embedding_model.provider dataset.embedding_model = dataset_embedding_model dataset.embedding_model_provider = dataset_embedding_model_provider diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index c361bfcc6f..1b341e8f21 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -80,6 +80,8 @@ def init_llm_node(config: dict) -> LLMNode: config=config, graph_init_params=init_params, graph_runtime_state=graph_runtime_state, + credentials_provider=MagicMock(), + model_factory=MagicMock(), ) return node @@ -115,7 +117,7 @@ def test_execute_llm(): db.session.close = MagicMock() # Mock the _fetch_model_config to avoid database calls - def mock_fetch_model_config(**_kwargs): + def mock_fetch_model_config(*_args, **_kwargs): from decimal import Decimal from unittest.mock import MagicMock @@ -227,7 +229,7 @@ def test_execute_llm_with_jinja2(): db.session.close = MagicMock() # Mock the _fetch_model_config method - def mock_fetch_model_config(**_kwargs): + def mock_fetch_model_config(*_args, **_kwargs): from decimal import Decimal from unittest.mock import MagicMock diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 7445699a86..88edc4f9b3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -9,6 +9,7 @@ from core.model_runtime.entities import AssistantPromptMessage from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus from core.workflow.graph import Graph +from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable @@ -84,6 +85,8 @@ def init_parameter_extractor_node(config: dict): config=config, graph_init_params=init_params, graph_runtime_state=graph_runtime_state, + credentials_provider=MagicMock(spec=CredentialsProvider), + model_factory=MagicMock(spec=ModelFactory), ) return node diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index 608fc76bd2..f6d9dfddae 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -331,7 +331,7 @@ class TestDatasetServiceUpdateDataset: ) embedding_model = Mock() - embedding_model.model = "text-embedding-ada-002" + embedding_model.model_name = "text-embedding-ada-002" embedding_model.provider = "openai" binding = Mock() @@ -424,7 +424,7 @@ class TestDatasetServiceUpdateDataset: ) embedding_model = Mock() - embedding_model.model = "text-embedding-3-small" + embedding_model.model_name = "text-embedding-3-small" embedding_model.provider = "openai" binding = Mock() diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index 025a0d8d70..63596bc320 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -82,7 +82,7 @@ class TestCacheEmbeddingDocuments: Mock: Configured ModelInstance with text embedding capabilities """ model_instance = Mock() - model_instance.model = "text-embedding-ada-002" + model_instance.model_name = "text-embedding-ada-002" model_instance.provider = "openai" model_instance.credentials = {"api_key": "test-key"} @@ -597,7 +597,7 @@ class TestCacheEmbeddingQuery: def mock_model_instance(self): """Create a mock ModelInstance for testing.""" model_instance = Mock() - model_instance.model = "text-embedding-ada-002" + model_instance.model_name = "text-embedding-ada-002" model_instance.provider = "openai" model_instance.credentials = {"api_key": "test-key"} return model_instance @@ -830,7 +830,7 @@ class TestEmbeddingModelSwitching: """ # Arrange model_instance_ada = Mock() - model_instance_ada.model = "text-embedding-ada-002" + model_instance_ada.model_name = "text-embedding-ada-002" model_instance_ada.provider = "openai" # Mock model type instance for ada @@ -841,7 +841,7 @@ class TestEmbeddingModelSwitching: model_type_instance_ada.get_model_schema.return_value = model_schema_ada model_instance_3_small = Mock() - model_instance_3_small.model = "text-embedding-3-small" + model_instance_3_small.model_name = "text-embedding-3-small" model_instance_3_small.provider = "openai" # Mock model type instance for 3-small @@ -914,11 +914,11 @@ class TestEmbeddingModelSwitching: """ # Arrange model_instance_openai = Mock() - model_instance_openai.model = "text-embedding-ada-002" + model_instance_openai.model_name = "text-embedding-ada-002" model_instance_openai.provider = "openai" model_instance_cohere = Mock() - model_instance_cohere.model = "embed-english-v3.0" + model_instance_cohere.model_name = "embed-english-v3.0" model_instance_cohere.provider = "cohere" cache_openai = CacheEmbedding(model_instance_openai) @@ -1001,7 +1001,7 @@ class TestEmbeddingDimensionValidation: def mock_model_instance(self): """Create a mock ModelInstance for testing.""" model_instance = Mock() - model_instance.model = "text-embedding-ada-002" + model_instance.model_name = "text-embedding-ada-002" model_instance.provider = "openai" model_instance.credentials = {"api_key": "test-key"} @@ -1123,7 +1123,7 @@ class TestEmbeddingDimensionValidation: """ # Arrange - OpenAI ada-002 (1536 dimensions) model_instance_ada = Mock() - model_instance_ada.model = "text-embedding-ada-002" + model_instance_ada.model_name = "text-embedding-ada-002" model_instance_ada.provider = "openai" # Mock model type instance for ada @@ -1156,7 +1156,7 @@ class TestEmbeddingDimensionValidation: # Arrange - Cohere embed-english-v3.0 (1024 dimensions) model_instance_cohere = Mock() - model_instance_cohere.model = "embed-english-v3.0" + model_instance_cohere.model_name = "embed-english-v3.0" model_instance_cohere.provider = "cohere" # Mock model type instance for cohere @@ -1225,7 +1225,7 @@ class TestEmbeddingEdgeCases: - MAX_CHUNKS: 10 """ model_instance = Mock() - model_instance.model = "text-embedding-ada-002" + model_instance.model_name = "text-embedding-ada-002" model_instance.provider = "openai" model_type_instance = Mock() @@ -1702,7 +1702,7 @@ class TestEmbeddingCachePerformance: - MAX_CHUNKS: 10 """ model_instance = Mock() - model_instance.model = "text-embedding-ada-002" + model_instance.model_name = "text-embedding-ada-002" model_instance.provider = "openai" model_type_instance = Mock() diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index ebe6c37818..3cecc92c16 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -34,7 +34,7 @@ def create_mock_model_instance(): mock_instance.provider_model_bundle.configuration = Mock() mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id" mock_instance.provider = "test-provider" - mock_instance.model = "test-model" + mock_instance.model_name = "test-model" return mock_instance @@ -65,7 +65,7 @@ class TestRerankModelRunner: mock_instance.provider_model_bundle.configuration = Mock() mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id" mock_instance.provider = "test-provider" - mock_instance.model = "test-model" + mock_instance.model_name = "test-model" return mock_instance @pytest.fixture diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py index 1c6d057863..b291f95e0f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py @@ -199,11 +199,32 @@ def test_mock_config_builder(): def test_mock_factory_node_type_detection(): """Test that MockNodeFactory correctly identifies nodes to mock.""" + from core.app.entities.app_invoke_entities import InvokeFrom + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState, VariablePool + from models.enums import UserFrom + from .test_mock_factory import MockNodeFactory + graph_init_params = GraphInitParams( + tenant_id="test", + app_id="test", + workflow_id="test", + graph_config={}, + user_id="test", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), + start_at=0, + total_tokens=0, + node_run_steps=0, + ) factory = MockNodeFactory( - graph_init_params=None, # Will be set by test - graph_runtime_state=None, # Will be set by test + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, mock_config=None, ) @@ -288,7 +309,11 @@ def test_workflow_without_auto_mock(): def test_register_custom_mock_node(): """Test registering a custom mock implementation for a node type.""" + from core.app.entities.app_invoke_entities import InvokeFrom + from core.workflow.entities import GraphInitParams from core.workflow.nodes.template_transform import TemplateTransformNode + from core.workflow.runtime import GraphRuntimeState, VariablePool + from models.enums import UserFrom from .test_mock_factory import MockNodeFactory @@ -298,9 +323,25 @@ def test_register_custom_mock_node(): # Custom mock implementation pass + graph_init_params = GraphInitParams( + tenant_id="test", + app_id="test", + workflow_id="test", + graph_config={}, + user_id="test", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), + start_at=0, + total_tokens=0, + node_run_steps=0, + ) factory = MockNodeFactory( - graph_init_params=None, - graph_runtime_state=None, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, mock_config=None, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py index 194d009288..b117b26b4c 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py @@ -1,9 +1,9 @@ import datetime import time from collections.abc import Iterable +from unittest import mock from unittest.mock import MagicMock -from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.message_entities import PromptMessageRole from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph @@ -82,7 +82,7 @@ def _build_branching_graph( def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: llm_data = LLMNodeData( title=title, - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), prompt_template=[ LLMNodeChatModelMessage( text=prompt_text, @@ -101,6 +101,8 @@ def _build_branching_graph( graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, mock_config=mock_config, + credentials_provider=mock.Mock(), + model_factory=mock.Mock(), ) return llm_node diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py index d8f229205b..45505909ea 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py @@ -1,8 +1,8 @@ import datetime import time +from unittest import mock from unittest.mock import MagicMock -from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.message_entities import PromptMessageRole from core.workflow.entities import GraphInitParams from core.workflow.graph import Graph @@ -78,7 +78,7 @@ def _build_llm_human_llm_graph( def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: llm_data = LLMNodeData( title=title, - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), prompt_template=[ LLMNodeChatModelMessage( text=prompt_text, @@ -97,6 +97,8 @@ def _build_llm_human_llm_graph( graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, mock_config=mock_config, + credentials_provider=mock.Mock(), + model_factory=mock.Mock(), ) return llm_node diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py index 9fa6ee57eb..f33d37e8ff 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py @@ -1,4 +1,5 @@ import time +from unittest import mock from core.model_runtime.entities.llm_entities import LLMMode from core.model_runtime.entities.message_entities import PromptMessageRole @@ -85,6 +86,8 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, mock_config=mock_config, + credentials_provider=mock.Mock(), + model_factory=mock.Mock(), ) return llm_node diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 8c58fe1922..186f8a8425 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -5,6 +5,7 @@ This module provides a MockNodeFactory that automatically detects and mocks node requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request). """ +from collections.abc import Mapping from typing import TYPE_CHECKING, Any from core.app.workflow.node_factory import DifyNodeFactory @@ -74,7 +75,7 @@ class MockNodeFactory(DifyNodeFactory): NodeType.CODE: MockCodeNode, } - def create_node(self, node_config: dict[str, Any]) -> Node: + def create_node(self, node_config: Mapping[str, Any]) -> Node: """ Create a node instance, using mock implementations for third-party service nodes. @@ -123,6 +124,16 @@ class MockNodeFactory(DifyNodeFactory): mock_config=self.mock_config, http_request_config=self._http_request_config, ) + elif node_type in {NodeType.LLM, NodeType.QUESTION_CLASSIFIER, NodeType.PARAMETER_EXTRACTOR}: + mock_instance = mock_class( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + mock_config=self.mock_config, + credentials_provider=self._llm_credentials_provider, + model_factory=self._llm_model_factory, + ) else: mock_instance = mock_class( id=node_id, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py index 1cda6ced31..aae4de9a27 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py @@ -16,9 +16,33 @@ from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNo def test_mock_factory_registers_iteration_node(): """Test that MockNodeFactory has iteration node registered.""" + from core.app.entities.app_invoke_entities import InvokeFrom + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState, VariablePool + from models.enums import UserFrom # Create a MockNodeFactory instance - factory = MockNodeFactory(graph_init_params=None, graph_runtime_state=None, mock_config=None) + graph_init_params = GraphInitParams( + tenant_id="test", + app_id="test", + workflow_id="test", + graph_config={"nodes": [], "edges": []}, + user_id="test", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), + start_at=0, + total_tokens=0, + node_run_steps=0, + ) + factory = MockNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + mock_config=None, + ) # Check that iteration node is registered assert NodeType.ITERATION in factory._mock_node_types diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 2179ff663b..71e8a9d863 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -8,6 +8,7 @@ allowing tests to run without external dependencies. import time from collections.abc import Generator, Mapping from typing import TYPE_CHECKING, Any, Optional +from unittest.mock import MagicMock from core.model_runtime.entities.llm_entities import LLMUsage from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -18,6 +19,7 @@ from core.workflow.nodes.document_extractor import DocumentExtractorNode from core.workflow.nodes.http_request import HttpRequestNode from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode from core.workflow.nodes.llm import LLMNode +from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory from core.workflow.nodes.parameter_extractor import ParameterExtractorNode from core.workflow.nodes.question_classifier import QuestionClassifierNode from core.workflow.nodes.template_transform import TemplateTransformNode @@ -42,6 +44,10 @@ class MockNodeMixin: mock_config: Optional["MockConfig"] = None, **kwargs: Any, ): + if isinstance(self, (LLMNode, QuestionClassifierNode)): + kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) + kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) + super().__init__( id=id, config=config, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py index eaf1317937..1b781545f5 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py @@ -101,11 +101,32 @@ def test_node_mock_config(): def test_mock_factory_detection(): """Test MockNodeFactory node type detection.""" + from core.app.entities.app_invoke_entities import InvokeFrom + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState, VariablePool + from models.enums import UserFrom + print("Testing MockNodeFactory detection...") + graph_init_params = GraphInitParams( + tenant_id="test", + app_id="test", + workflow_id="test", + graph_config={}, + user_id="test", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), + start_at=0, + total_tokens=0, + node_run_steps=0, + ) factory = MockNodeFactory( - graph_init_params=None, - graph_runtime_state=None, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, mock_config=None, ) @@ -133,11 +154,32 @@ def test_mock_factory_detection(): def test_mock_factory_registration(): """Test registering and unregistering mock node types.""" + from core.app.entities.app_invoke_entities import InvokeFrom + from core.workflow.entities import GraphInitParams + from core.workflow.runtime import GraphRuntimeState, VariablePool + from models.enums import UserFrom + print("Testing MockNodeFactory registration...") + graph_init_params = GraphInitParams( + tenant_id="test", + app_id="test", + workflow_id="test", + graph_config={}, + user_id="test", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ) + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), + start_at=0, + total_tokens=0, + node_run_steps=0, + ) factory = MockNodeFactory( - graph_init_params=None, - graph_runtime_state=None, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, mock_config=None, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index b0661f7d29..ebabf66b41 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -6,6 +6,7 @@ from unittest import mock import pytest from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, SystemConfiguration from core.model_runtime.entities.common_entities import I18nObject @@ -32,6 +33,7 @@ from core.workflow.nodes.llm.entities import ( ) from core.workflow.nodes.llm.file_saver import LLMFileSaver from core.workflow.nodes.llm.node import LLMNode +from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory from core.workflow.runtime import GraphRuntimeState, VariablePool from core.workflow.system_variable import SystemVariable from models.enums import UserFrom @@ -100,6 +102,8 @@ def llm_node( llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState ) -> LLMNode: mock_file_saver = mock.MagicMock(spec=LLMFileSaver) + mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) + mock_model_factory = mock.MagicMock(spec=ModelFactory) node_config = { "id": "1", "data": llm_node_data.model_dump(), @@ -109,13 +113,29 @@ def llm_node( config=node_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, + credentials_provider=mock_credentials_provider, + model_factory=mock_model_factory, llm_file_saver=mock_file_saver, ) return node @pytest.fixture -def model_config(): +def model_config(monkeypatch): + from tests.integration_tests.model_runtime.__mock.plugin_model import MockModelClass + + def mock_plugin_model_providers(_self): + providers = MockModelClass().fetch_model_providers("test") + for provider in providers: + provider.declaration.provider = f"{provider.plugin_id}/{provider.declaration.provider}" + return providers + + monkeypatch.setattr( + ModelProviderFactory, + "get_plugin_model_providers", + mock_plugin_model_providers, + ) + # Create actual provider and model type instances model_provider_factory = ModelProviderFactory(tenant_id="test") provider_instance = model_provider_factory.get_plugin_model_provider("openai") @@ -125,7 +145,7 @@ def model_config(): provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( tenant_id="1", - provider=provider_instance, + provider=provider_instance.declaration, preferred_provider_type=ProviderType.CUSTOM, using_provider_type=ProviderType.CUSTOM, system_configuration=SystemConfiguration(enabled=False), @@ -153,6 +173,89 @@ def model_config(): ) +def test_fetch_model_config_uses_ports(model_config: ModelConfigWithCredentialsEntity): + mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) + mock_model_factory = mock.MagicMock(spec=ModelFactory) + + provider_model_bundle = model_config.provider_model_bundle + model_type_instance = provider_model_bundle.model_type_instance + provider_model = mock.MagicMock() + + model_instance = mock.MagicMock( + model_type_instance=model_type_instance, + provider_model_bundle=provider_model_bundle, + ) + + mock_credentials_provider.fetch.return_value = {"api_key": "test"} + mock_model_factory.init_model_instance.return_value = model_instance + + with ( + mock.patch.object( + provider_model_bundle.configuration.__class__, + "get_provider_model", + return_value=provider_model, + ), + mock.patch.object( + model_type_instance.__class__, + "get_model_schema", + return_value=model_config.model_schema, + ), + ): + fetch_model_config( + node_data_model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), + credentials_provider=mock_credentials_provider, + model_factory=mock_model_factory, + ) + + mock_credentials_provider.fetch.assert_called_once_with("openai", "gpt-3.5-turbo") + mock_model_factory.init_model_instance.assert_called_once_with("openai", "gpt-3.5-turbo") + provider_model.raise_for_status.assert_called_once() + + +def test_dify_model_access_adapters_call_managers(): + mock_provider_manager = mock.MagicMock() + mock_model_manager = mock.MagicMock() + mock_configurations = mock.MagicMock() + mock_provider_configuration = mock.MagicMock() + mock_provider_model = mock.MagicMock() + + mock_configurations.get.return_value = mock_provider_configuration + mock_provider_configuration.get_provider_model.return_value = mock_provider_model + mock_provider_configuration.get_current_credentials.return_value = {"api_key": "test"} + + credentials_provider = DifyCredentialsProvider( + tenant_id="tenant", + provider_manager=mock_provider_manager, + ) + model_factory = DifyModelFactory( + tenant_id="tenant", + model_manager=mock_model_manager, + ) + + mock_provider_manager.get_configurations.return_value = mock_configurations + + credentials_provider.fetch("openai", "gpt-3.5-turbo") + model_factory.init_model_instance("openai", "gpt-3.5-turbo") + + mock_provider_manager.get_configurations.assert_called_once_with("tenant") + mock_configurations.get.assert_called_once_with("openai") + mock_provider_configuration.get_provider_model.assert_called_once_with( + model_type=ModelType.LLM, + model="gpt-3.5-turbo", + ) + mock_provider_configuration.get_current_credentials.assert_called_once_with( + model_type=ModelType.LLM, + model="gpt-3.5-turbo", + ) + mock_provider_model.raise_for_status.assert_called_once() + mock_model_manager.get_model_instance.assert_called_once_with( + tenant_id="tenant", + provider="openai", + model_type=ModelType.LLM, + model="gpt-3.5-turbo", + ) + + def test_fetch_files_with_file_segment(): file = File( id="1", @@ -485,6 +588,8 @@ def test_handle_list_messages_basic(llm_node): @pytest.fixture def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_state) -> tuple[LLMNode, LLMFileSaver]: mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver) + mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) + mock_model_factory = mock.MagicMock(spec=ModelFactory) node_config = { "id": "1", "data": llm_node_data.model_dump(), @@ -494,6 +599,8 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat config=node_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, + credentials_provider=mock_credentials_provider, + model_factory=mock_model_factory, llm_file_saver=mock_file_saver, ) return node, mock_file_saver diff --git a/api/tests/unit_tests/services/dataset_service_update_delete.py b/api/tests/unit_tests/services/dataset_service_update_delete.py index 5deec10d5e..c805dd98e2 100644 --- a/api/tests/unit_tests/services/dataset_service_update_delete.py +++ b/api/tests/unit_tests/services/dataset_service_update_delete.py @@ -642,8 +642,16 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: # Mock embedding model mock_embedding_model = Mock() - mock_embedding_model.model = "text-embedding-ada-002" + mock_embedding_model.model_name = "text-embedding-ada-002" mock_embedding_model.provider = "openai" + mock_embedding_model.credentials = {} + + mock_model_schema = Mock() + mock_model_schema.features = [] + + mock_text_embedding_model = Mock() + mock_text_embedding_model.get_model_schema.return_value = mock_model_schema + mock_embedding_model.model_type_instance = mock_text_embedding_model mock_model_instance = Mock() mock_model_instance.get_model_instance.return_value = mock_embedding_model diff --git a/api/tests/unit_tests/services/test_dataset_service.py b/api/tests/unit_tests/services/test_dataset_service.py index 87fd29bbc0..80cce81e89 100644 --- a/api/tests/unit_tests/services/test_dataset_service.py +++ b/api/tests/unit_tests/services/test_dataset_service.py @@ -174,7 +174,7 @@ class DatasetServiceTestDataFactory: Mock: Embedding model mock with model and provider attributes """ embedding_model = Mock() - embedding_model.model = model + embedding_model.model_name = model embedding_model.provider = provider return embedding_model @@ -434,7 +434,7 @@ class TestDatasetServiceCreateDataset: # Assert assert result.indexing_technique == "high_quality" assert result.embedding_model_provider == embedding_model.provider - assert result.embedding_model == embedding_model.model + assert result.embedding_model == embedding_model.model_name mock_model_manager_instance.get_default_model_instance.assert_called_once_with( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING ) diff --git a/api/tests/unit_tests/services/test_dataset_service_create_dataset.py b/api/tests/unit_tests/services/test_dataset_service_create_dataset.py index 4d63c5f911..7c7a70f962 100644 --- a/api/tests/unit_tests/services/test_dataset_service_create_dataset.py +++ b/api/tests/unit_tests/services/test_dataset_service_create_dataset.py @@ -46,7 +46,7 @@ class DatasetCreateTestDataFactory: def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock: """Create a mock embedding model.""" embedding_model = Mock() - embedding_model.model = model + embedding_model.model_name = model embedding_model.provider = provider return embedding_model @@ -244,7 +244,7 @@ class TestDatasetServiceCreateEmptyDataset: # Assert assert result.indexing_technique == "high_quality" assert result.embedding_model_provider == embedding_model.provider - assert result.embedding_model == embedding_model.model + assert result.embedding_model == embedding_model.model_name mock_model_manager_instance.get_default_model_instance.assert_called_once_with( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING )