From 6d5b3863949a0902c751c505491801175a514636 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Tue, 30 Jan 2024 15:25:37 +0800 Subject: [PATCH] Feat/blocking function call (#2247) --- api/core/app_runner/assistant_app_runner.py | 14 +- api/core/features/assistant_base_runner.py | 15 ++- api/core/features/assistant_cot_runner.py | 7 +- api/core/features/assistant_fc_runner.py | 126 +++++++++++++++--- .../model_runtime/entities/model_entities.py | 1 + .../model_providers/azure_openai/_constant.py | 5 + .../model_providers/azure_openai/llm/llm.py | 28 +++- .../model_providers/chatglm/llm/llm.py | 6 +- .../minimax/llm/abab5.5-chat.yaml | 2 + .../minimax/llm/abab6-chat.yaml | 2 + .../minimax/llm/chat_completion.py | 3 +- .../minimax/llm/chat_completion_pro.py | 46 +++++-- .../model_providers/minimax/llm/llm.py | 43 +++++- .../model_providers/minimax/llm/types.py | 10 ++ .../openai/llm/gpt-3.5-turbo-0613.yaml | 1 + .../openai/llm/gpt-3.5-turbo-1106.yaml | 1 + .../openai/llm/gpt-3.5-turbo-16k-0613.yaml | 1 + .../openai/llm/gpt-3.5-turbo-16k.yaml | 1 + .../openai/llm/gpt-3.5-turbo.yaml | 1 + .../openai/llm/gpt-4-0125-preview.yaml | 1 + .../openai/llm/gpt-4-1106-preview.yaml | 1 + .../model_providers/openai/llm/gpt-4-32k.yaml | 1 + .../openai/llm/gpt-4-turbo-preview.yaml | 1 + .../model_providers/openai/llm/gpt-4.yaml | 1 + .../model_providers/openai/llm/llm.py | 2 +- .../model_providers/xinference/llm/llm.py | 31 ++++- .../text_embedding/text_embedding.py | 23 +++- .../xinference/{llm => }/xinference_helper.py | 25 +++- .../zhipuai/llm/glm_3_turbo.yaml | 4 + .../model_providers/zhipuai/llm/glm_4.yaml | 4 + .../model_providers/zhipuai/llm/llm.py | 21 +++ api/requirements.txt | 2 +- .../model_runtime/__mock/xinference.py | 93 ++++++++----- 33 files changed, 429 insertions(+), 94 deletions(-) rename api/core/model_runtime/model_providers/xinference/{llm => }/xinference_helper.py (75%) diff --git a/api/core/app_runner/assistant_app_runner.py b/api/core/app_runner/assistant_app_runner.py index e440d1550d..b3cbeaf81c 100644 --- a/api/core/app_runner/assistant_app_runner.py +++ b/api/core/app_runner/assistant_app_runner.py @@ -11,6 +11,7 @@ from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.moderation.base import ModerationException from core.tools.entities.tool_entities import ToolRuntimeVariablePool @@ -194,6 +195,13 @@ class AssistantApplicationRunner(AppRunner): memory=memory, ) + # change function call strategy based on LLM model + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) + model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + + if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features): + agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING + # start agent runner if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: assistant_cot_runner = AssistantCotApplicationRunner( @@ -209,9 +217,9 @@ class AssistantApplicationRunner(AppRunner): prompt_messages=prompt_message, variables_pool=tool_variables, db_variables=tool_conversation_variables, + model_instance=model_instance ) invoke_result = assistant_cot_runner.run( - model_instance=model_instance, conversation=conversation, message=message, query=query, @@ -229,10 +237,10 @@ class AssistantApplicationRunner(AppRunner): memory=memory, prompt_messages=prompt_message, variables_pool=tool_variables, - db_variables=tool_conversation_variables + db_variables=tool_conversation_variables, + model_instance=model_instance ) invoke_result = assistant_fc_runner.run( - model_instance=model_instance, conversation=conversation, message=message, query=query, diff --git a/api/core/features/assistant_base_runner.py b/api/core/features/assistant_base_runner.py index 32f7f1d49f..8919033a1e 100644 --- a/api/core/features/assistant_base_runner.py +++ b/api/core/features/assistant_base_runner.py @@ -1,7 +1,7 @@ import logging import json -from typing import Optional, List, Tuple, Union +from typing import Optional, List, Tuple, Union, cast from datetime import datetime from mimetypes import guess_extension @@ -27,7 +27,10 @@ from core.entities.application_entities import ModelConfigEntity, \ AgentEntity, AppOrchestrationConfigEntity, ApplicationGenerateEntity, InvokeFrom from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.utils.encoders import jsonable_encoder +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_manager import ModelInstance from core.file.message_file_parser import FileTransferMethod logger = logging.getLogger(__name__) @@ -45,6 +48,7 @@ class BaseAssistantApplicationRunner(AppRunner): prompt_messages: Optional[List[PromptMessage]] = None, variables_pool: Optional[ToolRuntimeVariablePool] = None, db_variables: Optional[ToolConversationVariables] = None, + model_instance: ModelInstance = None ) -> None: """ Agent runner @@ -71,6 +75,7 @@ class BaseAssistantApplicationRunner(AppRunner): self.history_prompt_messages = prompt_messages self.variables_pool = variables_pool self.db_variables_pool = db_variables + self.model_instance = model_instance # init callback self.agent_callback = DifyAgentCallbackHandler() @@ -95,6 +100,14 @@ class BaseAssistantApplicationRunner(AppRunner): MessageAgentThought.message_id == self.message.id, ).count() + # check if model supports stream tool call + llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) + model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) + if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []): + self.stream_tool_call = True + else: + self.stream_tool_call = False + def _repacket_app_orchestration_config(self, app_orchestration_config: AppOrchestrationConfigEntity) -> AppOrchestrationConfigEntity: """ Repacket app orchestration config diff --git a/api/core/features/assistant_cot_runner.py b/api/core/features/assistant_cot_runner.py index 6fcc02d857..c7aec6965c 100644 --- a/api/core/features/assistant_cot_runner.py +++ b/api/core/features/assistant_cot_runner.py @@ -20,8 +20,7 @@ from core.features.assistant_base_runner import BaseAssistantApplicationRunner from models.model import Conversation, Message class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): - def run(self, model_instance: ModelInstance, - conversation: Conversation, + def run(self, conversation: Conversation, message: Message, query: str, ) -> Union[Generator, LLMResult]: @@ -82,6 +81,8 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): llm_usage.prompt_price += usage.prompt_price llm_usage.completion_price += usage.completion_price + model_instance = self.model_instance + while function_call_state and iteration_step <= max_iteration_steps: # continue to run until there is not any tool call function_call_state = False @@ -390,7 +391,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): # remove Action: xxx from agent thought agent_thought = re.sub(r'Action:.*', '', agent_thought, flags=re.IGNORECASE) - if action_name and action_input: + if action_name and action_input is not None: return AgentScratchpadUnit( agent_response=content, thought=agent_thought, diff --git a/api/core/features/assistant_fc_runner.py b/api/core/features/assistant_fc_runner.py index 03c4e87015..c6df20dfd3 100644 --- a/api/core/features/assistant_fc_runner.py +++ b/api/core/features/assistant_fc_runner.py @@ -5,7 +5,7 @@ from typing import Union, Generator, Dict, Any, Tuple, List from core.model_runtime.entities.message_entities import PromptMessage, UserPromptMessage,\ SystemPromptMessage, AssistantPromptMessage, ToolPromptMessage, PromptMessageTool -from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage +from core.model_runtime.entities.llm_entities import LLMResultChunk, LLMResult, LLMUsage, LLMResultChunkDelta from core.model_manager import ModelInstance from core.application_queue_manager import PublishFrom @@ -20,8 +20,7 @@ from models.model import Conversation, Message, MessageAgentThought logger = logging.getLogger(__name__) class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): - def run(self, model_instance: ModelInstance, - conversation: Conversation, + def run(self, conversation: Conversation, message: Message, query: str, ) -> Generator[LLMResultChunk, None, None]: @@ -81,6 +80,8 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): llm_usage.prompt_price += usage.prompt_price llm_usage.completion_price += usage.completion_price + model_instance = self.model_instance + while function_call_state and iteration_step <= max_iteration_steps: function_call_state = False @@ -101,12 +102,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): # recale llm max tokens self.recale_llm_max_tokens(self.model_config, prompt_messages) # invoke model - chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm( + chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=app_orchestration_config.model_config.parameters, tools=prompt_messages_tools, stop=app_orchestration_config.model_config.stop, - stream=True, + stream=self.stream_tool_call, user=self.user_id, callbacks=[], ) @@ -122,11 +123,41 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): current_llm_usage = None - for chunk in chunks: + if self.stream_tool_call: + for chunk in chunks: + # check if there is any tool call + if self.check_tool_calls(chunk): + function_call_state = True + tool_calls.extend(self.extract_tool_calls(chunk)) + tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) + try: + tool_call_inputs = json.dumps({ + tool_call[1]: tool_call[2] for tool_call in tool_calls + }, ensure_ascii=False) + except json.JSONDecodeError as e: + # ensure ascii to avoid encoding error + tool_call_inputs = json.dumps({ + tool_call[1]: tool_call[2] for tool_call in tool_calls + }) + + if chunk.delta.message and chunk.delta.message.content: + if isinstance(chunk.delta.message.content, list): + for content in chunk.delta.message.content: + response += content.data + else: + response += chunk.delta.message.content + + if chunk.delta.usage: + increase_usage(llm_usage, chunk.delta.usage) + current_llm_usage = chunk.delta.usage + + yield chunk + else: + result: LLMResult = chunks # check if there is any tool call - if self.check_tool_calls(chunk): + if self.check_blocking_tool_calls(result): function_call_state = True - tool_calls.extend(self.extract_tool_calls(chunk)) + tool_calls.extend(self.extract_blocking_tool_calls(result)) tool_call_names = ';'.join([tool_call[1] for tool_call in tool_calls]) try: tool_call_inputs = json.dumps({ @@ -138,18 +169,44 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): tool_call[1]: tool_call[2] for tool_call in tool_calls }) - if chunk.delta.message and chunk.delta.message.content: - if isinstance(chunk.delta.message.content, list): - for content in chunk.delta.message.content: + if result.usage: + increase_usage(llm_usage, result.usage) + current_llm_usage = result.usage + + if result.message and result.message.content: + if isinstance(result.message.content, list): + for content in result.message.content: response += content.data else: - response += chunk.delta.message.content + response += result.message.content - if chunk.delta.usage: - increase_usage(llm_usage, chunk.delta.usage) - current_llm_usage = chunk.delta.usage + if not result.message.content: + result.message.content = '' - yield chunk + yield LLMResultChunk( + model=model_instance.model, + prompt_messages=result.prompt_messages, + system_fingerprint=result.system_fingerprint, + delta=LLMResultChunkDelta( + index=0, + message=result.message, + usage=result.usage, + ) + ) + + if tool_calls: + prompt_messages.append(AssistantPromptMessage( + content='', + name='', + tool_calls=[AssistantPromptMessage.ToolCall( + id=tool_call[0], + type='function', + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=tool_call[1], + arguments=json.dumps(tool_call[2], ensure_ascii=False) + ) + ) for tool_call in tool_calls] + )) # save thought self.save_agent_thought( @@ -167,6 +224,12 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): final_answer += response + '\n' + # update prompt messages + if response.strip(): + prompt_messages.append(AssistantPromptMessage( + content=response, + )) + # call tools tool_responses = [] for tool_call_id, tool_call_name, tool_call_args in tool_calls: @@ -256,12 +319,6 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): ) self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) - # update prompt messages - if response.strip(): - prompt_messages.append(AssistantPromptMessage( - content=response, - )) - # update prompt tool for prompt_tool in prompt_messages_tools: self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) @@ -287,6 +344,14 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): if llm_result_chunk.delta.message.tool_calls: return True return False + + def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool: + """ + Check if there is any blocking tool call in llm result + """ + if llm_result.message.tool_calls: + return True + return False def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]: """ @@ -304,6 +369,23 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): )) return tool_calls + + def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]: + """ + Extract blocking tool calls from llm result + + Returns: + List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)] + """ + tool_calls = [] + for prompt_message in llm_result.message.tool_calls: + tool_calls.append(( + prompt_message.id, + prompt_message.function.name, + json.loads(prompt_message.function.arguments), + )) + + return tool_calls def organize_prompt_messages(self, prompt_template: str, query: str = None, diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 23c492cedb..2041cb3a97 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -78,6 +78,7 @@ class ModelFeature(Enum): MULTI_TOOL_CALL = "multi-tool-call" AGENT_THOUGHT = "agent-thought" VISION = "vision" + STREAM_TOOL_CALL = "stream-tool-call" class DefaultParameterName(Enum): diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index 75c7ec508b..8104df52dd 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -36,6 +36,7 @@ LLM_BASE_MODELS = [ features=[ ModelFeature.AGENT_THOUGHT, ModelFeature.MULTI_TOOL_CALL, + ModelFeature.STREAM_TOOL_CALL, ], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ @@ -80,6 +81,7 @@ LLM_BASE_MODELS = [ features=[ ModelFeature.AGENT_THOUGHT, ModelFeature.MULTI_TOOL_CALL, + ModelFeature.STREAM_TOOL_CALL, ], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ @@ -124,6 +126,7 @@ LLM_BASE_MODELS = [ features=[ ModelFeature.AGENT_THOUGHT, ModelFeature.MULTI_TOOL_CALL, + ModelFeature.STREAM_TOOL_CALL, ], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ @@ -198,6 +201,7 @@ LLM_BASE_MODELS = [ features=[ ModelFeature.AGENT_THOUGHT, ModelFeature.MULTI_TOOL_CALL, + ModelFeature.STREAM_TOOL_CALL, ], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ @@ -272,6 +276,7 @@ LLM_BASE_MODELS = [ features=[ ModelFeature.AGENT_THOUGHT, ModelFeature.MULTI_TOOL_CALL, + ModelFeature.STREAM_TOOL_CALL, ], fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index c1a5e23bc2..326043aa39 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -324,6 +324,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): tools: Optional[list[PromptMessageTool]] = None) -> Generator: index = 0 full_assistant_content = '' + delta_assistant_message_function_call_storage: ChoiceDeltaFunctionCall = None real_model = model system_fingerprint = None completion = '' @@ -333,12 +334,32 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): delta = chunk.choices[0] - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): + if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == '') and \ + delta.delta.function_call is None: continue - + # assistant_message_tool_calls = delta.delta.tool_calls assistant_message_function_call = delta.delta.function_call + # extract tool calls from response + if delta_assistant_message_function_call_storage is not None: + # handle process of stream function call + if assistant_message_function_call: + # message has not ended ever + delta_assistant_message_function_call_storage.arguments += assistant_message_function_call.arguments + continue + else: + # message has ended + assistant_message_function_call = delta_assistant_message_function_call_storage + delta_assistant_message_function_call_storage = None + else: + if assistant_message_function_call: + # start of stream function call + delta_assistant_message_function_call_storage = assistant_message_function_call + if delta_assistant_message_function_call_storage.arguments is None: + delta_assistant_message_function_call_storage.arguments = '' + continue + # extract tool calls from response # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) function_call = self._extract_response_function_call(assistant_message_function_call) @@ -489,7 +510,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): else: raise ValueError(f"Got unknown type {message}") - if message.name is not None: + if message.name: message_dict["name"] = message.name return message_dict @@ -586,7 +607,6 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): num_tokens = 0 for tool in tools: num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode(tool.get("type"))) num_tokens += len(encoding.encode('function')) # calculate num tokens for function object diff --git a/api/core/model_runtime/model_providers/chatglm/llm/llm.py b/api/core/model_runtime/model_providers/chatglm/llm/llm.py index 44868fcf73..471898fcf6 100644 --- a/api/core/model_runtime/model_providers/chatglm/llm/llm.py +++ b/api/core/model_runtime/model_providers/chatglm/llm/llm.py @@ -5,7 +5,7 @@ from typing import Generator, List, Optional, cast from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageFunction, - PromptMessageTool, SystemPromptMessage, UserPromptMessage) + PromptMessageTool, SystemPromptMessage, UserPromptMessage, ToolPromptMessage) from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -194,6 +194,10 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) message_dict = {"role": "system", "content": message.content} + elif isinstance(message, ToolPromptMessage): + # check if last message is user message + message = cast(ToolPromptMessage, message) + message_dict = {"role": "function", "content": message.content} else: raise ValueError(f"Unknown message type {type(message)}") diff --git a/api/core/model_runtime/model_providers/minimax/llm/abab5.5-chat.yaml b/api/core/model_runtime/model_providers/minimax/llm/abab5.5-chat.yaml index cacccdb669..c0ad1e2fdf 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/abab5.5-chat.yaml +++ b/api/core/model_runtime/model_providers/minimax/llm/abab5.5-chat.yaml @@ -4,6 +4,8 @@ label: model_type: llm features: - agent-thought + - tool-call + - stream-tool-call model_properties: mode: chat context_size: 16384 diff --git a/api/core/model_runtime/model_providers/minimax/llm/abab6-chat.yaml b/api/core/model_runtime/model_providers/minimax/llm/abab6-chat.yaml index 258f2a9188..4c487c598e 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/abab6-chat.yaml +++ b/api/core/model_runtime/model_providers/minimax/llm/abab6-chat.yaml @@ -4,6 +4,8 @@ label: model_type: llm features: - agent-thought + - tool-call + - stream-tool-call model_properties: mode: chat context_size: 32768 diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py index f688b348e5..718ebb1013 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py @@ -16,7 +16,7 @@ class MinimaxChatCompletion(object): """ def generate(self, model: str, api_key: str, group_id: str, prompt_messages: List[MinimaxMessage], model_parameters: dict, - tools: Dict[str, Any], stop: List[str] | None, stream: bool, user: str) \ + tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \ -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: """ generate chat completion @@ -162,7 +162,6 @@ class MinimaxChatCompletion(object): continue for choice in choices: - print(choice) message = choice['delta'] yield MinimaxMessage( content=message, diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py index 71b337e152..6233af26b6 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py @@ -17,7 +17,7 @@ class MinimaxChatCompletionPro(object): """ def generate(self, model: str, api_key: str, group_id: str, prompt_messages: List[MinimaxMessage], model_parameters: dict, - tools: Dict[str, Any], stop: List[str] | None, stream: bool, user: str) \ + tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \ -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: """ generate chat completion @@ -82,6 +82,10 @@ class MinimaxChatCompletionPro(object): **extra_kwargs } + if tools: + body['functions'] = tools + body['function_call'] = { 'type': 'auto' } + try: response = post( url=url, data=dumps(body), headers=headers, stream=stream, timeout=(10, 300)) @@ -135,6 +139,7 @@ class MinimaxChatCompletionPro(object): """ handle stream chat generate response """ + function_call_storage = None for line in response.iter_lines(): if not line: continue @@ -148,7 +153,7 @@ class MinimaxChatCompletionPro(object): msg = data['base_resp']['status_msg'] self._handle_error(code, msg) - if data['reply']: + if data['reply'] or 'usage' in data and data['usage']: total_tokens = data['usage']['total_tokens'] message = MinimaxMessage( role=MinimaxMessage.Role.ASSISTANT.value, @@ -160,6 +165,12 @@ class MinimaxChatCompletionPro(object): 'total_tokens': total_tokens } message.stop_reason = data['choices'][0]['finish_reason'] + + if function_call_storage: + function_call_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value) + function_call_message.function_call = function_call_storage + yield function_call_message + yield message return @@ -168,11 +179,28 @@ class MinimaxChatCompletionPro(object): continue for choice in choices: - message = choice['messages'][0]['text'] - if not message: - continue + message = choice['messages'][0] + + if 'function_call' in message: + if not function_call_storage: + function_call_storage = message['function_call'] + if 'arguments' not in function_call_storage or not function_call_storage['arguments']: + function_call_storage['arguments'] = '' + continue + else: + function_call_storage['arguments'] += message['function_call']['arguments'] + continue + else: + if function_call_storage: + message['function_call'] = function_call_storage + function_call_storage = None - yield MinimaxMessage( - content=message, - role=MinimaxMessage.Role.ASSISTANT.value - ) \ No newline at end of file + minimax_message = MinimaxMessage(content='', role=MinimaxMessage.Role.ASSISTANT.value) + + if 'function_call' in message: + minimax_message.function_call = message['function_call'] + + if 'text' in message: + minimax_message.content = message['text'] + + yield minimax_message \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/minimax/llm/llm.py b/api/core/model_runtime/model_providers/minimax/llm/llm.py index 7d8ae22317..86a246c714 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/llm.py +++ b/api/core/model_runtime/model_providers/minimax/llm/llm.py @@ -2,7 +2,7 @@ from typing import Generator, List from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) + SystemPromptMessage, UserPromptMessage, ToolPromptMessage) from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -84,6 +84,13 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): """ client: MinimaxChatCompletionPro = self.model_apis[model]() + if tools: + tools = [{ + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters + } for tool in tools] + response = client.generate( model=model, api_key=credentials['minimax_api_key'], @@ -109,7 +116,19 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): elif isinstance(prompt_message, UserPromptMessage): return MinimaxMessage(role=MinimaxMessage.Role.USER.value, content=prompt_message.content) elif isinstance(prompt_message, AssistantPromptMessage): + if prompt_message.tool_calls: + message = MinimaxMessage( + role=MinimaxMessage.Role.ASSISTANT.value, + content='' + ) + message.function_call={ + 'name': prompt_message.tool_calls[0].function.name, + 'arguments': prompt_message.tool_calls[0].function.arguments + } + return message return MinimaxMessage(role=MinimaxMessage.Role.ASSISTANT.value, content=prompt_message.content) + elif isinstance(prompt_message, ToolPromptMessage): + return MinimaxMessage(role=MinimaxMessage.Role.FUNCTION.value, content=prompt_message.content) else: raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported') @@ -151,6 +170,28 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): finish_reason=message.stop_reason if message.stop_reason else None, ), ) + elif message.function_call: + if 'name' not in message.function_call or 'arguments' not in message.function_call: + continue + + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content='', + tool_calls=[AssistantPromptMessage.ToolCall( + id='', + type='function', + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=message.function_call['name'], + arguments=message.function_call['arguments'] + ) + )] + ), + ), + ) else: yield LLMResultChunk( model=model, diff --git a/api/core/model_runtime/model_providers/minimax/llm/types.py b/api/core/model_runtime/model_providers/minimax/llm/types.py index 6d1e8e64d8..6229312445 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/types.py +++ b/api/core/model_runtime/model_providers/minimax/llm/types.py @@ -7,13 +7,23 @@ class MinimaxMessage: USER = 'USER' ASSISTANT = 'BOT' SYSTEM = 'SYSTEM' + FUNCTION = 'FUNCTION' role: str = Role.USER.value content: str usage: Dict[str, int] = None stop_reason: str = '' + function_call: Dict[str, Any] = None def to_dict(self) -> Dict[str, Any]: + if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value: + return { + 'sender_type': 'BOT', + 'sender_name': '专家', + 'text': '', + 'function_call': self.function_call + } + return { 'sender_type': self.role, 'sender_name': '我' if self.role == 'USER' else '专家', diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml index bc130b02fc..6d519cbee6 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml @@ -6,6 +6,7 @@ model_type: llm features: - multi-tool-call - agent-thought + - stream-tool-call model_properties: mode: chat context_size: 4096 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml index ff260bb367..4b5d31c774 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml @@ -6,6 +6,7 @@ model_type: llm features: - multi-tool-call - agent-thought + - stream-tool-call model_properties: mode: chat context_size: 16385 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml index 93d0113b8a..a86bacb34f 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml @@ -6,6 +6,7 @@ model_type: llm features: - multi-tool-call - agent-thought + - stream-tool-call model_properties: mode: chat context_size: 16385 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml index ddb4da775b..467041e842 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml @@ -6,6 +6,7 @@ model_type: llm features: - multi-tool-call - agent-thought + - stream-tool-call model_properties: mode: chat context_size: 16385 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml index ee8ad8d302..fddf1836c4 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml @@ -6,6 +6,7 @@ model_type: llm features: - multi-tool-call - agent-thought + - stream-tool-call model_properties: mode: chat context_size: 4096 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4-0125-preview.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4-0125-preview.yaml index d70395c566..943a6de321 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4-0125-preview.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4-0125-preview.yaml @@ -6,6 +6,7 @@ model_type: llm features: - multi-tool-call - agent-thought + - stream-tool-call model_properties: mode: chat context_size: 128000 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4-1106-preview.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4-1106-preview.yaml index 73258fedae..7f3bdaeac1 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4-1106-preview.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4-1106-preview.yaml @@ -6,6 +6,7 @@ model_type: llm features: - multi-tool-call - agent-thought + - stream-tool-call model_properties: mode: chat context_size: 128000 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4-32k.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4-32k.yaml index 296c88a379..b1e61317e9 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4-32k.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4-32k.yaml @@ -6,6 +6,7 @@ model_type: llm features: - multi-tool-call - agent-thought + - stream-tool-call model_properties: mode: chat context_size: 32768 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-preview.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-preview.yaml index c651a4e0e9..b109cfc814 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-preview.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4-turbo-preview.yaml @@ -6,6 +6,7 @@ model_type: llm features: - multi-tool-call - agent-thought + - stream-tool-call model_properties: mode: chat context_size: 128000 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4.yaml index 14ab1b26bd..48e8930608 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4.yaml @@ -6,6 +6,7 @@ model_type: llm features: - multi-tool-call - agent-thought + - stream-tool-call model_properties: mode: chat context_size: 8192 diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index f8fc5db99a..7722c69a95 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -671,7 +671,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): else: raise ValueError(f"Got unknown type {message}") - if message.name is not None: + if message.name: message_dict["name"] = message.name return message_dict diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 8f068f564d..dc9b594d5a 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -3,14 +3,14 @@ from typing import Generator, Iterator, List, Optional, Union, cast from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) + SystemPromptMessage, UserPromptMessage, ToolPromptMessage) from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, - ParameterRule, ParameterType) + ParameterRule, ParameterType, ModelFeature) from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.xinference.llm.xinference_helper import (XinferenceHelper, +from core.model_runtime.model_providers.xinference.xinference_helper import (XinferenceHelper, XinferenceModelExtraParameter) from core.model_runtime.utils import helper from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError, @@ -33,6 +33,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): see `core.model_runtime.model_providers.__base.large_language_model.LargeLanguageModel._invoke` """ + if 'temperature' in model_parameters: + if model_parameters['temperature'] < 0.01: + model_parameters['temperature'] = 0.01 + elif model_parameters['temperature'] > 1.0: + model_parameters['temperature'] = 0.99 + return self._generate( model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user, @@ -65,6 +71,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): credentials['completion_type'] = 'completion' else: raise ValueError(f'xinference model ability {extra_param.model_ability} is not supported') + + if extra_param.support_function_call: + credentials['support_function_call'] = True except RuntimeError as e: raise CredentialsValidateFailedError(f'Xinference credentials validate failed: {e}') @@ -220,6 +229,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) message_dict = {"role": "system", "content": message.content} + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content} else: raise ValueError(f"Unknown message type {type(message)}") @@ -237,7 +249,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): label=I18nObject( zh_Hans='温度', en_US='Temperature' - ) + ), ), ParameterRule( name='top_p', @@ -282,6 +294,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): completion_type = LLMMode.COMPLETION.value else: raise ValueError(f'xinference model ability {extra_args.model_ability} is not supported') + + support_function_call = credentials.get('support_function_call', False) entity = AIModelEntity( model=model, @@ -290,6 +304,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): ), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, + features=[ + ModelFeature.TOOL_CALL + ] if support_function_call else [], model_properties={ ModelPropertyKey.MODE: completion_type, }, @@ -310,6 +327,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): extra_model_kwargs can be got by `XinferenceHelper.get_xinference_extra_parameter` """ + if 'server_url' not in credentials: + raise CredentialsValidateFailedError('server_url is required in credentials') + + if credentials['server_url'].endswith('/'): + credentials['server_url'] = credentials['server_url'][:-1] + client = OpenAI( base_url=f'{credentials["server_url"]}/v1', api_key='abc', diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index e7d7959417..28389db6a4 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -2,7 +2,7 @@ import time from typing import Optional from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType, ModelPropertyKey from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) @@ -10,6 +10,7 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle, RESTfulModelHandle +from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper class XinferenceTextEmbeddingModel(TextEmbeddingModel): """ @@ -35,7 +36,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): """ server_url = credentials['server_url'] model_uid = credentials['model_uid'] - + + if server_url.endswith('/'): + server_url = server_url[:-1] + client = Client(base_url=server_url) try: @@ -102,8 +106,15 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): :return: """ try: + server_url = credentials['server_url'] + model_uid = credentials['model_uid'] + extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid) + + if extra_args.max_tokens: + credentials['max_tokens'] = extra_args.max_tokens + self._invoke(model=model, credentials=credentials, texts=['ping']) - except InvokeAuthorizationError: + except (InvokeAuthorizationError, RuntimeError): raise CredentialsValidateFailedError('Invalid api key') @property @@ -160,6 +171,7 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): """ used to define customizable model schema """ + entity = AIModelEntity( model=model, label=I18nObject( @@ -167,7 +179,10 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): ), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.TEXT_EMBEDDING, - model_properties={}, + model_properties={ + ModelPropertyKey.MAX_CHUNKS: 1, + ModelPropertyKey.CONTEXT_SIZE: 'max_tokens' in credentials and credentials['max_tokens'] or 512, + }, parameter_rules=[] ) diff --git a/api/core/model_runtime/model_providers/xinference/llm/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py similarity index 75% rename from api/core/model_runtime/model_providers/xinference/llm/xinference_helper.py rename to api/core/model_runtime/model_providers/xinference/xinference_helper.py index 88b5a558ac..64612ca3fa 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -1,6 +1,7 @@ from threading import Lock from time import time from typing import List +from os import path from requests import get from requests.adapters import HTTPAdapter @@ -12,11 +13,16 @@ class XinferenceModelExtraParameter(object): model_format: str model_handle_type: str model_ability: List[str] + max_tokens: int = 512 + support_function_call: bool = False - def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str]) -> None: + def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str], + support_function_call: bool, max_tokens: int) -> None: self.model_format = model_format self.model_handle_type = model_handle_type self.model_ability = model_ability + self.support_function_call = support_function_call + self.max_tokens = max_tokens cache = {} cache_lock = Lock() @@ -49,7 +55,7 @@ class XinferenceHelper: get xinference model extra parameter like model_format and model_handle_type """ - url = f'{server_url}/v1/models/{model_uid}' + url = path.join(server_url, 'v1/models', model_uid) # this methid is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3 session = Session() @@ -66,10 +72,12 @@ class XinferenceHelper: response_json = response.json() - model_format = response_json['model_format'] - model_ability = response_json['model_ability'] + model_format = response_json.get('model_format', 'ggmlv3') + model_ability = response_json.get('model_ability', []) - if model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']: + if response_json.get('model_type') == 'embedding': + model_handle_type = 'embedding' + elif model_format == 'ggmlv3' and 'chatglm' in response_json['model_name']: model_handle_type = 'chatglm' elif 'generate' in model_ability: model_handle_type = 'generate' @@ -78,8 +86,13 @@ class XinferenceHelper: else: raise NotImplementedError(f'xinference model handle type {model_handle_type} is not supported') + support_function_call = 'tools' in model_ability + max_tokens = response_json.get('max_tokens', 512) + return XinferenceModelExtraParameter( model_format=model_format, model_handle_type=model_handle_type, - model_ability=model_ability + model_ability=model_ability, + support_function_call=support_function_call, + max_tokens=max_tokens ) \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml index bf8f9cfd5e..b0027d01e3 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml @@ -2,6 +2,10 @@ model: glm-3-turbo label: en_US: glm-3-turbo model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call model_properties: mode: chat parameter_rules: diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml index c8329b7e37..ca7b1c1f45 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml @@ -2,6 +2,10 @@ model: glm-4 label: en_US: glm-4 model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call model_properties: mode: chat parameter_rules: diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index c4c1dfb85b..eafae9ab6f 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -194,6 +194,27 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): 'content': prompt_message.content, 'tool_call_id': prompt_message.tool_call_id }) + elif isinstance(prompt_message, AssistantPromptMessage): + if prompt_message.tool_calls: + params['messages'].append({ + 'role': 'assistant', + 'content': prompt_message.content, + 'tool_calls': [ + { + 'id': tool_call.id, + 'type': tool_call.type, + 'function': { + 'name': tool_call.function.name, + 'arguments': tool_call.function.arguments + } + } for tool_call in prompt_message.tool_calls + ] + }) + else: + params['messages'].append({ + 'role': 'assistant', + 'content': prompt_message.content + }) else: params['messages'].append({ 'role': prompt_message.role.value, diff --git a/api/requirements.txt b/api/requirements.txt index 97c21a5c61..a4757ed0fe 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -47,7 +47,7 @@ dashscope[tokenizer]~=1.14.0 huggingface_hub~=0.16.4 transformers~=4.31.0 pandas==1.5.3 -xinference-client~=0.6.4 +xinference-client~=0.8.1 safetensors==0.3.2 zhipuai==1.0.7 werkzeug~=3.0.1 diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py index f5c61f4725..e4cc2ceea6 100644 --- a/api/tests/integration_tests/model_runtime/__mock/xinference.py +++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py @@ -19,58 +19,86 @@ class MockXinferenceClass(object): raise RuntimeError('404 Not Found') if 'generate' == model_uid: - return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url) + return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={}) if 'chat' == model_uid: - return RESTfulChatModelHandle(model_uid, base_url=self.base_url) + return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={}) if 'embedding' == model_uid: - return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url) + return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={}) if 'rerank' == model_uid: - return RESTfulRerankModelHandle(model_uid, base_url=self.base_url) + return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={}) raise RuntimeError('404 Not Found') def get(self: Session, url: str, **kwargs): - if '/v1/models/' in url: - response = Response() - + response = Response() + if 'v1/models/' in url: # get model uid model_uid = url.split('/')[-1] if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \ model_uid not in ['generate', 'chat', 'embedding', 'rerank']: response.status_code = 404 - raise ConnectionError('404 Not Found') + return response # check if url is valid if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url): response.status_code = 404 - raise ConnectionError('404 Not Found') - + return response + + if model_uid in ['generate', 'chat']: + response.status_code = 200 + response._content = b'''{ + "model_type": "LLM", + "address": "127.0.0.1:43877", + "accelerators": [ + "0", + "1" + ], + "model_name": "chatglm3-6b", + "model_lang": [ + "en" + ], + "model_ability": [ + "generate", + "chat" + ], + "model_description": "latest chatglm3", + "model_format": "pytorch", + "model_size_in_billions": 7, + "quantization": "none", + "model_hub": "huggingface", + "revision": null, + "context_length": 2048, + "replica": 1 + }''' + return response + + elif model_uid == 'embedding': + response.status_code = 200 + response._content = b'''{ + "model_type": "embedding", + "address": "127.0.0.1:43877", + "accelerators": [ + "0", + "1" + ], + "model_name": "bge", + "model_lang": [ + "en" + ], + "revision": null, + "max_tokens": 512 +}''' + return response + + elif 'v1/cluster/auth' in url: response.status_code = 200 response._content = b'''{ - "model_type": "LLM", - "address": "127.0.0.1:43877", - "accelerators": [ - "0", - "1" - ], - "model_name": "chatglm3-6b", - "model_lang": [ - "en" - ], - "model_ability": [ - "generate", - "chat" - ], - "model_description": "latest chatglm3", - "model_format": "pytorch", - "model_size_in_billions": 7, - "quantization": "none", - "model_hub": "huggingface", - "revision": null, - "context_length": 2048, - "replica": 1 + "auth": true }''' return response + def _check_cluster_authenticated(self): + self._cluster_authed = True + def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top_n: int) -> dict: # check if self._model_uid is a valid uuid if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \ @@ -133,6 +161,7 @@ MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' def setup_xinference_mock(request, monkeypatch: MonkeyPatch): if MOCK: monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model) + monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated) monkeypatch.setattr(Session, 'get', MockXinferenceClass.get) monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding) monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)