diff --git a/api/README.md b/api/README.md index 1d3559c694..5c3a667801 100644 --- a/api/README.md +++ b/api/README.md @@ -5,7 +5,7 @@ 1. Start the docker-compose stack The backend require some middleware, including PostgreSQL, Redis, and Weaviate, which can be started together using `docker-compose`. - + ```bash cd ../docker docker-compose -f docker-compose.middleware.yaml -p dify up -d @@ -15,7 +15,7 @@ 3. Generate a `SECRET_KEY` in the `.env` file. ```bash - openssl rand -base64 42 + sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env ``` 3.5 If you use annaconda, create a new environment and activate it ```bash @@ -46,7 +46,7 @@ ``` pip install -r requirements.txt --upgrade --force-reinstall ``` - + 6. Start backend: ```bash flask run --host 0.0.0.0 --port=5001 --debug diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 21ce9cb6af..4b648a4e28 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -27,7 +27,9 @@ from fields.app_fields import ( from libs.login import login_required from models.model import App, AppModelConfig, Site from services.app_model_config_service import AppModelConfigService - +from core.tools.utils.configuration import ToolParameterConfigurationManager +from core.tools.tool_manager import ToolManager +from core.entities.application_entities import AgentToolEntity def _get_app(app_id, tenant_id): app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first() @@ -236,7 +238,39 @@ class AppApi(Resource): def get(self, app_id): """Get app detail""" app_id = str(app_id) - app = _get_app(app_id, current_user.current_tenant_id) + app: App = _get_app(app_id, current_user.current_tenant_id) + + # get original app model config + model_config: AppModelConfig = app.app_model_config + agent_mode = model_config.agent_mode_dict + # decrypt agent tool parameters if it's secret-input + for tool in agent_mode.get('tools') or []: + agent_tool_entity = AgentToolEntity(**tool) + # get tool + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=current_user.current_tenant_id, + agent_tool=agent_tool_entity, + agent_callback=None + ) + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + ) + + # get decrypted parameters + if agent_tool_entity.tool_parameters: + parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + masked_parameter = manager.mask_tool_parameters(parameters or {}) + else: + masked_parameter = {} + + # override tool parameters + tool['tool_parameters'] = masked_parameter + + # override agent mode + model_config.agent_mode = json.dumps(agent_mode) return app diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index f67fff4b06..117007d055 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -1,3 +1,4 @@ +import json from flask import request from flask_login import current_user @@ -7,6 +8,9 @@ from controllers.console import api from controllers.console.app import _get_app from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from core.entities.application_entities import AgentToolEntity +from core.tools.tool_manager import ToolManager +from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_model_config_was_updated from extensions.ext_database import db from libs.login import login_required @@ -38,6 +42,82 @@ class ModelConfigResource(Resource): ) new_app_model_config = new_app_model_config.from_model_config_dict(model_configuration) + # get original app model config + original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter( + AppModelConfig.id == app.app_model_config_id + ).first() + agent_mode = original_app_model_config.agent_mode_dict + # decrypt agent tool parameters if it's secret-input + parameter_map = {} + masked_parameter_map = {} + tool_map = {} + for tool in agent_mode.get('tools') or []: + agent_tool_entity = AgentToolEntity(**tool) + # get tool + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=current_user.current_tenant_id, + agent_tool=agent_tool_entity, + agent_callback=None + ) + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + ) + + # get decrypted parameters + if agent_tool_entity.tool_parameters: + parameters = manager.decrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + masked_parameter = manager.mask_tool_parameters(parameters or {}) + else: + parameters = {} + masked_parameter = {} + + key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + masked_parameter_map[key] = masked_parameter + parameter_map[key] = parameters + tool_map[key] = tool_runtime + + # encrypt agent tool parameters if it's secret-input + agent_mode = new_app_model_config.agent_mode_dict + for tool in agent_mode.get('tools') or []: + agent_tool_entity = AgentToolEntity(**tool) + + # get tool + key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + if key in tool_map: + tool_runtime = tool_map[key] + else: + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id=current_user.current_tenant_id, + agent_tool=agent_tool_entity, + agent_callback=None + ) + + manager = ToolParameterConfigurationManager( + tenant_id=current_user.current_tenant_id, + tool_runtime=tool_runtime, + provider_name=agent_tool_entity.provider_id, + provider_type=agent_tool_entity.provider_type, + ) + manager.delete_tool_parameters_cache() + + # override parameters if it equals to masked parameters + if agent_tool_entity.tool_parameters: + if key not in masked_parameter_map: + continue + + if agent_tool_entity.tool_parameters == masked_parameter_map[key]: + agent_tool_entity.tool_parameters = parameter_map[key] + + # encrypt parameters + if agent_tool_entity.tool_parameters: + tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + + # update app model config + new_app_model_config.agent_mode = json.dumps(agent_mode) + db.session.add(new_app_model_config) db.session.flush() diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 817c75765a..931979c7f3 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -82,6 +82,30 @@ class ToolBuiltinProviderIconApi(Resource): icon_bytes, minetype = ToolManageService.get_builtin_tool_provider_icon(provider) return send_file(io.BytesIO(icon_bytes), mimetype=minetype) +class ToolModelProviderIconApi(Resource): + @setup_required + def get(self, provider): + icon_bytes, mimetype = ToolManageService.get_model_tool_provider_icon(provider) + return send_file(io.BytesIO(icon_bytes), mimetype=mimetype) + +class ToolModelProviderListToolsApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self): + user_id = current_user.id + tenant_id = current_user.current_tenant_id + + parser = reqparse.RequestParser() + parser.add_argument('provider', type=str, required=True, nullable=False, location='args') + + args = parser.parse_args() + + return ToolManageService.list_model_tool_provider_tools( + user_id, + tenant_id, + args['provider'], + ) class ToolApiProviderAddApi(Resource): @setup_required @@ -283,6 +307,8 @@ api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provide api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin//update') api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin//credentials_schema') api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin//icon') +api.add_resource(ToolModelProviderIconApi, '/workspaces/current/tool-provider/model//icon') +api.add_resource(ToolModelProviderListToolsApi, '/workspaces/current/tool-provider/model/tools') api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add') api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote') api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools') diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index d4a6b6aa4f..3f7cfcaea8 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -200,8 +200,8 @@ class DatasetSegmentApi(DatasetApiResource): parser.add_argument('segments', type=dict, required=False, nullable=True, location='json') args = parser.parse_args() - SegmentService.segment_create_args_validate(args['segments'], document) - segment = SegmentService.update_segment(args['segments'], segment, document, dataset) + SegmentService.segment_create_args_validate(args, document) + segment = SegmentService.update_segment(args, segment, document, dataset) return { 'data': marshal(segment, segment_fields), 'doc_form': document.doc_form diff --git a/api/core/app_runner/assistant_app_runner.py b/api/core/app_runner/assistant_app_runner.py index d9a3447bda..655a5a1c7c 100644 --- a/api/core/app_runner/assistant_app_runner.py +++ b/api/core/app_runner/assistant_app_runner.py @@ -195,6 +195,10 @@ class AssistantApplicationRunner(AppRunner): if set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL]).intersection(model_schema.features or []): agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING + db.session.refresh(conversation) + db.session.refresh(message) + db.session.close() + # start agent runner if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT: assistant_cot_runner = AssistantCotApplicationRunner( diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py index 83f4f6929a..d3c91337c8 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app_runner/basic_app_runner.py @@ -192,6 +192,8 @@ class BasicApplicationRunner(AppRunner): model=app_orchestration_config.model_config.model ) + db.session.close() + invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=app_orchestration_config.model_config.parameters, diff --git a/api/core/app_runner/generate_task_pipeline.py b/api/core/app_runner/generate_task_pipeline.py index 5fd635bc3b..1cc56483ad 100644 --- a/api/core/app_runner/generate_task_pipeline.py +++ b/api/core/app_runner/generate_task_pipeline.py @@ -89,6 +89,10 @@ class GenerateTaskPipeline: Process generate task pipeline. :return: """ + db.session.refresh(self._conversation) + db.session.refresh(self._message) + db.session.close() + if stream: return self._process_stream_response() else: @@ -303,6 +307,7 @@ class GenerateTaskPipeline: .first() ) db.session.refresh(agent_thought) + db.session.close() if agent_thought: response = { @@ -330,6 +335,8 @@ class GenerateTaskPipeline: .filter(MessageFile.id == event.message_file_id) .first() ) + db.session.close() + # get extension if '.' in message_file.url: extension = f'.{message_file.url.split(".")[-1]}' @@ -413,6 +420,7 @@ class GenerateTaskPipeline: usage = llm_result.usage self._message = db.session.query(Message).filter(Message.id == self._message.id).first() + self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages) self._message.message_tokens = usage.prompt_tokens diff --git a/api/core/application_manager.py b/api/core/application_manager.py index e073eac4b9..9aca61c7bb 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -201,7 +201,7 @@ class ApplicationManager: logger.exception("Unknown Error when generating") queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER) finally: - db.session.remove() + db.session.close() def _handle_response(self, application_generate_entity: ApplicationGenerateEntity, queue_manager: ApplicationQueueManager, @@ -233,8 +233,6 @@ class ApplicationManager: else: logger.exception(e) raise e - finally: - db.session.remove() def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_dict: dict) \ -> AppOrchestrationConfigEntity: @@ -651,6 +649,7 @@ class ApplicationManager: db.session.add(conversation) db.session.commit() + db.session.refresh(conversation) else: conversation = ( db.session.query(Conversation) @@ -689,6 +688,7 @@ class ApplicationManager: db.session.add(message) db.session.commit() + db.session.refresh(message) for file in application_generate_entity.files: message_file = MessageFile( diff --git a/api/core/features/assistant_base_runner.py b/api/core/features/assistant_base_runner.py index 2a4ae7e135..1d9541070f 100644 --- a/api/core/features/assistant_base_runner.py +++ b/api/core/features/assistant_base_runner.py @@ -114,6 +114,7 @@ class BaseAssistantApplicationRunner(AppRunner): self.agent_thought_count = db.session.query(MessageAgentThought).filter( MessageAgentThought.message_id == self.message.id, ).count() + db.session.close() # check if model supports stream tool call llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) @@ -154,9 +155,9 @@ class BaseAssistantApplicationRunner(AppRunner): """ convert tool to prompt message tool """ - tool_entity = ToolManager.get_tool_runtime( - provider_type=tool.provider_type, provider_name=tool.provider_id, tool_name=tool.tool_name, - tenant_id=self.application_generate_entity.tenant_id, + tool_entity = ToolManager.get_agent_tool_runtime( + tenant_id=self.tenant_id, + agent_tool=tool, agent_callback=self.agent_callback ) tool_entity.load_variables(self.variables_pool) @@ -171,33 +172,11 @@ class BaseAssistantApplicationRunner(AppRunner): } ) - runtime_parameters = {} - - parameters = tool_entity.parameters or [] - user_parameters = tool_entity.get_runtime_parameters() or [] - - # override parameters - for parameter in user_parameters: - # check if parameter in tool parameters - found = False - for tool_parameter in parameters: - if tool_parameter.name == parameter.name: - found = True - break - - if found: - # override parameter - tool_parameter.type = parameter.type - tool_parameter.form = parameter.form - tool_parameter.required = parameter.required - tool_parameter.default = parameter.default - tool_parameter.options = parameter.options - tool_parameter.llm_description = parameter.llm_description - else: - # add new parameter - parameters.append(parameter) - + parameters = tool_entity.get_all_runtime_parameters() for parameter in parameters: + if parameter.form != ToolParameter.ToolParameterForm.LLM: + continue + parameter_type = 'string' enum = [] if parameter.type == ToolParameter.ToolParameterType.STRING: @@ -213,59 +192,16 @@ class BaseAssistantApplicationRunner(AppRunner): else: raise ValueError(f"parameter type {parameter.type} is not supported") - if parameter.form == ToolParameter.ToolParameterForm.FORM: - # get tool parameter from form - tool_parameter_config = tool.tool_parameters.get(parameter.name) - if not tool_parameter_config: - # get default value - tool_parameter_config = parameter.default - if not tool_parameter_config and parameter.required: - raise ValueError(f"tool parameter {parameter.name} not found in tool config") - - if parameter.type == ToolParameter.ToolParameterType.SELECT: - # check if tool_parameter_config in options - options = list(map(lambda x: x.value, parameter.options)) - if tool_parameter_config not in options: - raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}") - - # convert tool parameter config to correct type - try: - if parameter.type == ToolParameter.ToolParameterType.NUMBER: - # check if tool parameter is integer - if isinstance(tool_parameter_config, int): - tool_parameter_config = tool_parameter_config - elif isinstance(tool_parameter_config, float): - tool_parameter_config = tool_parameter_config - elif isinstance(tool_parameter_config, str): - if '.' in tool_parameter_config: - tool_parameter_config = float(tool_parameter_config) - else: - tool_parameter_config = int(tool_parameter_config) - elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN: - tool_parameter_config = bool(tool_parameter_config) - elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]: - tool_parameter_config = str(tool_parameter_config) - elif parameter.type == ToolParameter.ToolParameterType: - tool_parameter_config = str(tool_parameter_config) - except Exception as e: - raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type") - - # save tool parameter to tool entity memory - runtime_parameters[parameter.name] = tool_parameter_config - - elif parameter.form == ToolParameter.ToolParameterForm.LLM: - message_tool.parameters['properties'][parameter.name] = { - "type": parameter_type, - "description": parameter.llm_description or '', - } + message_tool.parameters['properties'][parameter.name] = { + "type": parameter_type, + "description": parameter.llm_description or '', + } - if len(enum) > 0: - message_tool.parameters['properties'][parameter.name]['enum'] = enum + if len(enum) > 0: + message_tool.parameters['properties'][parameter.name]['enum'] = enum - if parameter.required: - message_tool.parameters['required'].append(parameter.name) - - tool_entity.runtime.runtime_parameters.update(runtime_parameters) + if parameter.required: + message_tool.parameters['required'].append(parameter.name) return message_tool, tool_entity @@ -305,6 +241,9 @@ class BaseAssistantApplicationRunner(AppRunner): tool_runtime_parameters = tool.get_runtime_parameters() or [] for parameter in tool_runtime_parameters: + if parameter.form != ToolParameter.ToolParameterForm.LLM: + continue + parameter_type = 'string' enum = [] if parameter.type == ToolParameter.ToolParameterType.STRING: @@ -320,18 +259,17 @@ class BaseAssistantApplicationRunner(AppRunner): else: raise ValueError(f"parameter type {parameter.type} is not supported") - if parameter.form == ToolParameter.ToolParameterForm.LLM: - prompt_tool.parameters['properties'][parameter.name] = { - "type": parameter_type, - "description": parameter.llm_description or '', - } + prompt_tool.parameters['properties'][parameter.name] = { + "type": parameter_type, + "description": parameter.llm_description or '', + } - if len(enum) > 0: - prompt_tool.parameters['properties'][parameter.name]['enum'] = enum + if len(enum) > 0: + prompt_tool.parameters['properties'][parameter.name]['enum'] = enum - if parameter.required: - if parameter.name not in prompt_tool.parameters['required']: - prompt_tool.parameters['required'].append(parameter.name) + if parameter.required: + if parameter.name not in prompt_tool.parameters['required']: + prompt_tool.parameters['required'].append(parameter.name) return prompt_tool @@ -404,13 +342,16 @@ class BaseAssistantApplicationRunner(AppRunner): created_by=self.user_id, ) db.session.add(message_file) + db.session.commit() + db.session.refresh(message_file) + result.append(( message_file, message.save_as )) - - db.session.commit() + db.session.close() + return result def create_agent_thought(self, message_id: str, message: str, @@ -447,6 +388,8 @@ class BaseAssistantApplicationRunner(AppRunner): db.session.add(thought) db.session.commit() + db.session.refresh(thought) + db.session.close() self.agent_thought_count += 1 @@ -464,6 +407,10 @@ class BaseAssistantApplicationRunner(AppRunner): """ Save agent thought """ + agent_thought = db.session.query(MessageAgentThought).filter( + MessageAgentThought.id == agent_thought.id + ).first() + if thought is not None: agent_thought.thought = thought @@ -514,6 +461,7 @@ class BaseAssistantApplicationRunner(AppRunner): agent_thought.tool_labels_str = json.dumps(labels) db.session.commit() + db.session.close() def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: """ @@ -586,9 +534,14 @@ class BaseAssistantApplicationRunner(AppRunner): """ convert tool variables to db variables """ + db_variables = db.session.query(ToolConversationVariables).filter( + ToolConversationVariables.conversation_id == self.message.conversation_id, + ).first() + db_variables.updated_at = datetime.utcnow() db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) db.session.commit() + db.session.close() def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ @@ -644,4 +597,6 @@ class BaseAssistantApplicationRunner(AppRunner): if message.answer: result.append(AssistantPromptMessage(content=message.answer)) + db.session.close() + return result \ No newline at end of file diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py new file mode 100644 index 0000000000..db05eb1875 --- /dev/null +++ b/api/core/helper/tool_parameter_cache.py @@ -0,0 +1,54 @@ +import json +from enum import Enum +from json import JSONDecodeError +from typing import Optional + +from extensions.ext_redis import redis_client + + +class ToolParameterCacheType(Enum): + PARAMETER = "tool_parameter" + +class ToolParameterCache: + def __init__(self, + tenant_id: str, + provider: str, + tool_name: str, + cache_type: ToolParameterCacheType + ): + self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}" + + def get(self) -> Optional[dict]: + """ + Get cached model provider credentials. + + :return: + """ + cached_tool_parameter = redis_client.get(self.cache_key) + if cached_tool_parameter: + try: + cached_tool_parameter = cached_tool_parameter.decode('utf-8') + cached_tool_parameter = json.loads(cached_tool_parameter) + except JSONDecodeError: + return None + + return cached_tool_parameter + else: + return None + + def set(self, parameters: dict) -> None: + """ + Cache model provider credentials. + + :param credentials: provider credentials + :return: + """ + redis_client.setex(self.cache_key, 86400, json.dumps(parameters)) + + def delete(self) -> None: + """ + Delete cached model provider credentials. + + :return: + """ + redis_client.delete(self.cache_key) \ No newline at end of file diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 880a30cdf4..45ad1b51bf 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -82,6 +82,8 @@ class HostingConfiguration: RestrictModel(model="gpt-35-turbo-16k", base_model_name="gpt-35-turbo-16k", model_type=ModelType.LLM), RestrictModel(model="text-davinci-003", base_model_name="text-davinci-003", model_type=ModelType.LLM), RestrictModel(model="text-embedding-ada-002", base_model_name="text-embedding-ada-002", model_type=ModelType.TEXT_EMBEDDING), + RestrictModel(model="text-embedding-3-small", base_model_name="text-embedding-3-small", model_type=ModelType.TEXT_EMBEDDING), + RestrictModel(model="text-embedding-3-large", base_model_name="text-embedding-3-large", model_type=ModelType.TEXT_EMBEDDING), ] ) quotas.append(trial_quota) diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index f1f8ab3a3b..4d44ac3818 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -47,11 +47,14 @@ class TokenBufferMemory: files, message.app_model_config ) - prompt_message_contents = [TextPromptMessageContent(data=message.query)] - for file_obj in file_objs: - prompt_message_contents.append(file_obj.prompt_message_content) + if not file_objs: + prompt_messages.append(UserPromptMessage(content=message.query)) + else: + prompt_message_contents = [TextPromptMessageContent(data=message.query)] + for file_obj in file_objs: + prompt_message_contents.append(file_obj.prompt_message_content) - prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) + prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) else: prompt_messages.append(UserPromptMessage(content=message.query)) diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 52c2d66f9f..60cb655c98 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -17,7 +17,7 @@ class ModelType(Enum): SPEECH2TEXT = "speech2text" MODERATION = "moderation" TTS = "tts" - # TEXT2IMG = "text2img" + TEXT2IMG = "text2img" @classmethod def value_of(cls, origin_model_type: str) -> "ModelType": @@ -36,6 +36,8 @@ class ModelType(Enum): return cls.SPEECH2TEXT elif origin_model_type == 'tts' or origin_model_type == cls.TTS.value: return cls.TTS + elif origin_model_type == 'text2img' or origin_model_type == cls.TEXT2IMG.value: + return cls.TEXT2IMG elif origin_model_type == cls.MODERATION.value: return cls.MODERATION else: @@ -59,10 +61,11 @@ class ModelType(Enum): return 'tts' elif self == self.MODERATION: return 'moderation' + elif self == self.TEXT2IMG: + return 'text2img' else: raise ValueError(f'invalid model type {self}') - class FetchFrom(Enum): """ Enum class for fetch from. diff --git a/api/core/model_runtime/model_providers/__base/text2img_model.py b/api/core/model_runtime/model_providers/__base/text2img_model.py new file mode 100644 index 0000000000..972a2ea14a --- /dev/null +++ b/api/core/model_runtime/model_providers/__base/text2img_model.py @@ -0,0 +1,48 @@ +from abc import abstractmethod +from typing import IO, Optional + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.ai_model import AIModel + + +class Text2ImageModel(AIModel): + """ + Model class for text2img model. + """ + model_type: ModelType = ModelType.TEXT2IMG + + def invoke(self, model: str, credentials: dict, prompt: str, + model_parameters: dict, user: Optional[str] = None) \ + -> list[IO[bytes]]: + """ + Invoke Text2Image model + + :param model: model name + :param credentials: model credentials + :param prompt: prompt for image generation + :param model_parameters: model parameters + :param user: unique user id + + :return: image bytes + """ + try: + return self._invoke(model, credentials, prompt, model_parameters, user) + except Exception as e: + raise self._transform_invoke_error(e) + + @abstractmethod + def _invoke(self, model: str, credentials: dict, prompt: str, + model_parameters: dict, user: Optional[str] = None) \ + -> list[IO[bytes]]: + """ + Invoke Text2Image model + + :param model: model name + :param credentials: model credentials + :param prompt: prompt for image generation + :param model_parameters: model parameters + :param user: unique user id + + :return: image bytes + """ + raise NotImplementedError diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml index 8c878d67d8..2dcdc1bf2e 100644 --- a/api/core/model_runtime/model_providers/_position.yaml +++ b/api/core/model_runtime/model_providers/_position.yaml @@ -7,6 +7,7 @@ - togetherai - ollama - mistralai +- groq - replicate - huggingface_hub - zhipuai diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 6f9f41ca44..ad74179353 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -424,8 +424,25 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): if isinstance(message, UserPromptMessage): message_text = f"{human_prompt} {content}" + if not isinstance(message.content, list): + message_text = f"{ai_prompt} {content}" + else: + message_text = "" + for sub_message in message.content: + if sub_message.type == PromptMessageContentType.TEXT: + message_text += f"{human_prompt} {sub_message.data}" + elif sub_message.type == PromptMessageContentType.IMAGE: + message_text += f"{human_prompt} [IMAGE]" elif isinstance(message, AssistantPromptMessage): - message_text = f"{ai_prompt} {content}" + if not isinstance(message.content, list): + message_text = f"{ai_prompt} {content}" + else: + message_text = "" + for sub_message in message.content: + if sub_message.type == PromptMessageContentType.TEXT: + message_text += f"{ai_prompt} {sub_message.data}" + elif sub_message.type == PromptMessageContentType.IMAGE: + message_text += f"{ai_prompt} [IMAGE]" elif isinstance(message, SystemPromptMessage): message_text = content else: diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index 90dd2e7a6b..7fc0da73fb 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -524,5 +524,62 @@ EMBEDDING_BASE_MODELS = [ currency='USD', ) ) + ), + AzureBaseModel( + base_model_name='text-embedding-3-small', + entity=AIModelEntity( + model='fake-deployment-name', + label=I18nObject( + en_US='fake-deployment-name-label' + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.TEXT_EMBEDDING, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: 8191, + ModelPropertyKey.MAX_CHUNKS: 32, + }, + pricing=PriceConfig( + input=0.00002, + unit=0.001, + currency='USD', + ) + ) + ), + AzureBaseModel( + base_model_name='text-embedding-3-large', + entity=AIModelEntity( + model='fake-deployment-name', + label=I18nObject( + en_US='fake-deployment-name-label' + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.TEXT_EMBEDDING, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: 8191, + ModelPropertyKey.MAX_CHUNKS: 32, + }, + pricing=PriceConfig( + input=0.00013, + unit=0.001, + currency='USD', + ) + ) + ) +] +SPEECH2TEXT_BASE_MODELS = [ + AzureBaseModel( + base_model_name='whisper-1', + entity=AIModelEntity( + model='fake-deployment-name', + label=I18nObject( + en_US='fake-deployment-name-label' + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.SPEECH2TEXT, + model_properties={ + ModelPropertyKey.FILE_UPLOAD_LIMIT: 25, + ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: 'flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm' + } + ) ) ] diff --git a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml index c081808639..6c56ccc920 100644 --- a/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml +++ b/api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml @@ -15,6 +15,7 @@ help: supported_model_types: - llm - text-embedding + - speech2text configurate_methods: - customizable-model model_credential_schema: @@ -99,6 +100,24 @@ model_credential_schema: show_on: - variable: __model_type value: text-embedding + - label: + en_US: text-embedding-3-small + value: text-embedding-3-small + show_on: + - variable: __model_type + value: text-embedding + - label: + en_US: text-embedding-3-large + value: text-embedding-3-large + show_on: + - variable: __model_type + value: text-embedding + - label: + en_US: whisper-1 + value: whisper-1 + show_on: + - variable: __model_type + value: speech2text placeholder: zh_Hans: 在此输入您的模型版本 en_US: Enter your model version diff --git a/api/core/model_runtime/model_providers/azure_openai/speech2text/__init__.py b/api/core/model_runtime/model_providers/azure_openai/speech2text/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py new file mode 100644 index 0000000000..8aebcb90e4 --- /dev/null +++ b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py @@ -0,0 +1,82 @@ +import copy +from typing import IO, Optional + +from openai import AzureOpenAI + +from core.model_runtime.entities.model_entities import AIModelEntity +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI +from core.model_runtime.model_providers.azure_openai._constant import SPEECH2TEXT_BASE_MODELS, AzureBaseModel + + +class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel): + """ + Model class for OpenAI Speech to text model. + """ + + def _invoke(self, model: str, credentials: dict, + file: IO[bytes], user: Optional[str] = None) \ + -> str: + """ + Invoke speech2text model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :param user: unique user id + :return: text for given audio file + """ + return self._speech2text_invoke(model, credentials, file) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + audio_file_path = self._get_demo_file_path() + + with open(audio_file_path, 'rb') as audio_file: + self._speech2text_invoke(model, credentials, audio_file) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str: + """ + Invoke speech2text model + + :param model: model name + :param credentials: model credentials + :param file: audio file + :return: text for given audio file + """ + # transform credentials to kwargs for model instance + credentials_kwargs = self._to_credential_kwargs(credentials) + + # init model client + client = AzureOpenAI(**credentials_kwargs) + + response = client.audio.transcriptions.create(model=model, file=file) + + return response.text + + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) + return ai_model_entity.entity + + + @staticmethod + def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: + for ai_model_entity in SPEECH2TEXT_BASE_MODELS: + if ai_model_entity.base_model_name == base_model_name: + ai_model_entity_copy = copy.deepcopy(ai_model_entity) + ai_model_entity_copy.entity.model = model + ai_model_entity_copy.entity.label.en_US = model + ai_model_entity_copy.entity.label.zh_Hans = model + return ai_model_entity_copy + + return None diff --git a/api/core/model_runtime/model_providers/groq/_assets/icon_l_en.svg b/api/core/model_runtime/model_providers/groq/_assets/icon_l_en.svg new file mode 100644 index 0000000000..2505a5f493 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/_assets/icon_l_en.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/api/core/model_runtime/model_providers/groq/_assets/icon_s_en.svg b/api/core/model_runtime/model_providers/groq/_assets/icon_s_en.svg new file mode 100644 index 0000000000..087f37e471 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/_assets/icon_s_en.svg @@ -0,0 +1,4 @@ + + + + diff --git a/api/core/model_runtime/model_providers/groq/groq.py b/api/core/model_runtime/model_providers/groq/groq.py new file mode 100644 index 0000000000..1421aaaf2b --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/groq.py @@ -0,0 +1,29 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + +class GroqProvider(ModelProvider): + + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + + model_instance.validate_credentials( + model='llama2-70b-4096', + credentials=credentials + ) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + raise ex diff --git a/api/core/model_runtime/model_providers/groq/groq.yaml b/api/core/model_runtime/model_providers/groq/groq.yaml new file mode 100644 index 0000000000..db17cc8bdd --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/groq.yaml @@ -0,0 +1,32 @@ +provider: groq +label: + zh_Hans: GroqCloud + en_US: GroqCloud +description: + en_US: GroqCloud provides access to the Groq Cloud API, which hosts models like LLama2 and Mixtral. + zh_Hans: GroqCloud 提供对 Groq Cloud API 的访问,其中托管了 LLama2 和 Mixtral 等模型。 +icon_small: + en_US: icon_s_en.svg +icon_large: + en_US: icon_l_en.svg +background: "#F5F5F4" +help: + title: + en_US: Get your API Key from GroqCloud + zh_Hans: 从 GroqCloud 获取 API Key + url: + en_US: https://console.groq.com/ +supported_model_types: + - llm +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key diff --git a/api/core/model_runtime/model_providers/groq/llm/llama2-70b-4096.yaml b/api/core/model_runtime/model_providers/groq/llm/llama2-70b-4096.yaml new file mode 100644 index 0000000000..384912b0dd --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/llm/llama2-70b-4096.yaml @@ -0,0 +1,25 @@ +model: llama2-70b-4096 +label: + zh_Hans: Llama-2-70B-4096 + en_US: Llama-2-70B-4096 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 4096 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 4096 +pricing: + input: '0.7' + output: '0.8' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/groq/llm/llm.py b/api/core/model_runtime/model_providers/groq/llm/llm.py new file mode 100644 index 0000000000..915f7a4e1a --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/llm/llm.py @@ -0,0 +1,26 @@ +from collections.abc import Generator +from typing import Optional, Union + +from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + + +class GroqLargeLanguageModel(OAIAPICompatLargeLanguageModel): + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + self._add_custom_parameters(credentials) + return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + super().validate_credentials(model, credentials) + + @staticmethod + def _add_custom_parameters(credentials: dict) -> None: + credentials['mode'] = 'chat' + credentials['endpoint_url'] = 'https://api.groq.com/openai/v1' + diff --git a/api/core/model_runtime/model_providers/groq/llm/mixtral-8x7b-instruct-v0.1.yaml b/api/core/model_runtime/model_providers/groq/llm/mixtral-8x7b-instruct-v0.1.yaml new file mode 100644 index 0000000000..0dc6678fa2 --- /dev/null +++ b/api/core/model_runtime/model_providers/groq/llm/mixtral-8x7b-instruct-v0.1.yaml @@ -0,0 +1,25 @@ +model: mixtral-8x7b-32768 +label: + zh_Hans: Mixtral-8x7b-Instruct-v0.1 + en_US: Mixtral-8x7b-Instruct-v0.1 +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32768 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 20480 +pricing: + input: '0.27' + output: '0.27' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py index c388341d51..50f8c73ed9 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py @@ -1,20 +1,32 @@ from os.path import abspath, dirname, join +from threading import Lock from transformers import AutoTokenizer class JinaTokenizer: - @staticmethod - def _get_num_tokens_by_jina_base(text: str) -> int: + _tokenizer = None + _lock = Lock() + + @classmethod + def _get_tokenizer(cls): + if cls._tokenizer is None: + with cls._lock: + if cls._tokenizer is None: + base_path = abspath(__file__) + gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer') + cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path) + return cls._tokenizer + + @classmethod + def _get_num_tokens_by_jina_base(cls, text: str) -> int: """ use jina tokenizer to get num tokens """ - base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer') - tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path) + tokenizer = cls._get_tokenizer() tokens = tokenizer.encode(text) return len(tokens) - @staticmethod - def get_num_tokens(text: str) -> int: - return JinaTokenizer._get_num_tokens_by_jina_base(text) \ No newline at end of file + @classmethod + def get_num_tokens(cls, text: str) -> int: + return cls._get_num_tokens_by_jina_base(text) \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/openai/speech2text/whisper-1.yaml b/api/core/model_runtime/model_providers/openai/speech2text/whisper-1.yaml index 3cfe6f1a3a..6c14c76619 100644 --- a/api/core/model_runtime/model_providers/openai/speech2text/whisper-1.yaml +++ b/api/core/model_runtime/model_providers/openai/speech2text/whisper-1.yaml @@ -2,4 +2,4 @@ model: whisper-1 model_type: speech2text model_properties: file_upload_limit: 25 - supported_file_extensions: mp3,mp4,mpeg,mpga,m4a,wav,webm + supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm diff --git a/api/core/model_runtime/model_providers/openai/text_embedding/text-embedidng-ada-002.yaml b/api/core/model_runtime/model_providers/openai/text_embedding/text-embedding-ada-002.yaml similarity index 100% rename from api/core/model_runtime/model_providers/openai/text_embedding/text-embedidng-ada-002.yaml rename to api/core/model_runtime/model_providers/openai/text_embedding/text-embedding-ada-002.yaml diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index ffb4a0328c..602d0b749f 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -308,6 +308,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): type=ParameterType.INT, use_template='max_tokens', min=1, + max=credentials.get('context_length', 2048), default=512, label=I18nObject( zh_Hans='最大生成长度', diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index 1399e9ccd2..dd25037d34 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -44,6 +44,9 @@ class XinferenceRerankModel(RerankModel): docs=[] ) + if credentials['server_url'].endswith('/'): + credentials['server_url'] = credentials['server_url'][:-1] + # initialize client client = Client( base_url=credentials['server_url'] diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 8e2cd14be7..a41c727f35 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -10,7 +10,7 @@ from core.rag.models.document import Document class WordExtractor(BaseExtractor): - """Load pdf files. + """Load docx files. Args: @@ -46,14 +46,16 @@ class WordExtractor(BaseExtractor): def extract(self) -> list[Document]: """Load given path as single page.""" - import docx2txt + from docx import Document as docx_Document - return [ - Document( - page_content=docx2txt.process(self.file_path), - metadata={"source": self.file_path}, - ) - ] + document = docx_Document(self.file_path) + doc_texts = [paragraph.text for paragraph in document.paragraphs] + content = '\n'.join(doc_texts) + + return [Document( + page_content=content, + metadata={"source": self.file_path}, + )] @staticmethod def _is_valid_url(url: str) -> bool: diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index fcb06e5c84..509a1a189b 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -52,7 +52,7 @@ class BaseIndexProcessor(ABC): character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( chunk_size=segmentation["max_tokens"], - chunk_overlap=0, + chunk_overlap=segmentation.get('chunk_overlap', 0), fixed_separator=separator, separators=["\n\n", "。", ".", " ", ""], embedding_model_instance=embedding_model_instance @@ -61,7 +61,7 @@ class BaseIndexProcessor(ABC): # Automatic segmentation character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'], - chunk_overlap=0, + chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'], separators=["\n\n", "。", ".", " ", ""], embedding_model_instance=embedding_model_instance ) diff --git a/api/core/splitter/text_splitter.py b/api/core/splitter/text_splitter.py index e3d43c0658..5eeb237a96 100644 --- a/api/core/splitter/text_splitter.py +++ b/api/core/splitter/text_splitter.py @@ -30,7 +30,7 @@ def _split_text_with_regex( if separator: if keep_separator: # The parentheses in the pattern keep the delimiters in the result. - _splits = re.split(f"({separator})", text) + _splits = re.split(f"({re.escape(separator)})", text) splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)] if len(_splits) % 2 == 0: splits += _splits[-1:] @@ -94,7 +94,7 @@ class TextSplitter(BaseDocumentTransformer, ABC): documents.append(new_doc) return documents - def split_documents(self, documents: Iterable[Document]) -> list[Document]: + def split_documents(self, documents: Iterable[Document] ) -> list[Document]: """Split documents.""" texts, metadatas = [], [] for doc in documents: diff --git a/api/core/tools/docs/en_US/tool_scale_out.md b/api/core/tools/docs/en_US/tool_scale_out.md index 589a3c8810..e0269e0209 100644 --- a/api/core/tools/docs/en_US/tool_scale_out.md +++ b/api/core/tools/docs/en_US/tool_scale_out.md @@ -119,7 +119,7 @@ parameters: # Parameter list - The `identity` field is mandatory, it contains the basic information of the tool, including name, author, label, description, etc. - `parameters` Parameter list - `name` Parameter name, unique, no duplication with other parameters - - `type` Parameter type, currently supports `string`, `number`, `boolean`, `select` four types, corresponding to string, number, boolean, drop-down box + - `type` Parameter type, currently supports `string`, `number`, `boolean`, `select`, `secret-input` four types, corresponding to string, number, boolean, drop-down box, and encrypted input box, respectively. For sensitive information, we recommend using `secret-input` type - `required` Required or not - In `llm` mode, if the parameter is required, the Agent is required to infer this parameter - In `form` mode, if the parameter is required, the user is required to fill in this parameter on the frontend before the conversation starts diff --git a/api/core/tools/docs/zh_Hans/tool_scale_out.md b/api/core/tools/docs/zh_Hans/tool_scale_out.md index be146a5aeb..20bb5e6dbc 100644 --- a/api/core/tools/docs/zh_Hans/tool_scale_out.md +++ b/api/core/tools/docs/zh_Hans/tool_scale_out.md @@ -119,7 +119,7 @@ parameters: # 参数列表 - `identity` 字段是必须的,它包含了工具的基本信息,包括名称、作者、标签、描述等 - `parameters` 参数列表 - `name` 参数名称,唯一,不允许和其他参数重名 - - `type` 参数类型,目前支持`string`、`number`、`boolean`、`select` 四种类型,分别对应字符串、数字、布尔值、下拉框 + - `type` 参数类型,目前支持`string`、`number`、`boolean`、`select`、`secret-input` 五种类型,分别对应字符串、数字、布尔值、下拉框、加密输入框,对于敏感信息,我们建议使用`secret-input`类型 - `required` 是否必填 - 在`llm`模式下,如果参数为必填,则会要求Agent必须要推理出这个参数 - 在`form`模式下,如果参数为必填,则会要求用户在对话开始前在前端填写这个参数 diff --git a/api/core/tools/entities/common_entities.py b/api/core/tools/entities/common_entities.py index 13c27b57ee..55e31e8c35 100644 --- a/api/core/tools/entities/common_entities.py +++ b/api/core/tools/entities/common_entities.py @@ -8,15 +8,19 @@ class I18nObject(BaseModel): Model class for i18n object. """ zh_Hans: Optional[str] = None + pt_BR: Optional[str] = None en_US: str def __init__(self, **data): super().__init__(**data) if not self.zh_Hans: self.zh_Hans = self.en_US + if not self.pt_BR: + self.pt_BR = self.en_US def to_dict(self) -> dict: return { 'zh_Hans': self.zh_Hans, 'en_US': self.en_US, - } \ No newline at end of file + 'pt_BR': self.pt_BR + } diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index ad27706c3a..f7a61b0b0c 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -100,6 +100,7 @@ class ToolParameter(BaseModel): NUMBER = "number" BOOLEAN = "boolean" SELECT = "select" + SECRET_INPUT = "secret-input" class ToolParameterForm(Enum): SCHEMA = "schema" # should be set while adding tool @@ -304,4 +305,24 @@ class ToolRuntimeVariablePool(BaseModel): value=value, ) - self.pool.append(variable) \ No newline at end of file + self.pool.append(variable) + +class ModelToolPropertyKey(Enum): + IMAGE_PARAMETER_NAME = "image_parameter_name" + +class ModelToolConfiguration(BaseModel): + """ + Model tool configuration + """ + type: str = Field(..., description="The type of the model tool") + model: str = Field(..., description="The model") + label: I18nObject = Field(..., description="The label of the model tool") + properties: dict[ModelToolPropertyKey, Any] = Field(..., description="The properties of the model tool") + +class ModelToolProviderConfiguration(BaseModel): + """ + Model tool provider configuration + """ + provider: str = Field(..., description="The provider of the model tool") + models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool") + label: I18nObject = Field(..., description="The label of the model tool") \ No newline at end of file diff --git a/api/core/tools/entities/user_entities.py b/api/core/tools/entities/user_entities.py index 2641079333..8a5589da27 100644 --- a/api/core/tools/entities/user_entities.py +++ b/api/core/tools/entities/user_entities.py @@ -13,6 +13,7 @@ class UserToolProvider(BaseModel): BUILTIN = "builtin" APP = "app" API = "api" + MODEL = "model" id: str author: str diff --git a/api/core/tools/model_tools/anthropic.yaml b/api/core/tools/model_tools/anthropic.yaml new file mode 100644 index 0000000000..4ccb973df5 --- /dev/null +++ b/api/core/tools/model_tools/anthropic.yaml @@ -0,0 +1,20 @@ +provider: anthropic +label: + en_US: Anthropic Model Tools + zh_Hans: Anthropic 模型能力 + pt_BR: Anthropic Model Tools +models: + - type: llm + model: claude-3-sonnet-20240229 + label: + zh_Hans: Claude3 Sonnet 视觉 + en_US: Claude3 Sonnet Vision + properties: + image_parameter_name: image_id + - type: llm + model: claude-3-opus-20240229 + label: + zh_Hans: Claude3 Opus 视觉 + en_US: Claude3 Opus Vision + properties: + image_parameter_name: image_id diff --git a/api/core/tools/model_tools/google.yaml b/api/core/tools/model_tools/google.yaml new file mode 100644 index 0000000000..d81e1b0735 --- /dev/null +++ b/api/core/tools/model_tools/google.yaml @@ -0,0 +1,13 @@ +provider: google +label: + en_US: Google Model Tools + zh_Hans: Google 模型能力 + pt_BR: Google Model Tools +models: + - type: llm + model: gemini-pro-vision + label: + zh_Hans: Gemini Pro 视觉 + en_US: Gemini Pro Vision + properties: + image_parameter_name: image_id diff --git a/api/core/tools/model_tools/openai.yaml b/api/core/tools/model_tools/openai.yaml new file mode 100644 index 0000000000..45cbb295a9 --- /dev/null +++ b/api/core/tools/model_tools/openai.yaml @@ -0,0 +1,13 @@ +provider: openai +label: + en_US: OpenAI Model Tools + zh_Hans: OpenAI 模型能力 + pt_BR: OpenAI Model Tools +models: + - type: llm + model: gpt-4-vision-preview + label: + zh_Hans: GPT-4 视觉 + en_US: GPT-4 Vision + properties: + image_parameter_name: image_id diff --git a/api/core/tools/model_tools/zhipuai.yaml b/api/core/tools/model_tools/zhipuai.yaml new file mode 100644 index 0000000000..19a932eb89 --- /dev/null +++ b/api/core/tools/model_tools/zhipuai.yaml @@ -0,0 +1,13 @@ +provider: zhipuai +label: + en_US: ZhipuAI Model Tools + zh_Hans: ZhipuAI 模型能力 + pt_BR: ZhipuAI Model Tools +models: + - type: llm + model: glm-4v + label: + zh_Hans: GLM-4 视觉 + en_US: GLM-4 Vision + properties: + image_parameter_name: image_id diff --git a/api/core/tools/provider/_position.yaml b/api/core/tools/provider/_position.yaml index 2f332b1c31..ece9dbe159 100644 --- a/api/core/tools/provider/_position.yaml +++ b/api/core/tools/provider/_position.yaml @@ -1,14 +1,19 @@ - google - bing - duckduckgo -- yahoo -- wikipedia -- arxiv -- pubmed - dalle - azuredalle +- wikipedia +- model.openai +- model.google +- model.anthropic +- yahoo +- arxiv +- pubmed - stablediffusion - webscraper +- model.zhipuai +- aippt - youtube - wolframalpha - maths diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index c6cad0187e..fa2c5d27ef 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -4,24 +4,24 @@ from yaml import FullLoader, load from core.tools.entities.user_entities import UserToolProvider -position = {} class BuiltinToolProviderSort: - @staticmethod - def sort(providers: list[UserToolProvider]) -> list[UserToolProvider]: - global position - if not position: + _position = {} + + @classmethod + def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: + if not cls._position: tmp_position = {} file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml') with open(file_path) as f: for pos, val in enumerate(load(f, Loader=FullLoader)): tmp_position[val] = pos - position = tmp_position + cls._position = tmp_position def sort_compare(provider: UserToolProvider) -> int: - # if provider.type == UserToolProvider.ProviderType.MODEL: - # return position.get(f'model_provider.{provider.name}', 10000) - return position.get(provider.name, 10000) + if provider.type == UserToolProvider.ProviderType.MODEL: + return cls._position.get(f'model.{provider.name}', 10000) + return cls._position.get(provider.name, 10000) sorted_providers = sorted(providers, key=sort_compare) diff --git a/api/core/tools/provider/builtin/aippt/_assets/icon.png b/api/core/tools/provider/builtin/aippt/_assets/icon.png new file mode 100644 index 0000000000..b70618b487 Binary files /dev/null and b/api/core/tools/provider/builtin/aippt/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/aippt/aippt.py b/api/core/tools/provider/builtin/aippt/aippt.py new file mode 100644 index 0000000000..25133c51df --- /dev/null +++ b/api/core/tools/provider/builtin/aippt/aippt.py @@ -0,0 +1,11 @@ +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.aippt.tools.aippt import AIPPTGenerateTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class AIPPTProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + try: + AIPPTGenerateTool._get_api_token(credentials, user_id='__dify_system__') + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/aippt/aippt.yaml b/api/core/tools/provider/builtin/aippt/aippt.yaml new file mode 100644 index 0000000000..b3ff1f6d98 --- /dev/null +++ b/api/core/tools/provider/builtin/aippt/aippt.yaml @@ -0,0 +1,42 @@ +identity: + author: Dify + name: aippt + label: + en_US: AIPPT + zh_Hans: AIPPT + description: + en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop + zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底 + icon: icon.png +credentials_for_provider: + aippt_access_key: + type: secret-input + required: true + label: + en_US: AIPPT API key + zh_Hans: AIPPT API key + pt_BR: AIPPT API key + help: + en_US: Please input your AIPPT API key + zh_Hans: 请输入你的 AIPPT API key + pt_BR: Please input your AIPPT API key + placeholder: + en_US: Please input your AIPPT API key + zh_Hans: 请输入你的 AIPPT API key + pt_BR: Please input your AIPPT API key + url: https://www.aippt.cn + aippt_secret_key: + type: secret-input + required: true + label: + en_US: AIPPT Secret key + zh_Hans: AIPPT Secret key + pt_BR: AIPPT Secret key + help: + en_US: Please input your AIPPT Secret key + zh_Hans: 请输入你的 AIPPT Secret key + pt_BR: Please input your AIPPT Secret key + placeholder: + en_US: Please input your AIPPT Secret key + zh_Hans: 请输入你的 AIPPT Secret key + pt_BR: Please input your AIPPT Secret key diff --git a/api/core/tools/provider/builtin/aippt/tools/aippt.py b/api/core/tools/provider/builtin/aippt/tools/aippt.py new file mode 100644 index 0000000000..81465848a2 --- /dev/null +++ b/api/core/tools/provider/builtin/aippt/tools/aippt.py @@ -0,0 +1,541 @@ +from base64 import b64encode +from hashlib import sha1 +from hmac import new as hmac_new +from json import loads as json_loads +from threading import Lock +from time import sleep, time +from typing import Any + +from httpx import get, post +from requests import get as requests_get +from yarl import URL + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption +from core.tools.tool.builtin_tool import BuiltinTool + + +class AIPPTGenerateTool(BuiltinTool): + """ + A tool for generating a ppt + """ + + _api_base_url = URL('https://co.aippt.cn/api') + _api_token_cache = {} + _api_token_cache_lock = Lock() + _style_cache = {} + _style_cache_lock = Lock() + + _task = {} + _task_type_map = { + 'auto': 1, + 'markdown': 7, + } + + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + """ + Invokes the AIPPT generate tool with the given user ID and tool parameters. + + Args: + user_id (str): The ID of the user invoking the tool. + tool_parameters (dict[str, Any]): The parameters for the tool + + Returns: + ToolInvokeMessage | list[ToolInvokeMessage]: The result of the tool invocation, which can be a single message or a list of messages. + """ + title = tool_parameters.get('title', '') + if not title: + return self.create_text_message('Please provide a title for the ppt') + + model = tool_parameters.get('model', 'aippt') + if not model: + return self.create_text_message('Please provide a model for the ppt') + + outline = tool_parameters.get('outline', '') + + # create task + task_id = self._create_task( + type=self._task_type_map['auto' if not outline else 'markdown'], + title=title, + content=outline, + user_id=user_id + ) + + # get suit + color = tool_parameters.get('color') + style = tool_parameters.get('style') + + if color == '__default__': + color_id = '' + else: + color_id = int(color.split('-')[1]) + + if style == '__default__': + style_id = '' + else: + style_id = int(style.split('-')[1]) + + suit_id = self._get_suit(style_id=style_id, colour_id=color_id) + + # generate outline + if not outline: + self._generate_outline( + task_id=task_id, + model=model, + user_id=user_id + ) + + # generate content + self._generate_content( + task_id=task_id, + model=model, + user_id=user_id + ) + + # generate ppt + _, ppt_url = self._generate_ppt( + task_id=task_id, + suit_id=suit_id, + user_id=user_id + ) + + return self.create_text_message('''the ppt has been created successfully,''' + f'''the ppt url is {ppt_url}''' + '''please give the ppt url to user and direct user to download it.''') + + def _create_task(self, type: int, title: str, content: str, user_id: str) -> str: + """ + Create a task + + :param type: the task type + :param title: the task title + :param content: the task content + + :return: the task ID + """ + headers = { + 'x-channel': '', + 'x-api-key': self.runtime.credentials['aippt_access_key'], + 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + } + response = post( + str(self._api_base_url / 'ai' / 'chat' / 'v2' / 'task'), + headers=headers, + files={ + 'type': ('', str(type)), + 'title': ('', title), + 'content': ('', content) + } + ) + + if response.status_code != 200: + raise Exception(f'Failed to connect to aippt: {response.text}') + + response = response.json() + if response.get('code') != 0: + raise Exception(f'Failed to create task: {response.get("msg")}') + + return response.get('data', {}).get('id') + + def _generate_outline(self, task_id: str, model: str, user_id: str) -> str: + api_url = self._api_base_url / 'ai' / 'chat' / 'outline' if model == 'aippt' else \ + self._api_base_url / 'ai' / 'chat' / 'wx' / 'outline' + api_url %= {'task_id': task_id} + + headers = { + 'x-channel': '', + 'x-api-key': self.runtime.credentials['aippt_access_key'], + 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + } + + response = requests_get( + url=api_url, + headers=headers, + stream=True, + timeout=(10, 60) + ) + + if response.status_code != 200: + raise Exception(f'Failed to connect to aippt: {response.text}') + + outline = '' + for chunk in response.iter_lines(delimiter=b'\n\n'): + if not chunk: + continue + + event = '' + lines = chunk.decode('utf-8').split('\n') + for line in lines: + if line.startswith('event:'): + event = line[6:] + elif line.startswith('data:'): + data = line[5:] + if event == 'message': + try: + data = json_loads(data) + outline += data.get('content', '') + except Exception as e: + pass + elif event == 'close': + break + elif event == 'error' or event == 'filter': + raise Exception(f'Failed to generate outline: {data}') + + return outline + + def _generate_content(self, task_id: str, model: str, user_id: str) -> str: + api_url = self._api_base_url / 'ai' / 'chat' / 'content' if model == 'aippt' else \ + self._api_base_url / 'ai' / 'chat' / 'wx' / 'content' + api_url %= {'task_id': task_id} + + headers = { + 'x-channel': '', + 'x-api-key': self.runtime.credentials['aippt_access_key'], + 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + } + + response = requests_get( + url=api_url, + headers=headers, + stream=True, + timeout=(10, 60) + ) + + if response.status_code != 200: + raise Exception(f'Failed to connect to aippt: {response.text}') + + if model == 'aippt': + content = '' + for chunk in response.iter_lines(delimiter=b'\n\n'): + if not chunk: + continue + + event = '' + lines = chunk.decode('utf-8').split('\n') + for line in lines: + if line.startswith('event:'): + event = line[6:] + elif line.startswith('data:'): + data = line[5:] + if event == 'message': + try: + data = json_loads(data) + content += data.get('content', '') + except Exception as e: + pass + elif event == 'close': + break + elif event == 'error' or event == 'filter': + raise Exception(f'Failed to generate content: {data}') + + return content + elif model == 'wenxin': + response = response.json() + if response.get('code') != 0: + raise Exception(f'Failed to generate content: {response.get("msg")}') + + return response.get('data', '') + + return '' + + def _generate_ppt(self, task_id: str, suit_id: int, user_id) -> tuple[str, str]: + """ + Generate a ppt + + :param task_id: the task ID + :param suit_id: the suit ID + :return: the cover url of the ppt and the ppt url + """ + headers = { + 'x-channel': '', + 'x-api-key': self.runtime.credentials['aippt_access_key'], + 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id=user_id), + } + + response = post( + str(self._api_base_url / 'design' / 'v2' / 'save'), + headers=headers, + data={ + 'task_id': task_id, + 'template_id': suit_id + } + ) + + if response.status_code != 200: + raise Exception(f'Failed to connect to aippt: {response.text}') + + response = response.json() + if response.get('code') != 0: + raise Exception(f'Failed to generate ppt: {response.get("msg")}') + + id = response.get('data', {}).get('id') + cover_url = response.get('data', {}).get('cover_url') + + response = post( + str(self._api_base_url / 'download' / 'export' / 'file'), + headers=headers, + data={ + 'id': id, + 'format': 'ppt', + 'files_to_zip': False, + 'edit': True + } + ) + + if response.status_code != 200: + raise Exception(f'Failed to connect to aippt: {response.text}') + + response = response.json() + if response.get('code') != 0: + raise Exception(f'Failed to generate ppt: {response.get("msg")}') + + export_code = response.get('data') + if not export_code: + raise Exception('Failed to generate ppt, the export code is empty') + + current_iteration = 0 + while current_iteration < 50: + # get ppt url + response = post( + str(self._api_base_url / 'download' / 'export' / 'file' / 'result'), + headers=headers, + data={ + 'task_key': export_code + } + ) + + if response.status_code != 200: + raise Exception(f'Failed to connect to aippt: {response.text}') + + response = response.json() + if response.get('code') != 0: + raise Exception(f'Failed to generate ppt: {response.get("msg")}') + + if response.get('msg') == '导出中': + current_iteration += 1 + sleep(2) + continue + + ppt_url = response.get('data', []) + if len(ppt_url) == 0: + raise Exception('Failed to generate ppt, the ppt url is empty') + + return cover_url, ppt_url[0] + + raise Exception('Failed to generate ppt, the export is timeout') + + @classmethod + def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str: + """ + Get API token + + :param credentials: the credentials + :return: the API token + """ + access_key = credentials['aippt_access_key'] + secret_key = credentials['aippt_secret_key'] + + cache_key = f'{access_key}#@#{user_id}' + + with cls._api_token_cache_lock: + # clear expired tokens + now = time() + for key in list(cls._api_token_cache.keys()): + if cls._api_token_cache[key]['expire'] < now: + del cls._api_token_cache[key] + + if cache_key in cls._api_token_cache: + return cls._api_token_cache[cache_key]['token'] + + # get token + headers = { + 'x-api-key': access_key, + 'x-timestamp': str(int(now)), + 'x-signature': cls._calculate_sign(access_key, secret_key, int(now)) + } + + param = { + 'uid': user_id, + 'channel': '' + } + + response = get( + str(cls._api_base_url / 'grant' / 'token'), + params=param, + headers=headers + ) + + if response.status_code != 200: + raise Exception(f'Failed to connect to aippt: {response.text}') + response = response.json() + if response.get('code') != 0: + raise Exception(f'Failed to connect to aippt: {response.get("msg")}') + + token = response.get('data', {}).get('token') + expire = response.get('data', {}).get('time_expire') + + with cls._api_token_cache_lock: + cls._api_token_cache[cache_key] = { + 'token': token, + 'expire': now + expire + } + + return token + + @classmethod + def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str: + return b64encode( + hmac_new( + key=secret_key.encode('utf-8'), + msg=f'GET@/api/grant/token/@{timestamp}'.encode(), + digestmod=sha1 + ).digest() + ).decode('utf-8') + + @classmethod + def _get_styles(cls, credentials: dict[str, str], user_id: str) -> tuple[list[dict], list[dict]]: + """ + Get styles + """ + + # check cache + with cls._style_cache_lock: + # clear expired styles + now = time() + for key in list(cls._style_cache.keys()): + if cls._style_cache[key]['expire'] < now: + del cls._style_cache[key] + + key = f'{credentials["aippt_access_key"]}#@#{user_id}' + if key in cls._style_cache: + return cls._style_cache[key]['colors'], cls._style_cache[key]['styles'] + + headers = { + 'x-channel': '', + 'x-api-key': credentials['aippt_access_key'], + 'x-token': cls._get_api_token(credentials=credentials, user_id=user_id) + } + response = get( + str(cls._api_base_url / 'template_component' / 'suit' / 'select'), + headers=headers + ) + + if response.status_code != 200: + raise Exception(f'Failed to connect to aippt: {response.text}') + + response = response.json() + + if response.get('code') != 0: + raise Exception(f'Failed to connect to aippt: {response.get("msg")}') + + colors = [{ + 'id': f'id-{item.get("id")}', + 'name': item.get('name'), + 'en_name': item.get('en_name', item.get('name')), + } for item in response.get('data', {}).get('colour') or []] + styles = [{ + 'id': f'id-{item.get("id")}', + 'name': item.get('title'), + } for item in response.get('data', {}).get('suit_style') or []] + + with cls._style_cache_lock: + cls._style_cache[key] = { + 'colors': colors, + 'styles': styles, + 'expire': now + 60 * 60 + } + + return colors, styles + + def get_styles(self, user_id: str) -> tuple[list[dict], list[dict]]: + """ + Get styles + + :param credentials: the credentials + :return: Tuple[list[dict[id, color]], list[dict[id, style]] + """ + if not self.runtime.credentials.get('aippt_access_key') or not self.runtime.credentials.get('aippt_secret_key'): + raise Exception('Please provide aippt credentials') + + return self._get_styles(credentials=self.runtime.credentials, user_id=user_id) + + def _get_suit(self, style_id: int, colour_id: int) -> int: + """ + Get suit + """ + headers = { + 'x-channel': '', + 'x-api-key': self.runtime.credentials['aippt_access_key'], + 'x-token': self._get_api_token(credentials=self.runtime.credentials, user_id='__dify_system__') + } + response = get( + str(self._api_base_url / 'template_component' / 'suit' / 'search'), + headers=headers, + params={ + 'style_id': style_id, + 'colour_id': colour_id, + 'page': 1, + 'page_size': 1 + } + ) + + if response.status_code != 200: + raise Exception(f'Failed to connect to aippt: {response.text}') + + response = response.json() + + if response.get('code') != 0: + raise Exception(f'Failed to connect to aippt: {response.get("msg")}') + + if len(response.get('data', {}).get('list') or []) > 0: + return response.get('data', {}).get('list')[0].get('id') + + raise Exception('Failed to get suit, the suit does not exist, please check the style and color') + + def get_runtime_parameters(self) -> list[ToolParameter]: + """ + Get runtime parameters + + Override this method to add runtime parameters to the tool. + """ + try: + colors, styles = self.get_styles(user_id='__dify_system__') + except Exception as e: + colors, styles = [ + {'id': -1, 'name': '__default__', 'en_name': '__default__'} + ], [ + {'id': -1, 'name': '__default__', 'en_name': '__default__'} + ] + + return [ + ToolParameter( + name='color', + label=I18nObject(zh_Hans='颜色', en_US='Color'), + human_description=I18nObject(zh_Hans='颜色', en_US='Color'), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + required=False, + default=colors[0]['id'], + options=[ + ToolParameterOption( + value=color['id'], + label=I18nObject(zh_Hans=color['name'], en_US=color['en_name']) + ) for color in colors + ] + ), + ToolParameter( + name='style', + label=I18nObject(zh_Hans='风格', en_US='Style'), + human_description=I18nObject(zh_Hans='风格', en_US='Style'), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + required=False, + default=styles[0]['id'], + options=[ + ToolParameterOption( + value=style['id'], + label=I18nObject(zh_Hans=style['name'], en_US=style['name']) + ) for style in styles + ] + ), + ] \ No newline at end of file diff --git a/api/core/tools/provider/builtin/aippt/tools/aippt.yaml b/api/core/tools/provider/builtin/aippt/tools/aippt.yaml new file mode 100644 index 0000000000..d35798ad66 --- /dev/null +++ b/api/core/tools/provider/builtin/aippt/tools/aippt.yaml @@ -0,0 +1,54 @@ +identity: + name: aippt + author: Dify + label: + en_US: AIPPT + zh_Hans: AIPPT +description: + human: + en_US: AI-generated PPT with one click, input your content topic, and let AI serve you one-stop + zh_Hans: AI一键生成PPT,输入你的内容主题,让AI为你一站式服务到底 + llm: A tool used to generate PPT with AI, input your content topic, and let AI generate PPT for you. +parameters: + - name: title + type: string + required: true + label: + en_US: Title + zh_Hans: 标题 + human_description: + en_US: The title of the PPT. + zh_Hans: PPT的标题。 + llm_description: The title of the PPT, which will be used to generate the PPT outline. + form: llm + - name: outline + type: string + required: false + label: + en_US: Outline + zh_Hans: 大纲 + human_description: + en_US: The outline of the PPT + zh_Hans: PPT的大纲 + llm_description: The outline of the PPT, which will be used to generate the PPT content. provide it if you have. + form: llm + - name: llm + type: select + required: true + label: + en_US: LLM model + zh_Hans: 生成大纲的LLM + options: + - value: aippt + label: + en_US: AIPPT default model + zh_Hans: AIPPT默认模型 + - value: wenxin + label: + en_US: Wenxin ErnieBot + zh_Hans: 文心一言 + default: aippt + human_description: + en_US: The LLM model used for generating PPT outline. + zh_Hans: 用于生成PPT大纲的LLM模型。 + form: form diff --git a/api/core/tools/provider/builtin/bing/_assets/icon.png b/api/core/tools/provider/builtin/bing/_assets/icon.png deleted file mode 100644 index 1a7b3225a9..0000000000 Binary files a/api/core/tools/provider/builtin/bing/_assets/icon.png and /dev/null differ diff --git a/api/core/tools/provider/builtin/bing/_assets/icon.svg b/api/core/tools/provider/builtin/bing/_assets/icon.svg new file mode 100644 index 0000000000..a94de7971d --- /dev/null +++ b/api/core/tools/provider/builtin/bing/_assets/icon.svg @@ -0,0 +1,40 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/bing/bing.yaml b/api/core/tools/provider/builtin/bing/bing.yaml index ee3aaf1968..9df836929c 100644 --- a/api/core/tools/provider/builtin/bing/bing.yaml +++ b/api/core/tools/provider/builtin/bing/bing.yaml @@ -9,7 +9,7 @@ identity: en_US: Bing Search zh_Hans: Bing 搜索 pt_BR: Bing Search - icon: icon.png + icon: icon.svg credentials_for_provider: subscription_key: type: secret-input diff --git a/api/core/tools/provider/builtin/gaode/_assets/icon.png b/api/core/tools/provider/builtin/gaode/_assets/icon.png deleted file mode 100644 index d4aec4cda8..0000000000 Binary files a/api/core/tools/provider/builtin/gaode/_assets/icon.png and /dev/null differ diff --git a/api/core/tools/provider/builtin/gaode/_assets/icon.svg b/api/core/tools/provider/builtin/gaode/_assets/icon.svg new file mode 100644 index 0000000000..0f5729e17a --- /dev/null +++ b/api/core/tools/provider/builtin/gaode/_assets/icon.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/gaode/gaode.yaml b/api/core/tools/provider/builtin/gaode/gaode.yaml index 158c8d975f..bca53b22e9 100644 --- a/api/core/tools/provider/builtin/gaode/gaode.yaml +++ b/api/core/tools/provider/builtin/gaode/gaode.yaml @@ -9,7 +9,7 @@ identity: en_US: Autonavi Open Platform service toolkit. zh_Hans: 高德开放平台服务工具包。 pt_BR: Kit de ferramentas de serviço Autonavi Open Platform. - icon: icon.png + icon: icon.svg credentials_for_provider: api_key: type: secret-input diff --git a/api/core/tools/provider/builtin/github/_assets/icon.png b/api/core/tools/provider/builtin/github/_assets/icon.png deleted file mode 100644 index c1615e8914..0000000000 Binary files a/api/core/tools/provider/builtin/github/_assets/icon.png and /dev/null differ diff --git a/api/core/tools/provider/builtin/github/_assets/icon.svg b/api/core/tools/provider/builtin/github/_assets/icon.svg new file mode 100644 index 0000000000..d56adb2c2f --- /dev/null +++ b/api/core/tools/provider/builtin/github/_assets/icon.svg @@ -0,0 +1,17 @@ + + + github [#142] + Created with Sketch. + + + + + + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/github/github.yaml b/api/core/tools/provider/builtin/github/github.yaml index 540eab147c..d529e639cc 100644 --- a/api/core/tools/provider/builtin/github/github.yaml +++ b/api/core/tools/provider/builtin/github/github.yaml @@ -9,7 +9,7 @@ identity: en_US: GitHub is an online software source code hosting service. zh_Hans: GitHub是一个在线软件源代码托管服务平台。 pt_BR: GitHub é uma plataforma online para serviços de hospedagem de código fonte de software. - icon: icon.png + icon: icon.svg credentials_for_provider: access_tokens: type: secret-input diff --git a/api/core/tools/provider/builtin/qrcode/_assets/icon.svg b/api/core/tools/provider/builtin/qrcode/_assets/icon.svg index d44bb0bca9..979bdda455 100644 --- a/api/core/tools/provider/builtin/qrcode/_assets/icon.svg +++ b/api/core/tools/provider/builtin/qrcode/_assets/icon.svg @@ -1,5 +1,4 @@ - diff --git a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py index a86f17a999..8db5b1f8e8 100644 --- a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py +++ b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.py @@ -2,14 +2,23 @@ import io import logging from typing import Any, Union -import qrcode +from qrcode.constants import ERROR_CORRECT_H, ERROR_CORRECT_L, ERROR_CORRECT_M, ERROR_CORRECT_Q +from qrcode.image.base import BaseImage from qrcode.image.pure import PyPNGImage +from qrcode.main import QRCode from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool class QRCodeGeneratorTool(BuiltinTool): + error_correction_levels = { + 'L': ERROR_CORRECT_L, # <=7% + 'M': ERROR_CORRECT_M, # <=15% + 'Q': ERROR_CORRECT_Q, # <=25% + 'H': ERROR_CORRECT_H, # <=30% + } + def _invoke(self, user_id: str, tool_parameters: dict[str, Any], @@ -17,19 +26,44 @@ class QRCodeGeneratorTool(BuiltinTool): """ invoke tools """ - # get expression + # get text content content = tool_parameters.get('content', '') if not content: return self.create_text_message('Invalid parameter content') + # get border size + border = tool_parameters.get('border', 0) + if border < 0 or border > 100: + return self.create_text_message('Invalid parameter border') + + # get error_correction + error_correction = tool_parameters.get('error_correction', '') + if error_correction not in self.error_correction_levels.keys(): + return self.create_text_message('Invalid parameter error_correction') + try: - img = qrcode.make(data=content, image_factory=PyPNGImage) - byte_stream = io.BytesIO() - img.save(byte_stream) - byte_array = byte_stream.getvalue() - return self.create_blob_message(blob=byte_array, + image = self._generate_qrcode(content, border, error_correction) + image_bytes = self._image_to_byte_array(image) + return self.create_blob_message(blob=image_bytes, meta={'mime_type': 'image/png'}, save_as=self.VARIABLE_KEY.IMAGE.value) except Exception: logging.exception(f'Failed to generate QR code for content: {content}') return self.create_text_message('Failed to generate QR code') + + def _generate_qrcode(self, content: str, border: int, error_correction: str) -> BaseImage: + qr = QRCode( + image_factory=PyPNGImage, + error_correction=self.error_correction_levels.get(error_correction), + border=border, + ) + qr.add_data(data=content) + qr.make(fit=True) + img = qr.make_image() + return img + + @staticmethod + def _image_to_byte_array(image: BaseImage) -> bytes: + byte_stream = io.BytesIO() + image.save(byte_stream) + return byte_stream.getvalue() diff --git a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.yaml b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.yaml index ca562ac094..8c8b8c449a 100644 --- a/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.yaml +++ b/api/core/tools/provider/builtin/qrcode/tools/qrcode_generator.yaml @@ -2,9 +2,9 @@ identity: name: qrcode_generator author: Bowen Liang label: - en_US: QR Code Generator - zh_Hans: 二维码生成器 - pt_BR: QR Code Generator + en_US: Generate QR Code + zh_Hans: 生成二维码 + pt_BR: Generate QR Code description: human: en_US: A tool for generating QR code image @@ -24,3 +24,53 @@ parameters: zh_Hans: 二维码文本内容 pt_BR: 二维码文本内容 form: llm + - name: error_correction + type: select + required: true + default: M + label: + en_US: Error Correction + zh_Hans: 容错等级 + pt_BR: Error Correction + human_description: + en_US: Error Correction in L, M, Q or H, from low to high, the bigger size of generated QR code with the better error correction effect + zh_Hans: 容错等级,可设置为低、中、偏高或高,从低到高,生成的二维码越大且容错效果越好 + pt_BR: Error Correction in L, M, Q or H, from low to high, the bigger size of generated QR code with the better error correction effect + options: + - value: L + label: + en_US: Low + zh_Hans: 低 + pt_BR: Low + - value: M + label: + en_US: Medium + zh_Hans: 中 + pt_BR: Medium + - value: Q + label: + en_US: Quartile + zh_Hans: 偏高 + pt_BR: Quartile + - value: H + label: + en_US: High + zh_Hans: 高 + pt_BR: High + form: form + - name: border + type: number + required: true + default: 2 + min: 0 + max: 100 + label: + en_US: border size + zh_Hans: 边框粗细 + pt_BR: border size + human_description: + en_US: border size(default to 2) + zh_Hans: 边框粗细的格数(默认为2) + pt_BR: border size(default to 2) + llm: border size, default to 2 + form: form diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py index cda89036dc..e449062718 100644 --- a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py @@ -2,11 +2,11 @@ import io import json from base64 import b64decode, b64encode from copy import deepcopy -from os.path import join from typing import Any, Union from httpx import get, post from PIL import Image +from yarl import URL from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption @@ -79,7 +79,7 @@ class StableDiffusionTool(BuiltinTool): # set model try: - url = join(base_url, 'sdapi/v1/options') + url = str(URL(base_url) / 'sdapi' / 'v1' / 'options') response = post(url, data=json.dumps({ 'sd_model_checkpoint': model })) @@ -153,8 +153,21 @@ class StableDiffusionTool(BuiltinTool): if not model: raise ToolProviderCredentialValidationError('Please input model') - response = get(url=f'{base_url}/sdapi/v1/sd-models', timeout=120) - if response.status_code != 200: + api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models') + response = get(url=api_url, timeout=10) + if response.status_code == 404: + # try draw a picture + self._invoke( + user_id='test', + tool_parameters={ + 'prompt': 'a cat', + 'width': 1024, + 'height': 1024, + 'steps': 1, + 'lora': '', + } + ) + elif response.status_code != 200: raise ToolProviderCredentialValidationError('Failed to get models') else: models = [d['model_name'] for d in response.json()] @@ -165,6 +178,23 @@ class StableDiffusionTool(BuiltinTool): except Exception as e: raise ToolProviderCredentialValidationError(f'Failed to get models, {e}') + def get_sd_models(self) -> list[str]: + """ + get sd models + """ + try: + base_url = self.runtime.credentials.get('base_url', None) + if not base_url: + return [] + api_url = str(URL(base_url) / 'sdapi' / 'v1' / 'sd-models') + response = get(url=api_url, timeout=10) + if response.status_code != 200: + return [] + else: + return [d['model_name'] for d in response.json()] + except Exception as e: + return [] + def img2img(self, base_url: str, lora: str, image_binary: bytes, prompt: str, negative_prompt: str, width: int, height: int, steps: int) \ @@ -192,7 +222,7 @@ class StableDiffusionTool(BuiltinTool): draw_options['prompt'] = prompt try: - url = join(base_url, 'sdapi/v1/img2img') + url = str(URL(base_url) / 'sdapi' / 'v1' / 'img2img') response = post(url, data=json.dumps(draw_options), timeout=120) if response.status_code != 200: return self.create_text_message('Failed to generate image') @@ -225,7 +255,7 @@ class StableDiffusionTool(BuiltinTool): draw_options['negative_prompt'] = negative_prompt try: - url = join(base_url, 'sdapi/v1/txt2img') + url = str(URL(base_url) / 'sdapi' / 'v1' / 'txt2img') response = post(url, data=json.dumps(draw_options), timeout=120) if response.status_code != 200: return self.create_text_message('Failed to generate image') @@ -269,5 +299,29 @@ class StableDiffusionTool(BuiltinTool): label=I18nObject(en_US=i.name, zh_Hans=i.name) ) for i in self.list_default_image_variables()]) ) + + if self.runtime.credentials: + try: + models = self.get_sd_models() + if len(models) != 0: + parameters.append( + ToolParameter(name='model', + label=I18nObject(en_US='Model', zh_Hans='Model'), + human_description=I18nObject( + en_US='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion', + zh_Hans='Stable Diffusion 的模型,您可以查看 Stable Diffusion 的官方文档', + ), + type=ToolParameter.ToolParameterType.SELECT, + form=ToolParameter.ToolParameterForm.FORM, + llm_description='Model of Stable Diffusion, you can check the official documentation of Stable Diffusion', + required=True, + default=models[0], + options=[ToolParameterOption( + value=i, + label=I18nObject(en_US=i, zh_Hans=i) + ) for i in models]) + ) + except: + pass return parameters diff --git a/api/core/tools/provider/builtin/twilio/twilio.py b/api/core/tools/provider/builtin/twilio/twilio.py index dbf30962f9..7984d7b3b1 100644 --- a/api/core/tools/provider/builtin/twilio/twilio.py +++ b/api/core/tools/provider/builtin/twilio/twilio.py @@ -1,5 +1,8 @@ from typing import Any +from twilio.base.exceptions import TwilioRestException +from twilio.rest import Client + from core.tools.errors import ToolProviderCredentialValidationError from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController @@ -7,19 +10,20 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl class TwilioProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict[str, Any]) -> None: try: - """ - SendMessageTool().fork_tool_runtime( - meta={ - "credentials": credentials, - } - ).invoke( - user_id="", - tool_parameters={ - "message": "Credential validation message", - "to_number": "+14846624384", - }, - ) - """ - pass + # Extract credentials + account_sid = credentials["account_sid"] + auth_token = credentials["auth_token"] + from_number = credentials["from_number"] + + # Initialize twilio client + client = Client(account_sid, auth_token) + + # fetch account + client.api.accounts(account_sid).fetch() + + except TwilioRestException as e: + raise ToolProviderCredentialValidationError(f"Twilio API error: {e.msg}") from e + except KeyError as e: + raise ToolProviderCredentialValidationError(f"Missing required credential: {e}") from e except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) + raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py b/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py index 5a536cca50..aca10e6a7f 100644 --- a/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py +++ b/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.py @@ -4,9 +4,10 @@ import httpx from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool +from core.tools.utils.uuid_utils import is_valid_uuid -class WecomRepositoriesTool(BuiltinTool): +class WecomGroupBotTool(BuiltinTool): def _invoke(self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ @@ -17,8 +18,9 @@ class WecomRepositoriesTool(BuiltinTool): return self.create_text_message('Invalid parameter content') hook_key = tool_parameters.get('hook_key', '') - if not hook_key: - return self.create_text_message('Invalid parameter hook_key') + if not is_valid_uuid(hook_key): + return self.create_text_message( + f'Invalid parameter hook_key ${hook_key}, not a valid UUID') msgtype = 'text' api_url = 'https://qyapi.weixin.qq.com/cgi-bin/webhook/send' diff --git a/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.yaml b/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.yaml index 52f3cf5731..ece1bbc927 100644 --- a/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.yaml +++ b/api/core/tools/provider/builtin/wecom/tools/wecom_group_bot.yaml @@ -14,7 +14,7 @@ description: llm: A tool for sending messages to a chat group on Wecom(企业微信) . parameters: - name: hook_key - type: string + type: secret-input required: true label: en_US: Wecom Group bot webhook key diff --git a/api/core/tools/provider/builtin/wecom/wecom.py b/api/core/tools/provider/builtin/wecom/wecom.py index 6380061b4f..7a2576b668 100644 --- a/api/core/tools/provider/builtin/wecom/wecom.py +++ b/api/core/tools/provider/builtin/wecom/wecom.py @@ -1,8 +1,8 @@ -from core.tools.provider.builtin.wecom.tools.wecom_group_bot import WecomRepositoriesTool +from core.tools.provider.builtin.wecom.tools.wecom_group_bot import WecomGroupBotTool from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController class WecomProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: - WecomRepositoriesTool() + WecomGroupBotTool() pass diff --git a/api/core/tools/provider/builtin/youtube/_assets/icon.png b/api/core/tools/provider/builtin/youtube/_assets/icon.png deleted file mode 100644 index 3ab7908a5d..0000000000 Binary files a/api/core/tools/provider/builtin/youtube/_assets/icon.png and /dev/null differ diff --git a/api/core/tools/provider/builtin/youtube/_assets/icon.svg b/api/core/tools/provider/builtin/youtube/_assets/icon.svg new file mode 100644 index 0000000000..83b0700fec --- /dev/null +++ b/api/core/tools/provider/builtin/youtube/_assets/icon.svg @@ -0,0 +1,11 @@ + + + + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/youtube/youtube.yaml b/api/core/tools/provider/builtin/youtube/youtube.yaml index 98b2905dfa..2f83ae43ee 100644 --- a/api/core/tools/provider/builtin/youtube/youtube.yaml +++ b/api/core/tools/provider/builtin/youtube/youtube.yaml @@ -9,7 +9,7 @@ identity: en_US: YouTube zh_Hans: YouTube(油管)是全球最大的视频分享网站,用户可以在上面上传、观看和分享视频。 pt_BR: YouTube é o maior site de compartilhamento de vídeos do mundo, onde os usuários podem fazer upload, assistir e compartilhar vídeos. - icon: icon.png + icon: icon.svg credentials_for_provider: google_api_key: type: secret-input diff --git a/api/core/tools/provider/model_tool_provider.py b/api/core/tools/provider/model_tool_provider.py new file mode 100644 index 0000000000..ef47e9aae9 --- /dev/null +++ b/api/core/tools/provider/model_tool_provider.py @@ -0,0 +1,244 @@ +from copy import deepcopy +from typing import Any + +from core.entities.model_entities import ModelStatus +from core.errors.error import ProviderTokenNotInitError +from core.model_manager import ModelInstance +from core.model_runtime.entities.model_entities import ModelFeature, ModelType +from core.provider_manager import ProviderConfiguration, ProviderManager, ProviderModelBundle +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ModelToolPropertyKey, + ToolDescription, + ToolIdentity, + ToolParameter, + ToolProviderCredentials, + ToolProviderIdentity, + ToolProviderType, +) +from core.tools.errors import ToolNotFoundError +from core.tools.provider.tool_provider import ToolProviderController +from core.tools.tool.model_tool import ModelTool +from core.tools.tool.tool import Tool +from core.tools.utils.configuration import ModelToolConfigurationManager + + +class ModelToolProviderController(ToolProviderController): + configuration: ProviderConfiguration = None + is_active: bool = False + + def __init__(self, configuration: ProviderConfiguration = None, **kwargs): + """ + init the provider + + :param data: the data of the provider + """ + super().__init__(**kwargs) + self.configuration = configuration + + @staticmethod + def from_db(configuration: ProviderConfiguration = None) -> 'ModelToolProviderController': + """ + init the provider from db + + :param configuration: the configuration of the provider + """ + # check if all models are active + if configuration is None: + return None + is_active = True + models = configuration.get_provider_models() + for model in models: + if model.status != ModelStatus.ACTIVE: + is_active = False + break + + # get the provider configuration + model_tool_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider) + if model_tool_configuration is None: + raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}') + + # override the configuration + if model_tool_configuration.label: + label = deepcopy(model_tool_configuration.label) + if label.en_US: + label.en_US = model_tool_configuration.label.en_US + if label.zh_Hans: + label.zh_Hans = model_tool_configuration.label.zh_Hans + else: + label = I18nObject( + en_US=configuration.provider.label.en_US, + zh_Hans=configuration.provider.label.zh_Hans + ) + + return ModelToolProviderController( + is_active=is_active, + identity=ToolProviderIdentity( + author='Dify', + name=configuration.provider.provider, + description=I18nObject( + zh_Hans=f'{label.zh_Hans} 模型能力提供商', + en_US=f'{label.en_US} model capability provider' + ), + label=I18nObject( + zh_Hans=label.zh_Hans, + en_US=label.en_US + ), + icon=configuration.provider.icon_small.en_US, + ), + configuration=configuration, + credentials_schema={}, + ) + + @staticmethod + def is_configuration_valid(configuration: ProviderConfiguration) -> bool: + """ + check if the configuration has a model can be used as a tool + """ + models = configuration.get_provider_models() + for model in models: + if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []): + return True + return False + + def _get_model_tools(self, tenant_id: str = None) -> list[ModelTool]: + """ + returns a list of tools that the provider can provide + + :return: list of tools + """ + tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff' + provider_manager = ProviderManager() + if self.configuration is None: + configurations = provider_manager.get_configurations(tenant_id=tenant_id).values() + self.configuration = next(filter(lambda x: x.provider == self.identity.name, configurations), None) + # get all tools + tools: list[ModelTool] = [] + # get all models + if not self.configuration: + return tools + configuration = self.configuration + + provider_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider) + if provider_configuration is None: + raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}') + + for model in configuration.get_provider_models(): + model_configuration = ModelToolConfigurationManager.get_model_configuration(self.configuration.provider.provider, model.model) + if model_configuration is None: + continue + + if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []): + provider_instance = configuration.get_provider_instance() + model_type_instance = provider_instance.get_model_instance(model.model_type) + provider_model_bundle = ProviderModelBundle( + configuration=configuration, + provider_instance=provider_instance, + model_type_instance=model_type_instance + ) + + try: + model_instance = ModelInstance(provider_model_bundle, model.model) + except ProviderTokenNotInitError: + model_instance = None + + tools.append(ModelTool( + identity=ToolIdentity( + author='Dify', + name=model.model, + label=model_configuration.label, + ), + parameters=[ + ToolParameter( + name=ModelToolPropertyKey.IMAGE_PARAMETER_NAME.value, + label=I18nObject(zh_Hans='图片ID', en_US='Image ID'), + human_description=I18nObject(zh_Hans='图片ID', en_US='Image ID'), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + required=True, + default=Tool.VARIABLE_KEY.IMAGE.value + ) + ], + description=ToolDescription( + human=I18nObject(zh_Hans='图生文工具', en_US='Convert image to text'), + llm='Vision tool used to extract text and other visual information from images, can be used for OCR, image captioning, etc.', + ), + is_team_authorization=model.status == ModelStatus.ACTIVE, + tool_type=ModelTool.ModelToolType.VISION, + model_instance=model_instance, + model=model.model, + )) + + self.tools = tools + return tools + + def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: + """ + returns the credentials schema of the provider + + :return: the credentials schema + """ + return {} + + def get_tools(self, user_id: str, tenant_id: str) -> list[ModelTool]: + """ + returns a list of tools that the provider can provide + + :return: list of tools + """ + return self._get_model_tools(tenant_id=tenant_id) + + def get_tool(self, tool_name: str) -> ModelTool: + """ + get tool by name + + :param tool_name: the name of the tool + :return: the tool + """ + if self.tools is None: + self.get_tools(user_id='', tenant_id=self.configuration.tenant_id) + + for tool in self.tools: + if tool.identity.name == tool_name: + return tool + + raise ValueError(f'tool {tool_name} not found') + + def get_parameters(self, tool_name: str) -> list[ToolParameter]: + """ + returns the parameters of the tool + + :param tool_name: the name of the tool, defined in `get_tools` + :return: list of parameters + """ + tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) + if tool is None: + raise ToolNotFoundError(f'tool {tool_name} not found') + return tool.parameters + + @property + def app_type(self) -> ToolProviderType: + """ + returns the type of the provider + + :return: type of the provider + """ + return ToolProviderType.MODEL + + def validate_credentials(self, credentials: dict[str, Any]) -> None: + """ + validate the credentials of the provider + + :param tool_name: the name of the tool, defined in `get_tools` + :param credentials: the credentials of the tool + """ + pass + + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + """ + validate the credentials of the provider + + :param tool_name: the name of the tool, defined in `get_tools` + :param credentials: the credentials of the tool + """ + pass \ No newline at end of file diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index 31519734ed..fa7e7567dd 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -12,6 +12,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.errors import ToolProviderCredentialValidationError from core.tools.tool.tool import Tool +API_TOOL_DEFAULT_TIMEOUT = (10, 60) class ApiTool(Tool): api_bundle: ApiBasedToolBundle @@ -211,19 +212,19 @@ class ApiTool(Tool): # do http request if method == 'get': - response = ssrf_proxy.get(url, params=params, headers=headers, cookies=cookies, timeout=10, follow_redirects=True) + response = ssrf_proxy.get(url, params=params, headers=headers, cookies=cookies, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True) elif method == 'post': - response = ssrf_proxy.post(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, follow_redirects=True) + response = ssrf_proxy.post(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True) elif method == 'put': - response = ssrf_proxy.put(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, follow_redirects=True) + response = ssrf_proxy.put(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True) elif method == 'delete': - response = ssrf_proxy.delete(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, allow_redirects=True) + response = ssrf_proxy.delete(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, allow_redirects=True) elif method == 'patch': - response = ssrf_proxy.patch(url, params=params, headers=headers, cookies=cookies, data=body, timeout=10, follow_redirects=True) + response = ssrf_proxy.patch(url, params=params, headers=headers, cookies=cookies, data=body, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True) elif method == 'head': - response = ssrf_proxy.head(url, params=params, headers=headers, cookies=cookies, timeout=10, follow_redirects=True) + response = ssrf_proxy.head(url, params=params, headers=headers, cookies=cookies, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True) elif method == 'options': - response = ssrf_proxy.options(url, params=params, headers=headers, cookies=cookies, timeout=10, follow_redirects=True) + response = ssrf_proxy.options(url, params=params, headers=headers, cookies=cookies, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True) else: raise ValueError(f'Invalid http method {method}') diff --git a/api/core/tools/tool/model_tool.py b/api/core/tools/tool/model_tool.py new file mode 100644 index 0000000000..84e6610c75 --- /dev/null +++ b/api/core/tools/tool/model_tool.py @@ -0,0 +1,156 @@ +from base64 import b64encode +from enum import Enum +from typing import Any, cast + +from core.model_manager import ModelInstance +from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.message_entities import ( + PromptMessageContent, + PromptMessageContentType, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.tools.entities.tool_entities import ModelToolPropertyKey, ToolInvokeMessage +from core.tools.tool.tool import Tool + +VISION_PROMPT = """## Image Recognition Task +### Task Description +I require a powerful vision language model for an image recognition task. The model should be capable of extracting various details from the images, including but not limited to text content, layout distribution, color distribution, main subjects, and emotional expressions. +### Specific Requirements +1. **Text Content Extraction:** Ensure that the model accurately recognizes and extracts text content from the images, regardless of text size, font, or color. +2. **Layout Distribution Analysis:** The model should analyze the layout structure of the images, capturing the relationships between various elements and providing detailed information about the image layout. +3. **Color Distribution Analysis:** Extract information about color distribution in the images, including primary colors, color combinations, and other relevant details. +4. **Main Subject Recognition:** The model should accurately identify the main subjects in the images and provide detailed descriptions of these subjects. +5. **Emotional Expression Analysis:** Analyze and describe the emotions or expressions conveyed in the images based on facial expressions, postures, and other relevant features. +### Additional Considerations +- Ensure that the extracted information is as comprehensive and accurate as possible. +- For each task, provide confidence scores or relevance scores for the model outputs to assess the reliability of the results. +- If necessary, pose specific questions for different tasks to guide the model in better understanding the images and providing relevant information.""" + +class ModelTool(Tool): + class ModelToolType(Enum): + """ + the type of the model tool + """ + VISION = 'vision' + + model_configuration: dict[str, Any] = None + tool_type: ModelToolType + + def __init__(self, model_instance: ModelInstance = None, model: str = None, + tool_type: ModelToolType = ModelToolType.VISION, + properties: dict[ModelToolPropertyKey, Any] = None, + **kwargs): + """ + init the tool + """ + kwargs['model_configuration'] = { + 'model_instance': model_instance, + 'model': model, + 'properties': properties + } + kwargs['tool_type'] = tool_type + super().__init__(**kwargs) + + """ + Model tool + """ + def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool': + """ + fork a new tool with meta data + + :param meta: the meta data of a tool call processing, tenant_id is required + :return: the new tool + """ + return self.__class__( + identity=self.identity.copy() if self.identity else None, + parameters=self.parameters.copy() if self.parameters else None, + description=self.description.copy() if self.description else None, + model_instance=self.model_configuration['model_instance'], + model=self.model_configuration['model'], + tool_type=self.tool_type, + runtime=Tool.Runtime(**meta) + ) + + def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> None: + """ + validate the credentials for Model tool + """ + pass + + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + """ + """ + model_instance = self.model_configuration['model_instance'] + if not model_instance: + return self.create_text_message('the tool is not configured correctly') + + if self.tool_type == ModelTool.ModelToolType.VISION: + return self._invoke_llm_vision(user_id, tool_parameters) + else: + return self.create_text_message('the tool is not configured correctly') + + def _invoke_llm_vision(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: + # get image + image_parameter_name = self.model_configuration['properties'].get(ModelToolPropertyKey.IMAGE_PARAMETER_NAME, 'image_id') + image_id = tool_parameters.pop(image_parameter_name, '') + if not image_id: + image = self.get_default_image_variable() + if not image: + return self.create_text_message('Please upload an image or input image_id') + else: + image = self.get_variable(image_id) + if not image: + image = self.get_default_image_variable() + if not image: + return self.create_text_message('Please upload an image or input image_id') + + if not image: + return self.create_text_message('Please upload an image or input image_id') + + # get image + image = self.get_variable_file(image.name) + if not image: + return self.create_text_message('Failed to get image') + + # organize prompt messages + prompt_messages = [ + SystemPromptMessage( + content=VISION_PROMPT + ), + UserPromptMessage( + content=[ + PromptMessageContent( + type=PromptMessageContentType.TEXT, + data='Recognize the image and extract the information from the image.' + ), + PromptMessageContent( + type=PromptMessageContentType.IMAGE, + data=f'data:image/png;base64,{b64encode(image).decode("utf-8")}' + ) + ] + ) + ] + + llm_instance = cast(LargeLanguageModel, self.model_configuration['model_instance']) + result: LLMResult = llm_instance.invoke( + model=self.model_configuration['model'], + credentials=self.runtime.credentials, + prompt_messages=prompt_messages, + model_parameters=tool_parameters, + tools=[], + stop=[], + stream=False, + user=user_id, + ) + + if not result: + return self.create_text_message('Failed to extract information from the image') + + # get result + content = result.message.content + if not content: + return self.create_text_message('Failed to extract information from the image') + + return self.create_text_message(content) \ No newline at end of file diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 192793897e..351ae4362e 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -266,6 +266,40 @@ class Tool(BaseModel, ABC): """ return self.parameters + def get_all_runtime_parameters(self) -> list[ToolParameter]: + """ + get all runtime parameters + + :return: all runtime parameters + """ + parameters = self.parameters or [] + parameters = parameters.copy() + user_parameters = self.get_runtime_parameters() or [] + user_parameters = user_parameters.copy() + + # override parameters + for parameter in user_parameters: + # check if parameter in tool parameters + found = False + for tool_parameter in parameters: + if tool_parameter.name == parameter.name: + found = True + break + + if found: + # override parameter + tool_parameter.type = parameter.type + tool_parameter.form = parameter.form + tool_parameter.required = parameter.required + tool_parameter.default = parameter.default + tool_parameter.options = parameter.options + tool_parameter.llm_description = parameter.llm_description + else: + # add new parameter + parameters.append(parameter) + + return parameters + def is_tool_available(self) -> bool: """ check if the tool is available diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index fd4748db70..2ac8f27bab 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -6,20 +6,33 @@ from os import listdir, path from typing import Any, Union from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler +from core.entities.application_entities import AgentToolEntity from core.model_runtime.entities.message_entities import PromptMessage +from core.provider_manager import ProviderManager from core.tools.entities.common_entities import I18nObject from core.tools.entities.constant import DEFAULT_PROVIDERS -from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeMessage, ToolProviderCredentials +from core.tools.entities.tool_entities import ( + ApiProviderAuthType, + ToolInvokeMessage, + ToolParameter, + ToolProviderCredentials, +) from core.tools.entities.user_entities import UserToolProvider from core.tools.errors import ToolProviderNotFoundError from core.tools.provider.api_tool_provider import ApiBasedToolProviderController from core.tools.provider.app_tool_provider import AppBasedToolProviderEntity from core.tools.provider.builtin._positions import BuiltinToolProviderSort from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController +from core.tools.provider.model_tool_provider import ModelToolProviderController from core.tools.provider.tool_provider import ToolProviderController from core.tools.tool.api_tool import ApiTool from core.tools.tool.builtin_tool import BuiltinTool -from core.tools.utils.configuration import ToolConfiguration +from core.tools.tool.tool import Tool +from core.tools.utils.configuration import ( + ModelToolConfigurationManager, + ToolConfigurationManager, + ToolParameterConfigurationManager, +) from core.tools.utils.encoder import serialize_base_model_dict from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider @@ -135,7 +148,7 @@ class ToolManager: raise ToolProviderNotFoundError(f'provider type {provider_type} not found') @staticmethod - def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id, + def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str, agent_callback: DifyAgentCallbackHandler = None) \ -> Union[BuiltinTool, ApiTool]: """ @@ -170,7 +183,7 @@ class ToolManager: # decrypt the credentials credentials = builtin_provider.credentials controller = ToolManager.get_builtin_provider(provider_name) - tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) @@ -187,18 +200,96 @@ class ToolManager: api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name) # decrypt the credentials - tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=api_provider) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) return api_provider.get_tool(tool_name).fork_tool_runtime(meta={ 'tenant_id': tenant_id, 'credentials': decrypted_credentials, }) + elif provider_type == 'model': + if tenant_id is None: + raise ValueError('tenant id is required for model provider') + # get model provider + model_provider = ToolManager.get_model_provider(tenant_id, provider_name) + + # get tool + model_tool = model_provider.get_tool(tool_name) + + return model_tool.fork_tool_runtime(meta={ + 'tenant_id': tenant_id, + 'credentials': model_tool.model_configuration['model_instance'].credentials + }) elif provider_type == 'app': raise NotImplementedError('app provider not implemented') else: raise ToolProviderNotFoundError(f'provider type {provider_type} not found') + @staticmethod + def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool: + """ + get the agent tool runtime + """ + tool_entity = ToolManager.get_tool_runtime( + provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, tool_name=agent_tool.tool_name, + tenant_id=tenant_id, + agent_callback=agent_callback + ) + runtime_parameters = {} + parameters = tool_entity.get_all_runtime_parameters() + for parameter in parameters: + if parameter.form == ToolParameter.ToolParameterForm.FORM: + # get tool parameter from form + tool_parameter_config = agent_tool.tool_parameters.get(parameter.name) + if not tool_parameter_config: + # get default value + tool_parameter_config = parameter.default + if not tool_parameter_config and parameter.required: + raise ValueError(f"tool parameter {parameter.name} not found in tool config") + + if parameter.type == ToolParameter.ToolParameterType.SELECT: + # check if tool_parameter_config in options + options = list(map(lambda x: x.value, parameter.options)) + if tool_parameter_config not in options: + raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} not in options {options}") + + # convert tool parameter config to correct type + try: + if parameter.type == ToolParameter.ToolParameterType.NUMBER: + # check if tool parameter is integer + if isinstance(tool_parameter_config, int): + tool_parameter_config = tool_parameter_config + elif isinstance(tool_parameter_config, float): + tool_parameter_config = tool_parameter_config + elif isinstance(tool_parameter_config, str): + if '.' in tool_parameter_config: + tool_parameter_config = float(tool_parameter_config) + else: + tool_parameter_config = int(tool_parameter_config) + elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN: + tool_parameter_config = bool(tool_parameter_config) + elif parameter.type not in [ToolParameter.ToolParameterType.SELECT, ToolParameter.ToolParameterType.STRING]: + tool_parameter_config = str(tool_parameter_config) + elif parameter.type == ToolParameter.ToolParameterType: + tool_parameter_config = str(tool_parameter_config) + except Exception as e: + raise ValueError(f"tool parameter {parameter.name} value {tool_parameter_config} is not correct type") + + # save tool parameter to tool entity memory + runtime_parameters[parameter.name] = tool_parameter_config + + # decrypt runtime parameters + encryption_manager = ToolParameterConfigurationManager( + tenant_id=tenant_id, + tool_runtime=tool_entity, + provider_name=agent_tool.provider_id, + provider_type=agent_tool.provider_type, + ) + runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters) + + tool_entity.runtime.runtime_parameters.update(runtime_parameters) + return tool_entity + @staticmethod def get_builtin_provider_icon(provider: str) -> tuple[str, str]: """ @@ -266,6 +357,49 @@ class ToolManager: return builtin_providers + @staticmethod + def list_model_providers(tenant_id: str = None) -> list[ModelToolProviderController]: + """ + list all the model providers + + :return: the list of the model providers + """ + tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff' + # get configurations + model_configurations = ModelToolConfigurationManager.get_all_configuration() + # get all providers + provider_manager = ProviderManager() + configurations = provider_manager.get_configurations(tenant_id).values() + # get model providers + model_providers: list[ModelToolProviderController] = [] + for configuration in configurations: + # all the model tool should be configurated + if configuration.provider.provider not in model_configurations: + continue + if not ModelToolProviderController.is_configuration_valid(configuration): + continue + model_providers.append(ModelToolProviderController.from_db(configuration)) + + return model_providers + + @staticmethod + def get_model_provider(tenant_id: str, provider_name: str) -> ModelToolProviderController: + """ + get the model provider + + :param provider_name: the name of the provider + + :return: the provider + """ + # get configurations + provider_manager = ProviderManager() + configurations = provider_manager.get_configurations(tenant_id) + configuration = configurations.get(provider_name) + if configuration is None: + raise ToolProviderNotFoundError(f'model provider {provider_name} not found') + + return ModelToolProviderController.from_db(configuration) + @staticmethod def get_tool_label(tool_name: str) -> Union[I18nObject, None]: """ @@ -338,13 +472,35 @@ class ToolManager: controller = ToolManager.get_builtin_provider(provider_name) # init tool configuration - tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) # decrypt the credentials and mask the credentials decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials) result_providers[provider_name].team_credentials = masked_credentials + # get model tool providers + model_providers = ToolManager.list_model_providers(tenant_id=tenant_id) + # append model providers + for provider in model_providers: + result_providers[f'model_provider.{provider.identity.name}'] = UserToolProvider( + id=provider.identity.name, + author=provider.identity.author, + name=provider.identity.name, + description=I18nObject( + en_US=provider.identity.description.en_US, + zh_Hans=provider.identity.description.zh_Hans, + ), + icon=provider.identity.icon, + label=I18nObject( + en_US=provider.identity.label.en_US, + zh_Hans=provider.identity.label.zh_Hans, + ), + type=UserToolProvider.ProviderType.MODEL, + team_credentials={}, + is_team_authorization=provider.is_active, + ) + # get db api providers db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \ filter(ApiToolProvider.tenant_id == tenant_id).all() @@ -383,7 +539,7 @@ class ToolManager: ) # init tool configuration - tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) # decrypt the credentials and mask the credentials decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials) @@ -443,7 +599,7 @@ class ToolManager: provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE ) # init tool configuration - tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 56a442a223..927af1f5be 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -1,14 +1,23 @@ -from typing import Any +import os +from typing import Any, Union from pydantic import BaseModel +from yaml import FullLoader, load from core.helper import encrypter +from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType -from core.tools.entities.tool_entities import ToolProviderCredentials +from core.tools.entities.tool_entities import ( + ModelToolConfiguration, + ModelToolProviderConfiguration, + ToolParameter, + ToolProviderCredentials, +) from core.tools.provider.tool_provider import ToolProviderController +from core.tools.tool.tool import Tool -class ToolConfiguration(BaseModel): +class ToolConfigurationManager(BaseModel): tenant_id: str provider_controller: ToolProviderController @@ -94,3 +103,187 @@ class ToolConfiguration(BaseModel): cache_type=ToolProviderCredentialsCacheType.PROVIDER ) cache.delete() + +class ToolParameterConfigurationManager(BaseModel): + """ + Tool parameter configuration manager + """ + tenant_id: str + tool_runtime: Tool + provider_name: str + provider_type: str + + def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]: + """ + deep copy parameters + """ + return {key: value for key, value in parameters.items()} + + def _merge_parameters(self) -> list[ToolParameter]: + """ + merge parameters + """ + # get tool parameters + tool_parameters = self.tool_runtime.parameters or [] + # get tool runtime parameters + runtime_parameters = self.tool_runtime.get_runtime_parameters() or [] + # override parameters + current_parameters = tool_parameters.copy() + for runtime_parameter in runtime_parameters: + found = False + for index, parameter in enumerate(current_parameters): + if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form: + current_parameters[index] = runtime_parameter + found = True + break + + if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: + current_parameters.append(runtime_parameter) + + return current_parameters + + def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: + """ + mask tool parameters + + return a deep copy of parameters with masked values + """ + parameters = self._deep_copy(parameters) + + # override parameters + current_parameters = self._merge_parameters() + + for parameter in current_parameters: + if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: + if parameter.name in parameters: + if len(parameters[parameter.name]) > 6: + parameters[parameter.name] = \ + parameters[parameter.name][:2] + \ + '*' * (len(parameters[parameter.name]) - 4) +\ + parameters[parameter.name][-2:] + else: + parameters[parameter.name] = '*' * len(parameters[parameter.name]) + + return parameters + + def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: + """ + encrypt tool parameters with tenant id + + return a deep copy of parameters with encrypted values + """ + # override parameters + current_parameters = self._merge_parameters() + + for parameter in current_parameters: + if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: + if parameter.name in parameters: + encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name]) + parameters[parameter.name] = encrypted + + return parameters + + def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: + """ + decrypt tool parameters with tenant id + + return a deep copy of parameters with decrypted values + """ + cache = ToolParameterCache( + tenant_id=self.tenant_id, + provider=f'{self.provider_type}.{self.provider_name}', + tool_name=self.tool_runtime.identity.name, + cache_type=ToolParameterCacheType.PARAMETER + ) + cached_parameters = cache.get() + if cached_parameters: + return cached_parameters + + # override parameters + current_parameters = self._merge_parameters() + has_secret_input = False + + for parameter in current_parameters: + if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT: + if parameter.name in parameters: + try: + has_secret_input = True + parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name]) + except: + pass + + if has_secret_input: + cache.set(parameters) + + return parameters + + def delete_tool_parameters_cache(self): + cache = ToolParameterCache( + tenant_id=self.tenant_id, + provider=f'{self.provider_type}.{self.provider_name}', + tool_name=self.tool_runtime.identity.name, + cache_type=ToolParameterCacheType.PARAMETER + ) + cache.delete() + +class ModelToolConfigurationManager: + """ + Model as tool configuration + """ + _configurations: dict[str, ModelToolProviderConfiguration] = {} + _model_configurations: dict[str, ModelToolConfiguration] = {} + _inited = False + + @classmethod + def _init_configuration(cls): + """ + init configuration + """ + + absolute_path = os.path.abspath(os.path.dirname(__file__)) + model_tools_path = os.path.join(absolute_path, '..', 'model_tools') + + # get all .yaml file + files = [f for f in os.listdir(model_tools_path) if f.endswith('.yaml')] + + for file in files: + provider = file.split('.')[0] + with open(os.path.join(model_tools_path, file), encoding='utf-8') as f: + configurations = ModelToolProviderConfiguration(**load(f, Loader=FullLoader)) + models = configurations.models or [] + for model in models: + model_key = f'{provider}.{model.model}' + cls._model_configurations[model_key] = model + + cls._configurations[provider] = configurations + cls._inited = True + + @classmethod + def get_configuration(cls, provider: str) -> Union[ModelToolProviderConfiguration, None]: + """ + get configuration by provider + """ + if not cls._inited: + cls._init_configuration() + return cls._configurations.get(provider, None) + + @classmethod + def get_all_configuration(cls) -> dict[str, ModelToolProviderConfiguration]: + """ + get all configurations + """ + if not cls._inited: + cls._init_configuration() + return cls._configurations + + @classmethod + def get_model_configuration(cls, provider: str, model: str) -> Union[ModelToolConfiguration, None]: + """ + get model configuration + """ + key = f'{provider}.{model}' + + if not cls._inited: + cls._init_configuration() + + return cls._model_configurations.get(key, None) \ No newline at end of file diff --git a/api/core/tools/utils/uuid_utils.py b/api/core/tools/utils/uuid_utils.py new file mode 100644 index 0000000000..3046c08c89 --- /dev/null +++ b/api/core/tools/utils/uuid_utils.py @@ -0,0 +1,9 @@ +import uuid + + +def is_valid_uuid(uuid_str: str) -> bool: + try: + uuid.UUID(uuid_str) + return True + except Exception: + return False diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py index 87b15697ba..d2c6e32dfd 100644 --- a/api/extensions/ext_mail.py +++ b/api/extensions/ext_mail.py @@ -32,8 +32,6 @@ class Mail: from libs.smtp import SMTPClient if not app.config.get('SMTP_SERVER') or not app.config.get('SMTP_PORT'): raise ValueError('SMTP_SERVER and SMTP_PORT are required for smtp mail type') - if not app.config.get('SMTP_USERNAME') or not app.config.get('SMTP_PASSWORD'): - raise ValueError('SMTP_USERNAME and SMTP_PASSWORD are required for smtp mail type') self._client = SMTPClient( server=app.config.get('SMTP_SERVER'), port=app.config.get('SMTP_PORT'), diff --git a/api/libs/smtp.py b/api/libs/smtp.py index 17fcc6de7c..6c8e0c2777 100644 --- a/api/libs/smtp.py +++ b/api/libs/smtp.py @@ -16,7 +16,8 @@ class SMTPClient: smtp = smtplib.SMTP(self.server, self.port) if self._use_tls: smtp.starttls() - smtp.login(self.username, self.password) + if (self.username): + smtp.login(self.username, self.password) msg = MIMEMultipart() msg['Subject'] = mail['subject'] msg['From'] = self._from diff --git a/api/requirements.txt b/api/requirements.txt index 9721c3a13d..847903c4f4 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -32,7 +32,7 @@ celery==5.2.7 redis~=4.5.4 openpyxl==3.1.2 chardet~=5.1.0 -docx2txt==0.8 +python-docx~=1.1.0 pypdfium2==4.16.0 resend~=0.7.0 pyjwt~=2.8.0 diff --git a/api/services/account_service.py b/api/services/account_service.py index e35d325ae4..103af7f79c 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -15,7 +15,7 @@ from events.tenant_event import tenant_was_created from extensions.ext_redis import redis_client from libs.helper import get_remote_ip from libs.passport import PassportService -from libs.password import compare_password, hash_password +from libs.password import compare_password, hash_password, valid_password from libs.rsa import generate_key_pair from models.account import * from services.errors.account import ( @@ -58,7 +58,7 @@ class AccountService: account.current_tenant_id = available_ta.tenant_id available_ta.current = True db.session.commit() - + if datetime.utcnow() - account.last_active_at > timedelta(minutes=10): account.last_active_at = datetime.utcnow() db.session.commit() @@ -104,6 +104,9 @@ class AccountService: if account.password and not compare_password(password, account.password, account.password_salt): raise CurrentPasswordIncorrectError("Current password is incorrect.") + # may be raised + valid_password(new_password) + # generate password salt salt = secrets.token_bytes(16) base64_salt = base64.b64encode(salt).decode() @@ -140,9 +143,9 @@ class AccountService: account.interface_language = interface_language account.interface_theme = interface_theme - + # Set timezone based on language - account.timezone = language_timezone_mapping.get(interface_language, 'UTC') + account.timezone = language_timezone_mapping.get(interface_language, 'UTC') db.session.add(account) db.session.commit() @@ -279,7 +282,7 @@ class TenantService: tenant_account_join = TenantAccountJoin.query.filter_by(account_id=account.id, tenant_id=tenant_id).first() if not tenant_account_join: raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") - else: + else: TenantAccountJoin.query.filter(TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id).update({'current': False}) tenant_account_join.current = True # Set the current tenant for the account @@ -449,7 +452,7 @@ class RegisterService: return account @classmethod - def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str = 'normal', inviter: Account = None) -> str: + def invite_new_member(cls, tenant: Tenant, email: str, language: str, role: str = 'normal', inviter: Account = None) -> str: """Invite new member""" account = Account.query.filter_by(email=email).first() diff --git a/api/services/tools_manage_service.py b/api/services/tools_manage_service.py index 0e3d481640..ff618e5d2b 100644 --- a/api/services/tools_manage_service.py +++ b/api/services/tools_manage_service.py @@ -9,6 +9,7 @@ from core.tools.entities.tool_entities import ( ApiProviderAuthType, ApiProviderSchemaType, ToolCredentialsOption, + ToolParameter, ToolProviderCredentials, ) from core.tools.entities.user_entities import UserTool, UserToolProvider @@ -16,11 +17,12 @@ from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidatio from core.tools.provider.api_tool_provider import ApiBasedToolProviderController from core.tools.provider.tool_provider import ToolProviderController from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import ToolConfiguration +from core.tools.utils.configuration import ToolConfigurationManager from core.tools.utils.encoder import serialize_base_model_array, serialize_base_model_dict from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider +from services.model_provider_service import ModelProviderService class ToolManageService: @@ -49,11 +51,13 @@ class ToolManageService: :param provider: the provider dict """ url_prefix = (current_app.config.get("CONSOLE_API_URL") - + "/console/api/workspaces/current/tool-provider/builtin/") + + "/console/api/workspaces/current/tool-provider/") if 'icon' in provider: if provider['type'] == UserToolProvider.ProviderType.BUILTIN.value: - provider['icon'] = url_prefix + provider['name'] + '/icon' + provider['icon'] = url_prefix + 'builtin/' + provider['name'] + '/icon' + elif provider['type'] == UserToolProvider.ProviderType.MODEL.value: + provider['icon'] = url_prefix + 'model/' + provider['name'] + '/icon' elif provider['type'] == UserToolProvider.ProviderType.API.value: try: provider['icon'] = json.loads(provider['icon']) @@ -73,15 +77,52 @@ class ToolManageService: provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider) tools = provider_controller.get_tools() - result = [ - UserTool( + tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) + # check if user has added the provider + builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == provider, + ).first() + + credentials = {} + if builtin_provider is not None: + # get credentials + credentials = builtin_provider.credentials + credentials = tool_provider_configurations.decrypt_tool_credentials(credentials) + + result = [] + for tool in tools: + # fork tool runtime + tool = tool.fork_tool_runtime(meta={ + 'credentials': credentials, + 'tenant_id': tenant_id, + }) + + # get tool parameters + parameters = tool.parameters or [] + # get tool runtime parameters + runtime_parameters = tool.get_runtime_parameters() + # override parameters + current_parameters = parameters.copy() + for runtime_parameter in runtime_parameters: + found = False + for index, parameter in enumerate(current_parameters): + if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form: + current_parameters[index] = runtime_parameter + found = True + break + + if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: + current_parameters.append(runtime_parameter) + + user_tool = UserTool( author=tool.identity.author, name=tool.identity.name, label=tool.identity.label, description=tool.description.human, - parameters=tool.parameters or [] - ) for tool in tools - ] + parameters=current_parameters + ) + result.append(user_tool) return json.loads( serialize_base_model_array(result) @@ -238,7 +279,7 @@ class ToolManageService: provider_controller.load_bundled_tools(tool_bundles) # encrypt credentials - tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials) db_provider.credentials_str = json.dumps(encrypted_credentials) @@ -325,7 +366,7 @@ class ToolManageService: provider_controller = ToolManager.get_builtin_provider(provider_name) if not provider_controller.need_credentials: raise ValueError(f'provider {provider_name} does not need credentials') - tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) # get original credentials if exists if provider is not None: original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) @@ -409,7 +450,7 @@ class ToolManageService: provider_controller.load_bundled_tools(tool_bundles) # get original credentials if exists - tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) @@ -449,7 +490,7 @@ class ToolManageService: # delete cache provider_controller = ToolManager.get_builtin_provider(provider_name) - tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=provider_controller) + tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) tool_configuration.delete_tool_credentials_cache() return { 'result': 'success' } @@ -467,6 +508,46 @@ class ToolManageService: return icon_bytes, mime_type + @staticmethod + def get_model_tool_provider_icon( + provider: str + ): + """ + get tool provider icon and it's mimetype + """ + + service = ModelProviderService() + icon_bytes, mime_type = service.get_model_provider_icon(provider=provider, icon_type='icon_small', lang='en_US') + + if icon_bytes is None: + raise ValueError(f'provider {provider} does not exists') + + return icon_bytes, mime_type + + @staticmethod + def list_model_tool_provider_tools( + user_id: str, tenant_id: str, provider: str + ): + """ + list model tool provider tools + """ + provider_controller = ToolManager.get_model_provider(tenant_id=tenant_id, provider_name=provider) + tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id) + + result = [ + UserTool( + author=tool.identity.author, + name=tool.identity.name, + label=tool.identity.label, + description=tool.description.human, + parameters=tool.parameters or [] + ) for tool in tools + ] + + return json.loads( + serialize_base_model_array(result) + ) + @staticmethod def delete_api_tool_provider( user_id: str, tenant_id: str, provider_name: str @@ -551,7 +632,7 @@ class ToolManageService: # decrypt credentials if db_provider.id: - tool_configuration = ToolConfiguration( + tool_configuration = ToolConfigurationManager( tenant_id=tenant_id, provider_controller=provider_controller ) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx index cd55faf8e1..93141dd86e 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/chartView.tsx @@ -29,7 +29,17 @@ export default function ChartView({ appId }: IChartViewProps) { const [period, setPeriod] = useState({ name: t('appLog.filter.period.last7days'), query: { start: today.subtract(7, 'day').format(queryDateFormat), end: today.format(queryDateFormat) } }) const onSelect = (item: Item) => { - setPeriod({ name: item.name, query: item.value === 'all' ? undefined : { start: today.subtract(item.value as number, 'day').format(queryDateFormat), end: today.format(queryDateFormat) } }) + if (item.value === 'all') { + setPeriod({ name: item.name, query: undefined }) + } + else if (item.value === 0) { + const startOfToday = today.startOf('day').format(queryDateFormat) + const endOfToday = today.endOf('day').format(queryDateFormat) + setPeriod({ name: item.name, query: { start: startOfToday, end: endOfToday } }) + } + else { + setPeriod({ name: item.name, query: { start: today.subtract(item.value as number, 'day').format(queryDateFormat), end: today.format(queryDateFormat) } }) + } } if (!response) diff --git a/web/app/activate/activateForm.tsx b/web/app/activate/activateForm.tsx index 3cf88ce281..be3706037b 100644 --- a/web/app/activate/activateForm.tsx +++ b/web/app/activate/activateForm.tsx @@ -62,8 +62,10 @@ const ActivateForm = () => { showErrorMessage(t('login.error.passwordEmpty')) return false } - if (!validPassword.test(password)) + if (!validPassword.test(password)) { showErrorMessage(t('login.error.passwordInvalid')) + return false + } return true }, [name, password, showErrorMessage, t]) diff --git a/web/app/components/app/configuration/base/warning-mask/index.tsx b/web/app/components/app/configuration/base/warning-mask/index.tsx index 550c3f73a8..03df4f16df 100644 --- a/web/app/components/app/configuration/base/warning-mask/index.tsx +++ b/web/app/components/app/configuration/base/warning-mask/index.tsx @@ -24,7 +24,7 @@ const WarningMask: FC = ({ return (
-
+
{warningIcon}
{title} diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index 424b77f0cd..caf2392d36 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -25,6 +25,7 @@ import { useToastContext } from '@/app/components/base/toast' import { useEventEmitterContextContext } from '@/context/event-emitter' import { ADD_EXTERNAL_DATA_TOOL } from '@/app/components/app/configuration/config-var' import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block' +import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '@/app/components/base/prompt-editor/plugins/update-block' export type ISimplePromptInput = { mode: AppType @@ -122,6 +123,10 @@ const Prompt: FC = ({ if (mode === AppType.chat) setIntroduction(res.opening_statement) showAutomaticFalse() + eventEmitter?.emit({ + type: PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER, + payload: res.prompt, + } as any) } const minHeight = 228 const [editorHeight, setEditorHeight] = useState(minHeight) diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx index cd13636c94..95858d9540 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx @@ -34,7 +34,7 @@ const AgentTools: FC = () => { const [selectedProviderId, setSelectedProviderId] = useState(undefined) const [isShowSettingTool, setIsShowSettingTool] = useState(false) const tools = (modelConfig?.agentConfig?.tools as AgentTool[] || []).map((item) => { - const collection = collectionList.find(collection => collection.id === item.provider_id) + const collection = collectionList.find(collection => collection.id === item.provider_id && collection.type === item.provider_type) const icon = collection?.icon return { ...item, diff --git a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx index 9892018ab6..378054aae6 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/setting-built-in-tool.tsx @@ -8,7 +8,7 @@ import Drawer from '@/app/components/base/drawer-plus' import Form from '@/app/components/header/account-setting/model-provider-page/model-modal/Form' import { addDefaultValue, toolParametersToFormSchemas } from '@/app/components/tools/utils/to-form-schema' import type { Collection, Tool } from '@/app/components/tools/types' -import { fetchBuiltInToolList, fetchCustomToolList } from '@/service/tools' +import { fetchBuiltInToolList, fetchCustomToolList, fetchModelToolList } from '@/service/tools' import I18n from '@/context/i18n' import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' @@ -19,6 +19,7 @@ import AppIcon from '@/app/components/base/app-icon' type Props = { collection: Collection isBuiltIn?: boolean + isModel?: boolean toolName: string setting?: Record readonly?: boolean @@ -29,6 +30,7 @@ type Props = { const SettingBuiltInTool: FC = ({ collection, isBuiltIn = true, + isModel = true, toolName, setting = {}, readonly, @@ -56,7 +58,11 @@ const SettingBuiltInTool: FC = ({ (async () => { setIsLoading(true) try { - const list = isBuiltIn ? await fetchBuiltInToolList(collection.name) : await fetchCustomToolList(collection.name) + const list = isBuiltIn + ? await fetchBuiltInToolList(collection.name) + : isModel + ? await fetchModelToolList(collection.name) + : await fetchCustomToolList(collection.name) setTools(list) const currTool = list.find(tool => tool.name === toolName) if (currTool) { diff --git a/web/app/components/app/configuration/debug/index.tsx b/web/app/components/app/configuration/debug/index.tsx index 6cf42d07b6..a7fd2d5ef7 100644 --- a/web/app/components/app/configuration/debug/index.tsx +++ b/web/app/components/app/configuration/debug/index.tsx @@ -130,7 +130,7 @@ const Debug: FC = ({ const { notify } = useContext(ToastContext) const logError = useCallback((message: string) => { - notify({ type: 'error', message }) + notify({ type: 'error', message, duration: 3000 }) }, [notify]) const [completionFiles, setCompletionFiles] = useState([]) diff --git a/web/app/components/app/overview/settings/index.tsx b/web/app/components/app/overview/settings/index.tsx index 026ba6ae10..df0f3a0062 100644 --- a/web/app/components/app/overview/settings/index.tsx +++ b/web/app/components/app/overview/settings/index.tsx @@ -12,6 +12,7 @@ import { SimpleSelect } from '@/app/components/base/select' import type { AppDetailResponse } from '@/models/app' import type { Language } from '@/types/app' import EmojiPicker from '@/app/components/base/emoji-picker' +import { useToastContext } from '@/app/components/base/toast' import { languages } from '@/i18n/language' @@ -42,6 +43,7 @@ const SettingsModal: FC = ({ onClose, onSave, }) => { + const { notify } = useToastContext() const [isShowMore, setIsShowMore] = useState(false) const { icon, icon_background } = appInfo const { title, description, copyright, privacy_policy, default_language } = appInfo.site @@ -67,6 +69,10 @@ const SettingsModal: FC = ({ } const onClickSave = async () => { + if (!inputInfo.title) { + notify({ type: 'error', message: t('app.newApp.nameNotEmpty') }) + return + } setSaveLoading(true) const params = { title: inputInfo.title, diff --git a/web/app/components/base/chat/chat-with-history/config-panel/index.tsx b/web/app/components/base/chat/chat-with-history/config-panel/index.tsx index d90e4025d3..cf12b35d08 100644 --- a/web/app/components/base/chat/chat-with-history/config-panel/index.tsx +++ b/web/app/components/base/chat/chat-with-history/config-panel/index.tsx @@ -95,7 +95,10 @@ const ConfigPanel = () => { diff --git a/web/app/components/base/image-uploader/chat-image-uploader.tsx b/web/app/components/base/image-uploader/chat-image-uploader.tsx index 97f85a3137..4d34e9c15a 100644 --- a/web/app/components/base/image-uploader/chat-image-uploader.tsx +++ b/web/app/components/base/image-uploader/chat-image-uploader.tsx @@ -1,6 +1,7 @@ import type { FC } from 'react' import { useState } from 'react' import { useTranslation } from 'react-i18next' +import cn from 'classnames' import Uploader from './uploader' import ImageLinkInput from './image-link-input' import { ImagePlus } from '@/app/components/base/icons/src/vender/line/images' @@ -25,16 +26,16 @@ const UploadOnlyFromLocal: FC = ({ }) => { return ( - { - hovering => ( -
( +
- -
- ) - } + `} + > + +
+ )}
) } @@ -54,13 +55,16 @@ const UploaderButton: FC = ({ const { t } = useTranslation() const [open, setOpen] = useState(false) - const hasUploadFromLocal = methods.find(method => method === TransferMethod.local_file) + const hasUploadFromLocal = methods.find( + method => method === TransferMethod.local_file, + ) const handleUpload = (imageFile: ImageFile) => { - setOpen(false) onUpload(imageFile) } + const closePopover = () => setOpen(false) + const handleToggle = () => { if (disabled) return @@ -72,43 +76,46 @@ const UploaderButton: FC = ({ -
- -
+
- -
+ +
- { - hasUploadFromLocal && ( - <> -
-
- OR -
-
- - { - hovering => ( -
- - {t('common.imageUploader.uploadFromComputer')} -
- ) - } -
- - ) - } + {hasUploadFromLocal && ( + <> +
+
+ OR +
+
+ + {hovering => ( +
+ + {t('common.imageUploader.uploadFromComputer')} +
+ )} +
+ + )}
@@ -125,7 +132,9 @@ const ChatImageUploader: FC = ({ onUpload, disabled, }) => { - const onlyUploadLocal = settings.transfer_methods.length === 1 && settings.transfer_methods[0] === TransferMethod.local_file + const onlyUploadLocal + = settings.transfer_methods.length === 1 + && settings.transfer_methods[0] === TransferMethod.local_file if (onlyUploadLocal) { return ( diff --git a/web/app/components/base/image-uploader/image-link-input.tsx b/web/app/components/base/image-uploader/image-link-input.tsx index 6c1435db30..d9ca50ac3e 100644 --- a/web/app/components/base/image-uploader/image-link-input.tsx +++ b/web/app/components/base/image-uploader/image-link-input.tsx @@ -30,6 +30,7 @@ const ImageLinkInput: FC = ({ return (
setImageLink(e.target.value)} diff --git a/web/app/components/base/image-uploader/image-list.tsx b/web/app/components/base/image-uploader/image-list.tsx index b359d2edeb..6573815950 100644 --- a/web/app/components/base/image-uploader/image-list.tsx +++ b/web/app/components/base/image-uploader/image-list.tsx @@ -1,7 +1,11 @@ import type { FC } from 'react' import { useState } from 'react' import { useTranslation } from 'react-i18next' -import { Loading02, XClose } from '@/app/components/base/icons/src/vender/line/general' +import cn from 'classnames' +import { + Loading02, + XClose, +} from '@/app/components/base/icons/src/vender/line/general' import { RefreshCcw01 } from '@/app/components/base/icons/src/vender/line/arrows' import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' import TooltipPlus from '@/app/components/base/tooltip-plus' @@ -30,7 +34,11 @@ const ImageList: FC = ({ const [imagePreviewUrl, setImagePreviewUrl] = useState('') const handleImageLinkLoadSuccess = (item: ImageFile) => { - if (item.type === TransferMethod.remote_url && onImageLinkLoadSuccess && item.progress !== -1) + if ( + item.type === TransferMethod.remote_url + && onImageLinkLoadSuccess + && item.progress !== -1 + ) onImageLinkLoadSuccess(item._id) } const handleImageLinkLoadError = (item: ImageFile) => { @@ -39,89 +47,95 @@ const ImageList: FC = ({ } return ( -
- { - list.map(item => ( -
- { - item.type === TransferMethod.local_file && item.progress !== 100 && ( - <> -
-1 ? `${item.progress}%` : 0 }} - > - { - item.progress === -1 && ( - onReUpload && onReUpload(item._id)} /> - ) - } -
- { - item.progress > -1 && ( - {item.progress}% - ) - } - - ) - } - { - item.type === TransferMethod.remote_url && item.progress !== 100 && ( -
+ {list.map(item => ( +
+ {item.type === TransferMethod.local_file && item.progress !== 100 && ( + <> +
-1 ? `${item.progress}%` : 0 }} + > + {item.progress === -1 && ( + onReUpload && onReUpload(item._id)} + /> + )} +
+ {item.progress > -1 && ( + + {item.progress}% + + )} + + )} + {item.type === TransferMethod.remote_url && item.progress !== 100 && ( +
- { - item.progress > -1 && ( - - ) - } - { - item.progress === -1 && ( - - - - ) - } -
- ) + ${ + item.progress === -1 + ? 'bg-[#FEF0C7] border-[#DC6803]' + : 'bg-black/[0.16] border-transparent' } - handleImageLinkLoadSuccess(item)} - onError={() => handleImageLinkLoadError(item)} - src={item.type === TransferMethod.remote_url ? item.url : item.base64Url} - onClick={() => item.progress === 100 && setImagePreviewUrl((item.type === TransferMethod.remote_url ? item.url : item.base64Url) as string)} - /> - { - !readonly && ( -
onRemove && onRemove(item._id)} + `} + > + {item.progress > -1 && ( + + )} + {item.progress === -1 && ( + - -
+ + + )} +
+ )} + {item.file?.name} handleImageLinkLoadSuccess(item)} + onError={() => handleImageLinkLoadError(item)} + src={ + item.type === TransferMethod.remote_url + ? item.url + : item.base64Url + } + onClick={() => + item.progress === 100 + && setImagePreviewUrl( + (item.type === TransferMethod.remote_url + ? item.url + : item.base64Url) as string, ) } -
- )) - } - { - imagePreviewUrl && ( - setImagePreviewUrl('')} /> - ) - } + {!readonly && ( + + )} +
+ ))} + {imagePreviewUrl && ( + setImagePreviewUrl('')} + /> + )}
) } diff --git a/web/app/components/base/image-uploader/uploader.tsx b/web/app/components/base/image-uploader/uploader.tsx index f43c24c3f6..c6f5e707eb 100644 --- a/web/app/components/base/image-uploader/uploader.tsx +++ b/web/app/components/base/image-uploader/uploader.tsx @@ -7,6 +7,7 @@ import { ALLOW_FILE_EXTENSIONS } from '@/types/app' type UploaderProps = { children: (hovering: boolean) => JSX.Element onUpload: (imageFile: ImageFile) => void + closePopover?: () => void limit?: number disabled?: boolean } @@ -14,11 +15,16 @@ type UploaderProps = { const Uploader: FC = ({ children, onUpload, + closePopover, limit, disabled, }) => { const [hovering, setHovering] = useState(false) - const { handleLocalFileUpload } = useLocalFileUploader({ limit, onUpload, disabled }) + const { handleLocalFileUpload } = useLocalFileUploader({ + limit, + onUpload, + disabled, + }) const handleChange = (e: ChangeEvent) => { const file = e.target.files?.[0] @@ -27,6 +33,7 @@ const Uploader: FC = ({ return handleLocalFileUpload(file) + closePopover?.() } return ( @@ -37,11 +44,8 @@ const Uploader: FC = ({ > {children(hovering)} (e.target as HTMLInputElement).value = ''} + className='absolute block inset-0 opacity-0 text-[0] w-full disabled:cursor-not-allowed cursor-pointer' + onClick={e => ((e.target as HTMLInputElement).value = '')} type='file' accept={ALLOW_FILE_EXTENSIONS.map(ext => `.${ext}`).join(',')} onChange={handleChange} diff --git a/web/app/components/base/prompt-editor/index.tsx b/web/app/components/base/prompt-editor/index.tsx index b35b4c71f9..ea946eec46 100644 --- a/web/app/components/base/prompt-editor/index.tsx +++ b/web/app/components/base/prompt-editor/index.tsx @@ -32,6 +32,7 @@ import VariableValueBlock from './plugins/variable-value-block' import { VariableValueBlockNode } from './plugins/variable-value-block/node' import { CustomTextNode } from './plugins/custom-text/node' import OnBlurBlock from './plugins/on-blur-or-focus-block' +import UpdateBlock from './plugins/update-block' import { textToEditorState } from './utils' import type { Dataset } from './plugins/context-block' import type { RoleName } from './plugins/history-block' @@ -226,6 +227,7 @@ const PromptEditor: FC = ({ + {/* */}
diff --git a/web/app/components/base/prompt-editor/plugins/update-block.tsx b/web/app/components/base/prompt-editor/plugins/update-block.tsx new file mode 100644 index 0000000000..df84aa02e6 --- /dev/null +++ b/web/app/components/base/prompt-editor/plugins/update-block.tsx @@ -0,0 +1,21 @@ +import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' +import { textToEditorState } from '../utils' +import { useEventEmitterContextContext } from '@/context/event-emitter' + +export const PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER = 'PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER' + +const UpdateBlock = () => { + const { eventEmitter } = useEventEmitterContextContext() + const [editor] = useLexicalComposerContext() + + eventEmitter?.useSubscription((v: any) => { + if (v.type === PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER) { + const editorState = editor.parseEditorState(textToEditorState(v.payload)) + editor.setEditorState(editorState) + } + }) + + return null +} + +export default UpdateBlock diff --git a/web/app/components/datasets/create/step-two/index.tsx b/web/app/components/datasets/create/step-two/index.tsx index 4688d7afee..55ea939292 100644 --- a/web/app/components/datasets/create/step-two/index.tsx +++ b/web/app/components/datasets/create/step-two/index.tsx @@ -838,7 +838,7 @@ const StepTwo = ({ {!isSetting ? (
- +
diff --git a/web/app/components/develop/template/template_chat.en.mdx b/web/app/components/develop/template/template_chat.en.mdx index ca400d1438..d4d1e8e4b6 100644 --- a/web/app/components/develop/template/template_chat.en.mdx +++ b/web/app/components/develop/template/template_chat.en.mdx @@ -427,6 +427,53 @@ Chat applications support session persistence, allowing previous chat history to --- + + + + Get next questions suggestions for the current message + + ### Path Params + + + + Message ID + + + + + + + + ```bash {{ title: 'cURL' }} + curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested' \ + --header 'Authorization: Bearer ENTER-YOUR-SECRET-KEY' \ + --header 'Content-Type: application/json' \ + ``` + + + + + ```json {{ title: 'Response' }} + { + "result": "success", + "data": [ + "a", + "b", + "c" + ] + } + ``` + + + + +--- + + + + 获取下一轮建议问题列表。 + + ### Path Params + + + + Message ID + + + + + + + + ```bash {{ title: 'cURL' }} + curl --location --request GET '${props.appDetail.api_base_url}/messages/{message_id}/suggested' \ + --header 'Authorization: Bearer ENTER-YOUR-SECRET-KEY' \ + --header 'Content-Type: application/json' \ + ``` + + + + + ```json {{ title: 'Response' }} + { + "result": "success", + "data": [ + "a", + "b", + "c" + ] + } + ``` + + + + +--- + +--- + = ({ const showCollectionList = (() => { let typeFilteredList: Collection[] = [] if (collectionType === CollectionType.all) - typeFilteredList = collectionList - else - typeFilteredList = collectionList.filter(item => item.type === collectionType) + typeFilteredList = collectionList.filter(item => item.type !== CollectionType.model) + else if (collectionType === CollectionType.builtIn) + typeFilteredList = collectionList.filter(item => item.type === CollectionType.builtIn) + else if (collectionType === CollectionType.custom) + typeFilteredList = collectionList.filter(item => item.type === CollectionType.custom) if (query) return typeFilteredList.filter(item => item.name.includes(query)) @@ -122,6 +124,10 @@ const Tools: FC = ({ const list = await fetchBuiltInToolList(currCollection.name) setCurrentTools(list) } + else if (currCollection.type === CollectionType.model) { + const list = await fetchModelToolList(currCollection.name) + setCurrentTools(list) + } else { const list = await fetchCustomToolList(currCollection.name) setCurrentTools(list) @@ -130,7 +136,7 @@ const Tools: FC = ({ catch (e) { } setIsDetailLoading(false) })() - }, [currCollection?.name]) + }, [currCollection?.name, currCollection?.type]) const [isShowEditCollectionToolModal, setIsShowEditCollectionToolModal] = useState(false) const handleCreateToolCollection = () => { @@ -197,7 +203,7 @@ const Tools: FC = ({ (showCollectionList.length > 0 || !query) ? diff --git a/web/app/components/tools/tool-list/header.tsx b/web/app/components/tools/tool-list/header.tsx index 5a243a0a2b..bf564f320f 100644 --- a/web/app/components/tools/tool-list/header.tsx +++ b/web/app/components/tools/tool-list/header.tsx @@ -29,9 +29,8 @@ const Header: FC = ({ const { t } = useTranslation() const isInToolsPage = loc === LOC.tools const isInDebugPage = !isInToolsPage - const needAuth = collection?.allow_delete - // const isBuiltIn = collection.type === CollectionType.builtIn + const needAuth = collection?.allow_delete || collection?.type === CollectionType.model const isAuthed = collection.is_team_authorization return (
@@ -50,10 +49,13 @@ const Header: FC = ({ )}
- {collection.type === CollectionType.builtIn && needAuth && ( + {(collection.type === CollectionType.builtIn || collection.type === CollectionType.model) && needAuth && (
onShowAuth()} + onClick={() => { + if (collection.type === CollectionType.builtIn || collection.type === CollectionType.model) + onShowAuth() + }} >
{t(`tools.auth.${isAuthed ? 'authorized' : 'unauthorized'}`)}
diff --git a/web/app/components/tools/tool-list/index.tsx b/web/app/components/tools/tool-list/index.tsx index 3bee3292e6..9228a028a5 100644 --- a/web/app/components/tools/tool-list/index.tsx +++ b/web/app/components/tools/tool-list/index.tsx @@ -8,6 +8,7 @@ import type { Collection, CustomCollectionBackend, Tool } from '../types' import Loading from '../../base/loading' import { ArrowNarrowRight } from '../../base/icons/src/vender/line/arrows' import Toast from '../../base/toast' +import { ConfigurateMethodEnum } from '../../header/account-setting/model-provider-page/declarations' import Header from './header' import Item from './item' import AppIcon from '@/app/components/base/app-icon' @@ -16,6 +17,8 @@ import { fetchCustomCollection, removeBuiltInToolCredential, removeCustomCollect import EditCustomToolModal from '@/app/components/tools/edit-custom-collection-modal' import type { AgentTool } from '@/types/app' import { MAX_TOOLS_NUM } from '@/config' +import { useModalContext } from '@/context/modal-context' +import { useProviderContext } from '@/context/provider-context' type Props = { collection: Collection | null @@ -42,9 +45,32 @@ const ToolList: FC = ({ const { t } = useTranslation() const isInToolsPage = loc === LOC.tools const isBuiltIn = collection?.type === CollectionType.builtIn + const isModel = collection?.type === CollectionType.model const needAuth = collection?.allow_delete + const { setShowModelModal } = useModalContext() const [showSettingAuth, setShowSettingAuth] = useState(false) + const { modelProviders: providers } = useProviderContext() + const showSettingAuthModal = () => { + if (isModel) { + const provider = providers.find(item => item.provider === collection?.id) + if (provider) { + setShowModelModal({ + payload: { + currentProvider: provider, + currentConfigurateMethod: ConfigurateMethodEnum.predefinedModel, + currentCustomConfigrationModelFixedFields: undefined, + }, + onSaveCallback: () => { + onRefreshData() + }, + }) + } + } + else { + setShowSettingAuth(true) + } + } const [customCollection, setCustomCollection] = useState(null) useEffect(() => { @@ -116,7 +142,7 @@ const ToolList: FC = ({ icon={icon} collection={collection} loc={loc} - onShowAuth={() => setShowSettingAuth(true)} + onShowAuth={() => showSettingAuthModal()} onShowEditCustomCollection={() => setIsShowEditCustomCollectionModal(true)} />
@@ -124,12 +150,12 @@ const ToolList: FC = ({
{t('tools.includeToolNum', { num: list.length, })}
- {needAuth && isBuiltIn && !collection.is_team_authorization && ( + {needAuth && (isBuiltIn || isModel) && !collection.is_team_authorization && ( <>
·
setShowSettingAuth(true)} + onClick={() => showSettingAuthModal()} >
{t('tools.auth.setup')}
@@ -149,7 +175,7 @@ const ToolList: FC = ({ collection={collection} isInToolsPage={isInToolsPage} isToolNumMax={(addedTools?.length || 0) >= MAX_TOOLS_NUM} - added={!!addedTools?.find(v => v.provider_id === collection.id && v.tool_name === item.name)} + added={!!addedTools?.find(v => v.provider_id === collection.id && v.provider_type === collection.type && v.tool_name === item.name)} onAdd={!isInToolsPage ? tool => onAddTool?.(collection as Collection, tool) : undefined} /> ))} diff --git a/web/app/components/tools/tool-list/item.tsx b/web/app/components/tools/tool-list/item.tsx index c53aba61d3..e6e07cd5c7 100644 --- a/web/app/components/tools/tool-list/item.tsx +++ b/web/app/components/tools/tool-list/item.tsx @@ -35,6 +35,7 @@ const Item: FC = ({ const language = getLanguage(locale) const isBuiltIn = collection.type === CollectionType.builtIn + const isModel = collection.type === CollectionType.model const canShowDetail = isInToolsPage const [showDetail, setShowDetail] = useState(false) const addBtn = @@ -73,6 +74,7 @@ const Item: FC = ({ setShowDetail(false) }} isBuiltIn={isBuiltIn} + isModel={isModel} /> )} diff --git a/web/app/components/tools/tool-nav-list/index.tsx b/web/app/components/tools/tool-nav-list/index.tsx index 1fab9de7a3..3a8fd4088b 100644 --- a/web/app/components/tools/tool-nav-list/index.tsx +++ b/web/app/components/tools/tool-nav-list/index.tsx @@ -6,21 +6,21 @@ import Item from './item' import type { Collection } from '@/app/components/tools/types' type Props = { className?: string - currentName: string + currentIndex: number list: Collection[] onChosen: (index: number) => void } const ToolNavList: FC = ({ className, - currentName, + currentIndex, list, onChosen, }) => { return (
{list.map((item, index) => ( - onChosen(index)}> + onChosen(index)}> ))}
) diff --git a/web/app/components/tools/types.ts b/web/app/components/tools/types.ts index 389276e81c..6de8d8aa76 100644 --- a/web/app/components/tools/types.ts +++ b/web/app/components/tools/types.ts @@ -26,6 +26,7 @@ export enum CollectionType { all = 'all', builtIn = 'builtin', custom = 'api', + model = 'model', } export type Emoji = { diff --git a/web/i18n/en-US/dataset-creation.ts b/web/i18n/en-US/dataset-creation.ts index 61f32436fd..8923170f7f 100644 --- a/web/i18n/en-US/dataset-creation.ts +++ b/web/i18n/en-US/dataset-creation.ts @@ -89,7 +89,7 @@ const translation = { other: 'and other ', fileUnit: ' files', notionUnit: ' pages', - lastStep: 'Last step', + previousStep: 'Previous step', nextStep: 'Save & Process', save: 'Save & Process', cancel: 'Cancel', diff --git a/web/i18n/ja-JP/dataset-creation.ts b/web/i18n/ja-JP/dataset-creation.ts index b08d0224f9..384c655214 100644 --- a/web/i18n/ja-JP/dataset-creation.ts +++ b/web/i18n/ja-JP/dataset-creation.ts @@ -89,7 +89,7 @@ const translation = { other: 'その他', fileUnit: 'ファイル', notionUnit: 'ページ', - lastStep: '最後のステップ', + previousStep: '前のステップ', nextStep: '保存して処理', save: '保存して処理', cancel: 'キャンセル', diff --git a/web/i18n/pt-BR/dataset-creation.ts b/web/i18n/pt-BR/dataset-creation.ts index 08018eae61..b721f2177b 100644 --- a/web/i18n/pt-BR/dataset-creation.ts +++ b/web/i18n/pt-BR/dataset-creation.ts @@ -89,7 +89,7 @@ const translation = { other: 'e outros ', fileUnit: ' arquivos', notionUnit: ' páginas', - lastStep: 'Última etapa', + previousStep: 'Passo anterior', nextStep: 'Salvar e Processar', save: 'Salvar e Processar', cancel: 'Cancelar', diff --git a/web/i18n/uk-UA/dataset-creation.ts b/web/i18n/uk-UA/dataset-creation.ts index 7ba648c38f..6c0099a771 100644 --- a/web/i18n/uk-UA/dataset-creation.ts +++ b/web/i18n/uk-UA/dataset-creation.ts @@ -89,7 +89,7 @@ const translation = { other: ' та інші ', fileUnit: ' файли', notionUnit: ' сторінки', - lastStep: 'Попередній крок', + previousStep: 'Попередній крок', nextStep: 'Зберегти та обробити', save: 'Зберегти та обробити', cancel: 'Скасувати', diff --git a/web/i18n/zh-Hans/dataset-creation.ts b/web/i18n/zh-Hans/dataset-creation.ts index dc401e1ce2..d36850dc3d 100644 --- a/web/i18n/zh-Hans/dataset-creation.ts +++ b/web/i18n/zh-Hans/dataset-creation.ts @@ -89,7 +89,7 @@ const translation = { other: '和其他 ', fileUnit: ' 个文件', notionUnit: ' 个页面', - lastStep: '上一步', + previousStep: '上一步', nextStep: '保存并处理', save: '保存并处理', cancel: '取消', diff --git a/web/service/tools.ts b/web/service/tools.ts index 008de4a557..ac59e2e508 100644 --- a/web/service/tools.ts +++ b/web/service/tools.ts @@ -12,6 +12,11 @@ export const fetchBuiltInToolList = (collectionName: string) => { export const fetchCustomToolList = (collectionName: string) => { return get(`/workspaces/current/tool-provider/api/tools?provider=${collectionName}`) } + +export const fetchModelToolList = (collectionName: string) => { + return get(`/workspaces/current/tool-provider/model/tools?provider=${collectionName}`) +} + export const fetchBuiltInToolCredentialSchema = (collectionName: string) => { return get(`/workspaces/current/tool-provider/builtin/${collectionName}/credentials_schema`) }