diff --git a/README.md b/README.md index 7e2740b10e..14111ee060 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,17 @@
+ + Dify.AI Upcoming Meetup Event [👉 Click to Join the Event Here 👈] + +
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs diff --git a/api/celerybeat-schedule.db b/api/celerybeat-schedule.db deleted file mode 100644 index b8c01de27b..0000000000 Binary files a/api/celerybeat-schedule.db and /dev/null differ diff --git a/api/commands.py b/api/commands.py index 91b50445e6..9923ccb8b8 100644 --- a/api/commands.py +++ b/api/commands.py @@ -6,15 +6,15 @@ import click from flask import current_app from werkzeug.exceptions import NotFound -from core.embedding.cached_embedding import CacheEmbedding -from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.models.document import Document from extensions.ext_database import db from libs.helper import email as email_validate from libs.password import hash_password, password_pattern, valid_password from libs.rsa import generate_key_pair from models.account import Tenant -from models.dataset import Dataset +from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment +from models.dataset import Document as DatasetDocument from models.model import Account from models.provider import Provider, ProviderModel @@ -124,14 +124,15 @@ def reset_encrypt_key_pair(): 'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green')) -@click.command('create-qdrant-indexes', help='Create qdrant indexes.') -def create_qdrant_indexes(): +@click.command('vdb-migrate', help='migrate vector db.') +def vdb_migrate(): """ - Migrate other vector database datas to Qdrant. + Migrate vector database datas to target vector database . """ - click.echo(click.style('Start create qdrant indexes.', fg='green')) + click.echo(click.style('Start migrate vector db.', fg='green')) create_count = 0 - + config = current_app.config + vector_type = config.get('VECTOR_STORE') page = 1 while True: try: @@ -140,54 +141,101 @@ def create_qdrant_indexes(): except NotFound: break - model_manager = ModelManager() - page += 1 for dataset in datasets: - if dataset.index_struct_dict: - if dataset.index_struct_dict['type'] != 'qdrant': - try: - click.echo('Create dataset qdrant index: {}'.format(dataset.id)) - try: - embedding_model = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model - - ) - except Exception: - continue - embeddings = CacheEmbedding(embedding_model) - - from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex - - index = QdrantVectorIndex( - dataset=dataset, - config=QdrantConfig( - endpoint=current_app.config.get('QDRANT_URL'), - api_key=current_app.config.get('QDRANT_API_KEY'), - root_path=current_app.root_path - ), - embeddings=embeddings - ) - if index: - index.create_qdrant_dataset(dataset) - index_struct = { - "type": 'qdrant', - "vector_store": { - "class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']} - } - dataset.index_struct = json.dumps(index_struct) - db.session.commit() - create_count += 1 - else: - click.echo('passed.') - except Exception as e: - click.echo( - click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), - fg='red')) + try: + click.echo('Create dataset vdb index: {}'.format(dataset.id)) + if dataset.index_struct_dict: + if dataset.index_struct_dict['type'] == vector_type: continue + if vector_type == "weaviate": + dataset_id = dataset.id + collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + index_struct_dict = { + "type": 'weaviate', + "vector_store": {"class_prefix": collection_name} + } + dataset.index_struct = json.dumps(index_struct_dict) + elif vector_type == "qdrant": + if dataset.collection_binding_id: + dataset_collection_binding = db.session.query(DatasetCollectionBinding). \ + filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \ + one_or_none() + if dataset_collection_binding: + collection_name = dataset_collection_binding.collection_name + else: + raise ValueError('Dataset Collection Bindings is not exist!') + else: + dataset_id = dataset.id + collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + index_struct_dict = { + "type": 'qdrant', + "vector_store": {"class_prefix": collection_name} + } + dataset.index_struct = json.dumps(index_struct_dict) + + elif vector_type == "milvus": + dataset_id = dataset.id + collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + index_struct_dict = { + "type": 'milvus', + "vector_store": {"class_prefix": collection_name} + } + dataset.index_struct = json.dumps(index_struct_dict) + else: + raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") + + vector = Vector(dataset) + click.echo(f"vdb_migrate {dataset.id}") + + try: + vector.delete() + except Exception as e: + raise e + + dataset_documents = db.session.query(DatasetDocument).filter( + DatasetDocument.dataset_id == dataset.id, + DatasetDocument.indexing_status == 'completed', + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ).all() + + documents = [] + for dataset_document in dataset_documents: + segments = db.session.query(DocumentSegment).filter( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.status == 'completed', + DocumentSegment.enabled == True + ).all() + + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + } + ) + + documents.append(document) + + if documents: + try: + vector.create(documents) + except Exception as e: + raise e + click.echo(f"Dataset {dataset.id} create successfully.") + db.session.add(dataset) + db.session.commit() + create_count += 1 + except Exception as e: + db.session.rollback() + click.echo( + click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)), + fg='red')) + continue click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green')) @@ -196,4 +244,4 @@ def register_commands(app): app.cli.add_command(reset_password) app.cli.add_command(reset_email) app.cli.add_command(reset_encrypt_key_pair) - app.cli.add_command(create_qdrant_indexes) + app.cli.add_command(vdb_migrate) diff --git a/api/config.py b/api/config.py index 83336e6c45..3f6980bdea 100644 --- a/api/config.py +++ b/api/config.py @@ -38,7 +38,9 @@ DEFAULTS = { 'LOG_LEVEL': 'INFO', 'HOSTED_OPENAI_QUOTA_LIMIT': 200, 'HOSTED_OPENAI_TRIAL_ENABLED': 'False', + 'HOSTED_OPENAI_TRIAL_MODELS': 'gpt-3.5-turbo,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,text-davinci-003', 'HOSTED_OPENAI_PAID_ENABLED': 'False', + 'HOSTED_OPENAI_PAID_MODELS': 'gpt-4,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-4-0125-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,gpt-3.5-turbo-instruct,text-davinci-003', 'HOSTED_AZURE_OPENAI_ENABLED': 'False', 'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200, 'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000, @@ -88,7 +90,7 @@ class Config: # ------------------------ # General Configurations. # ------------------------ - self.CURRENT_VERSION = "0.5.6" + self.CURRENT_VERSION = "0.5.7" self.COMMIT_SHA = get_env('COMMIT_SHA') self.EDITION = "SELF_HOSTED" self.DEPLOY_ENV = get_env('DEPLOY_ENV') @@ -261,8 +263,10 @@ class Config: self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE') self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION') self.HOSTED_OPENAI_TRIAL_ENABLED = get_bool_env('HOSTED_OPENAI_TRIAL_ENABLED') + self.HOSTED_OPENAI_TRIAL_MODELS = get_env('HOSTED_OPENAI_TRIAL_MODELS') self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT')) self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED') + self.HOSTED_OPENAI_PAID_MODELS = get_env('HOSTED_OPENAI_PAID_MODELS') self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED') self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY') diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 87cad07462..59a7535144 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -124,19 +124,13 @@ class AppListApi(Resource): available_models_names = [f'{model.provider.provider}.{model.model}' for model in available_models] provider_model = f"{model_config_dict['model']['provider']}.{model_config_dict['model']['name']}" if provider_model not in available_models_names: - model_manager = ModelManager() - model_instance = model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.LLM - ) - - if not model_instance: + if not default_model_entity: raise ProviderNotInitializeError( "No Default System Reasoning Model available. Please configure " "in the Settings -> Model Provider.") else: - model_config_dict["model"]["provider"] = model_instance.provider - model_config_dict["model"]["name"] = model_instance.model + model_config_dict["model"]["provider"] = default_model_entity.provider + model_config_dict["model"]["name"] = default_model_entity.model model_configuration = AppModelConfigService.validate_configuration( tenant_id=current_user.current_tenant_id, diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index c0c345baea..f3e639c6ac 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -178,7 +178,8 @@ class DataSourceNotionApi(Resource): notion_workspace_id=workspace_id, notion_obj_id=page_id, notion_page_type=page_type, - notion_access_token=data_source_binding.access_token + notion_access_token=data_source_binding.access_token, + tenant_id=current_user.current_tenant_id ) text_docs = extractor.extract() @@ -208,7 +209,8 @@ class DataSourceNotionApi(Resource): notion_info={ "notion_workspace_id": workspace_id, "notion_obj_id": page['page_id'], - "notion_page_type": page['type'] + "notion_page_type": page['type'], + "tenant_id": current_user.current_tenant_id }, document_model=args['doc_form'] ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index f80b4de48d..e633631c42 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -298,7 +298,8 @@ class DatasetIndexingEstimateApi(Resource): notion_info={ "notion_workspace_id": workspace_id, "notion_obj_id": page['page_id'], - "notion_page_type": page['type'] + "notion_page_type": page['type'], + "tenant_id": current_user.current_tenant_id }, document_model=args['doc_form'] ) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index a990ef96ee..c383cdc762 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -455,7 +455,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): notion_info={ "notion_workspace_id": data_source_info['notion_workspace_id'], "notion_obj_id": data_source_info['notion_page_id'], - "notion_page_type": data_source_info['type'] + "notion_page_type": data_source_info['type'], + "tenant_id": current_user.current_tenant_id }, document_model=document.doc_form ) diff --git a/api/controllers/service_api/app/__init__.py b/api/controllers/service_api/app/__init__.py index d8018ee385..e69de29bb2 100644 --- a/api/controllers/service_api/app/__init__.py +++ b/api/controllers/service_api/app/__init__.py @@ -1,27 +0,0 @@ -from extensions.ext_database import db -from models.model import EndUser - - -def create_or_update_end_user_for_user_id(app_model, user_id): - """ - Create or update session terminal based on user ID. - """ - end_user = db.session.query(EndUser) \ - .filter( - EndUser.tenant_id == app_model.tenant_id, - EndUser.session_id == user_id, - EndUser.type == 'service_api' - ).first() - - if end_user is None: - end_user = EndUser( - tenant_id=app_model.tenant_id, - app_id=app_model.id, - type='service_api', - is_anonymous=True, - session_id=user_id - ) - db.session.add(end_user) - db.session.commit() - - return end_user diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 9cd9770c09..a3151fc4a2 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,16 +1,16 @@ import json from flask import current_app -from flask_restful import fields, marshal_with +from flask_restful import fields, marshal_with, Resource from controllers.service_api import api -from controllers.service_api.wraps import AppApiResource +from controllers.service_api.wraps import validate_app_token from extensions.ext_database import db from models.model import App, AppModelConfig from models.tools import ApiToolProvider -class AppParameterApi(AppApiResource): +class AppParameterApi(Resource): """Resource for app variables.""" variable_fields = { @@ -42,8 +42,9 @@ class AppParameterApi(AppApiResource): 'system_parameters': fields.Nested(system_parameters_fields) } + @validate_app_token @marshal_with(parameters_fields) - def get(self, app_model: App, end_user): + def get(self, app_model: App): """Retrieve app parameters.""" app_model_config = app_model.app_model_config @@ -64,8 +65,9 @@ class AppParameterApi(AppApiResource): } } -class AppMetaApi(AppApiResource): - def get(self, app_model: App, end_user): +class AppMetaApi(Resource): + @validate_app_token + def get(self, app_model: App): """Get app meta""" app_model_config: AppModelConfig = app_model.app_model_config diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index d2906b1d6e..58ab56a292 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -1,7 +1,7 @@ import logging from flask import request -from flask_restful import reqparse +from flask_restful import Resource, reqparse from werkzeug.exceptions import InternalServerError import services @@ -17,10 +17,10 @@ from controllers.service_api.app.error import ( ProviderQuotaExceededError, UnsupportedAudioTypeError, ) -from controllers.service_api.wraps import AppApiResource +from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from models.model import App, AppModelConfig +from models.model import App, AppModelConfig, EndUser from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, @@ -30,8 +30,9 @@ from services.errors.audio import ( ) -class AudioApi(AppApiResource): - def post(self, app_model: App, end_user): +class AudioApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) + def post(self, app_model: App, end_user: EndUser): app_model_config: AppModelConfig = app_model.app_model_config if not app_model_config.speech_to_text_dict['enabled']: @@ -73,11 +74,11 @@ class AudioApi(AppApiResource): raise InternalServerError() -class TextApi(AppApiResource): - def post(self, app_model: App, end_user): +class TextApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) + def post(self, app_model: App, end_user: EndUser): parser = reqparse.RequestParser() parser.add_argument('text', type=str, required=True, nullable=False, location='json') - parser.add_argument('user', type=str, required=True, nullable=False, location='json') parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json') args = parser.parse_args() @@ -85,7 +86,7 @@ class TextApi(AppApiResource): response = AudioService.transcript_tts( tenant_id=app_model.tenant_id, text=args['text'], - end_user=args['user'], + end_user=end_user, voice=app_model.app_model_config.text_to_speech_dict.get('voice'), streaming=args['streaming'] ) diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 5331f796e7..c6cfb24378 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -4,12 +4,11 @@ from collections.abc import Generator from typing import Union from flask import Response, stream_with_context -from flask_restful import reqparse +from flask_restful import Resource, reqparse from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.service_api import api -from controllers.service_api.app import create_or_update_end_user_for_user_id from controllers.service_api.app.error import ( AppUnavailableError, CompletionRequestError, @@ -19,17 +18,19 @@ from controllers.service_api.app.error import ( ProviderNotInitializeError, ProviderQuotaExceededError, ) -from controllers.service_api.wraps import AppApiResource +from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.application_queue_manager import ApplicationQueueManager from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value +from models.model import App, EndUser from services.completion_service import CompletionService -class CompletionApi(AppApiResource): - def post(self, app_model, end_user): +class CompletionApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) + def post(self, app_model: App, end_user: EndUser): if app_model.mode != 'completion': raise AppUnavailableError() @@ -38,16 +39,12 @@ class CompletionApi(AppApiResource): parser.add_argument('query', type=str, location='json', default='') parser.add_argument('files', type=list, required=False, location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('user', required=True, nullable=False, type=str, location='json') parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') args = parser.parse_args() streaming = args['response_mode'] == 'streaming' - if end_user is None and args['user'] is not None: - end_user = create_or_update_end_user_for_user_id(app_model, args['user']) - args['auto_generate_name'] = False try: @@ -82,29 +79,20 @@ class CompletionApi(AppApiResource): raise InternalServerError() -class CompletionStopApi(AppApiResource): - def post(self, app_model, end_user, task_id): +class CompletionStopApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) + def post(self, app_model: App, end_user: EndUser, task_id): if app_model.mode != 'completion': raise AppUnavailableError() - if end_user is None: - parser = reqparse.RequestParser() - parser.add_argument('user', required=True, nullable=False, type=str, location='json') - args = parser.parse_args() - - user = args.get('user') - if user is not None: - end_user = create_or_update_end_user_for_user_id(app_model, user) - else: - raise ValueError("arg user muse be input.") - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) return {'result': 'success'}, 200 -class ChatApi(AppApiResource): - def post(self, app_model, end_user): +class ChatApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) + def post(self, app_model: App, end_user: EndUser): if app_model.mode != 'chat': raise NotChatAppError() @@ -114,7 +102,6 @@ class ChatApi(AppApiResource): parser.add_argument('files', type=list, required=False, location='json') parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('user', type=str, required=True, nullable=False, location='json') parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json') @@ -122,9 +109,6 @@ class ChatApi(AppApiResource): streaming = args['response_mode'] == 'streaming' - if end_user is None and args['user'] is not None: - end_user = create_or_update_end_user_for_user_id(app_model, args['user']) - try: response = CompletionService.completion( app_model=app_model, @@ -157,22 +141,12 @@ class ChatApi(AppApiResource): raise InternalServerError() -class ChatStopApi(AppApiResource): - def post(self, app_model, end_user, task_id): +class ChatStopApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) + def post(self, app_model: App, end_user: EndUser, task_id): if app_model.mode != 'chat': raise NotChatAppError() - if end_user is None: - parser = reqparse.RequestParser() - parser.add_argument('user', required=True, nullable=False, type=str, location='json') - args = parser.parse_args() - - user = args.get('user') - if user is not None: - end_user = create_or_update_end_user_for_user_id(app_model, user) - else: - raise ValueError("arg user muse be input.") - ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) return {'result': 'success'}, 200 diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 3c157bed99..4a5fe2f19f 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,52 +1,44 @@ -from flask import request -from flask_restful import marshal_with, reqparse +from flask_restful import Resource, marshal_with, reqparse from flask_restful.inputs import int_range from werkzeug.exceptions import NotFound import services from controllers.service_api import api -from controllers.service_api.app import create_or_update_end_user_for_user_id from controllers.service_api.app.error import NotChatAppError -from controllers.service_api.wraps import AppApiResource +from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields from libs.helper import uuid_value +from models.model import App, EndUser from services.conversation_service import ConversationService -class ConversationApi(AppApiResource): +class ConversationApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) @marshal_with(conversation_infinite_scroll_pagination_fields) - def get(self, app_model, end_user): + def get(self, app_model: App, end_user: EndUser): if app_model.mode != 'chat': raise NotChatAppError() parser = reqparse.RequestParser() parser.add_argument('last_id', type=uuid_value, location='args') parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('user', type=str, location='args') args = parser.parse_args() - if end_user is None and args['user'] is not None: - end_user = create_or_update_end_user_for_user_id(app_model, args['user']) - try: return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit']) except services.errors.conversation.LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") -class ConversationDetailApi(AppApiResource): +class ConversationDetailApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) @marshal_with(simple_conversation_fields) - def delete(self, app_model, end_user, c_id): + def delete(self, app_model: App, end_user: EndUser, c_id): if app_model.mode != 'chat': raise NotChatAppError() conversation_id = str(c_id) - user = request.get_json().get('user') - - if end_user is None and user is not None: - end_user = create_or_update_end_user_for_user_id(app_model, user) - try: ConversationService.delete(app_model, conversation_id, end_user) except services.errors.conversation.ConversationNotExistsError: @@ -54,10 +46,11 @@ class ConversationDetailApi(AppApiResource): return {"result": "success"}, 204 -class ConversationRenameApi(AppApiResource): +class ConversationRenameApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) @marshal_with(simple_conversation_fields) - def post(self, app_model, end_user, c_id): + def post(self, app_model: App, end_user: EndUser, c_id): if app_model.mode != 'chat': raise NotChatAppError() @@ -65,13 +58,9 @@ class ConversationRenameApi(AppApiResource): parser = reqparse.RequestParser() parser.add_argument('name', type=str, required=False, location='json') - parser.add_argument('user', type=str, location='json') parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') args = parser.parse_args() - if end_user is None and args['user'] is not None: - end_user = create_or_update_end_user_for_user_id(app_model, args['user']) - try: return ConversationService.rename( app_model, diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index a901375ec0..5dbc1b1d1b 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -1,30 +1,27 @@ from flask import request -from flask_restful import marshal_with +from flask_restful import Resource, marshal_with import services from controllers.service_api import api -from controllers.service_api.app import create_or_update_end_user_for_user_id from controllers.service_api.app.error import ( FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError, ) -from controllers.service_api.wraps import AppApiResource +from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from fields.file_fields import file_fields +from models.model import App, EndUser from services.file_service import FileService -class FileApi(AppApiResource): +class FileApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) @marshal_with(file_fields) - def post(self, app_model, end_user): + def post(self, app_model: App, end_user: EndUser): file = request.files['file'] - user_args = request.form.get('user') - - if end_user is None and user_args is not None: - end_user = create_or_update_end_user_for_user_id(app_model, user_args) # check file if 'file' not in request.files: diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index d90f536a42..0050ab1aee 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,20 +1,18 @@ -from flask_restful import fields, marshal_with, reqparse +from flask_restful import Resource, fields, marshal_with, reqparse from flask_restful.inputs import int_range from werkzeug.exceptions import NotFound import services from controllers.service_api import api -from controllers.service_api.app import create_or_update_end_user_for_user_id from controllers.service_api.app.error import NotChatAppError -from controllers.service_api.wraps import AppApiResource -from extensions.ext_database import db +from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from fields.conversation_fields import message_file_fields from libs.helper import TimestampField, uuid_value -from models.model import EndUser, Message +from models.model import App, EndUser from services.message_service import MessageService -class MessageListApi(AppApiResource): +class MessageListApi(Resource): feedback_fields = { 'rating': fields.String } @@ -70,8 +68,9 @@ class MessageListApi(AppApiResource): 'data': fields.List(fields.Nested(message_fields)) } + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) @marshal_with(message_infinite_scroll_pagination_fields) - def get(self, app_model, end_user): + def get(self, app_model: App, end_user: EndUser): if app_model.mode != 'chat': raise NotChatAppError() @@ -79,12 +78,8 @@ class MessageListApi(AppApiResource): parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') parser.add_argument('first_id', type=uuid_value, location='args') parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('user', type=str, location='args') args = parser.parse_args() - if end_user is None and args['user'] is not None: - end_user = create_or_update_end_user_for_user_id(app_model, args['user']) - try: return MessageService.pagination_by_first_id(app_model, end_user, args['conversation_id'], args['first_id'], args['limit']) @@ -94,18 +89,15 @@ class MessageListApi(AppApiResource): raise NotFound("First Message Not Exists.") -class MessageFeedbackApi(AppApiResource): - def post(self, app_model, end_user, message_id): +class MessageFeedbackApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) + def post(self, app_model: App, end_user: EndUser, message_id): message_id = str(message_id) parser = reqparse.RequestParser() parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') - parser.add_argument('user', type=str, location='json') args = parser.parse_args() - if end_user is None and args['user'] is not None: - end_user = create_or_update_end_user_for_user_id(app_model, args['user']) - try: MessageService.create_feedback(app_model, message_id, end_user, args['rating']) except services.errors.message.MessageNotExistsError: @@ -114,29 +106,17 @@ class MessageFeedbackApi(AppApiResource): return {'result': 'success'} -class MessageSuggestedApi(AppApiResource): - def get(self, app_model, end_user, message_id): +class MessageSuggestedApi(Resource): + @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) + def get(self, app_model: App, end_user: EndUser, message_id): message_id = str(message_id) if app_model.mode != 'chat': raise NotChatAppError() - try: - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id, - ).first() - if end_user is None and message.from_end_user_id is not None: - user = db.session.query(EndUser) \ - .filter( - EndUser.tenant_id == app_model.tenant_id, - EndUser.id == message.from_end_user_id, - EndUser.type == 'service_api' - ).first() - else: - user = end_user + try: questions = MessageService.get_suggested_questions_after_answer( app_model=app_model, - user=user, + user=end_user, message_id=message_id, check_enabled=False ) diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index a0d89fe62f..9819c73d37 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,22 +1,40 @@ +from collections.abc import Callable from datetime import datetime +from enum import Enum from functools import wraps +from typing import Optional from flask import current_app, request from flask_login import user_logged_in from flask_restful import Resource +from pydantic import BaseModel from werkzeug.exceptions import NotFound, Unauthorized from extensions.ext_database import db from libs.login import _get_user from models.account import Account, Tenant, TenantAccountJoin -from models.model import ApiToken, App +from models.model import ApiToken, App, EndUser from services.feature_service import FeatureService -def validate_app_token(view=None): - def decorator(view): - @wraps(view) - def decorated(*args, **kwargs): +class WhereisUserArg(Enum): + """ + Enum for whereis_user_arg. + """ + QUERY = 'query' + JSON = 'json' + FORM = 'form' + + +class FetchUserArg(BaseModel): + fetch_from: WhereisUserArg + required: bool = False + + +def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None): + def decorator(view_func): + @wraps(view_func) + def decorated_view(*args, **kwargs): api_token = validate_and_get_api_token('app') app_model = db.session.query(App).filter(App.id == api_token.app_id).first() @@ -29,16 +47,35 @@ def validate_app_token(view=None): if not app_model.enable_api: raise NotFound() - return view(app_model, None, *args, **kwargs) - return decorated + kwargs['app_model'] = app_model - if view: + if fetch_user_arg: + if fetch_user_arg.fetch_from == WhereisUserArg.QUERY: + user_id = request.args.get('user') + elif fetch_user_arg.fetch_from == WhereisUserArg.JSON: + user_id = request.get_json().get('user') + elif fetch_user_arg.fetch_from == WhereisUserArg.FORM: + user_id = request.form.get('user') + else: + # use default-user + user_id = None + + if not user_id and fetch_user_arg.required: + raise ValueError("Arg user must be provided.") + + if user_id: + user_id = str(user_id) + + kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id) + + return view_func(*args, **kwargs) + return decorated_view + + if view is None: + return decorator + else: return decorator(view) - # if view is None, it means that the decorator is used without parentheses - # use the decorator as a function for method_decorators - return decorator - def cloud_edition_billing_resource_check(resource: str, api_token_type: str, @@ -128,8 +165,33 @@ def validate_and_get_api_token(scope=None): return api_token -class AppApiResource(Resource): - method_decorators = [validate_app_token] +def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser: + """ + Create or update session terminal based on user ID. + """ + if not user_id: + user_id = 'DEFAULT-USER' + + end_user = db.session.query(EndUser) \ + .filter( + EndUser.tenant_id == app_model.tenant_id, + EndUser.app_id == app_model.id, + EndUser.session_id == user_id, + EndUser.type == 'service_api' + ).first() + + if end_user is None: + end_user = EndUser( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type='service_api', + is_anonymous=True if user_id == 'DEFAULT-USER' else False, + session_id=user_id + ) + db.session.add(end_user) + db.session.commit() + + return end_user class DatasetApiResource(Resource): diff --git a/api/core/agent/agent/calc_token_mixin.py b/api/core/agent/agent/calc_token_mixin.py deleted file mode 100644 index 9c0f9c5b36..0000000000 --- a/api/core/agent/agent/calc_token_mixin.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import cast - -from core.entities.application_entities import ModelConfigEntity -from core.model_runtime.entities.message_entities import PromptMessage -from core.model_runtime.entities.model_entities import ModelPropertyKey -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel - - -class CalcTokenMixin: - - def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int: - """ - Got the rest tokens available for the model after excluding messages tokens and completion max tokens - - :param model_config: - :param messages: - :return: - """ - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - - max_tokens = 0 - for parameter_rule in model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - max_tokens = (model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template)) or 0 - - if model_context_tokens is None: - return 0 - - if max_tokens is None: - max_tokens = 0 - - prompt_tokens = model_type_instance.get_num_tokens( - model_config.model, - model_config.credentials, - messages - ) - - rest_tokens = model_context_tokens - max_tokens - prompt_tokens - - return rest_tokens - - -class ExceededLLMTokensLimitError(Exception): - pass diff --git a/api/core/agent/agent/openai_function_call.py b/api/core/agent/agent/openai_function_call.py deleted file mode 100644 index 1f2d5f24b3..0000000000 --- a/api/core/agent/agent/openai_function_call.py +++ /dev/null @@ -1,361 +0,0 @@ -from collections.abc import Sequence -from typing import Any, Optional, Union - -from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent -from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message -from langchain.callbacks.base import BaseCallbackManager -from langchain.callbacks.manager import Callbacks -from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken -from langchain.memory.prompt import SUMMARY_PROMPT -from langchain.prompts.chat import BaseMessagePromptTemplate -from langchain.schema import ( - AgentAction, - AgentFinish, - AIMessage, - BaseMessage, - HumanMessage, - SystemMessage, - get_buffer_string, -) -from langchain.tools import BaseTool -from pydantic import root_validator - -from core.agent.agent.agent_llm_callback import AgentLLMCallback -from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError -from core.chain.llm_chain import LLMChain -from core.entities.application_entities import ModelConfigEntity -from core.entities.message_entities import lc_messages_to_prompt_messages -from core.model_manager import ModelInstance -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.third_party.langchain.llms.fake import FakeLLM - - -class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin): - moving_summary_buffer: str = "" - moving_summary_index: int = 0 - summary_model_config: ModelConfigEntity = None - model_config: ModelConfigEntity - agent_llm_callback: Optional[AgentLLMCallback] = None - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - - @root_validator - def validate_llm(cls, values: dict) -> dict: - return values - - @classmethod - def from_llm_and_tools( - cls, - model_config: ModelConfigEntity, - tools: Sequence[BaseTool], - callback_manager: Optional[BaseCallbackManager] = None, - extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, - system_message: Optional[SystemMessage] = SystemMessage( - content="You are a helpful AI assistant." - ), - agent_llm_callback: Optional[AgentLLMCallback] = None, - **kwargs: Any, - ) -> BaseSingleActionAgent: - prompt = cls.create_prompt( - extra_prompt_messages=extra_prompt_messages, - system_message=system_message, - ) - return cls( - model_config=model_config, - llm=FakeLLM(response=''), - prompt=prompt, - tools=tools, - callback_manager=callback_manager, - agent_llm_callback=agent_llm_callback, - **kwargs, - ) - - def should_use_agent(self, query: str): - """ - return should use agent - - :param query: - :return: - """ - original_max_tokens = 0 - for parameter_rule in self.model_config.model_schema.parameter_rules: - if (parameter_rule.name == 'max_tokens' - or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')): - original_max_tokens = (self.model_config.parameters.get(parameter_rule.name) - or self.model_config.parameters.get(parameter_rule.use_template)) or 0 - - self.model_config.parameters['max_tokens'] = 40 - - prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[]) - messages = prompt.to_messages() - - try: - prompt_messages = lc_messages_to_prompt_messages(messages) - model_instance = ModelInstance( - provider_model_bundle=self.model_config.provider_model_bundle, - model=self.model_config.model, - ) - - tools = [] - for function in self.functions: - tool = PromptMessageTool( - **function - ) - - tools.append(tool) - - result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - tools=tools, - stream=False, - model_parameters={ - 'temperature': 0.2, - 'top_p': 0.3, - 'max_tokens': 1500 - } - ) - except Exception as e: - raise e - - self.model_config.parameters['max_tokens'] = original_max_tokens - - return True if result.message.tool_calls else False - - def plan( - self, - intermediate_steps: list[tuple[AgentAction, str]], - callbacks: Callbacks = None, - **kwargs: Any, - ) -> Union[AgentAction, AgentFinish]: - """Given input, decided what to do. - - Args: - intermediate_steps: Steps the LLM has taken to date, along with observations - **kwargs: User inputs. - - Returns: - Action specifying what tool to use. - """ - agent_scratchpad = _format_intermediate_steps(intermediate_steps) - selected_inputs = { - k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad" - } - full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad) - prompt = self.prompt.format_prompt(**full_inputs) - messages = prompt.to_messages() - - prompt_messages = lc_messages_to_prompt_messages(messages) - - # summarize messages if rest_tokens < 0 - try: - prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions) - except ExceededLLMTokensLimitError as e: - return AgentFinish(return_values={"output": str(e)}, log=str(e)) - - model_instance = ModelInstance( - provider_model_bundle=self.model_config.provider_model_bundle, - model=self.model_config.model, - ) - - tools = [] - for function in self.functions: - tool = PromptMessageTool( - **function - ) - - tools.append(tool) - - result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - tools=tools, - stream=False, - callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [], - model_parameters={ - 'temperature': 0.2, - 'top_p': 0.3, - 'max_tokens': 1500 - } - ) - - ai_message = AIMessage( - content=result.message.content or "", - additional_kwargs={ - 'function_call': { - 'id': result.message.tool_calls[0].id, - **result.message.tool_calls[0].function.dict() - } if result.message.tool_calls else None - } - ) - agent_decision = _parse_ai_message(ai_message) - - if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset': - tool_inputs = agent_decision.tool_input - if isinstance(tool_inputs, dict) and 'query' in tool_inputs: - tool_inputs['query'] = kwargs['input'] - agent_decision.tool_input = tool_inputs - - return agent_decision - - @classmethod - def get_system_message(cls): - return SystemMessage(content="You are a helpful AI assistant.\n" - "The current date or current time you know is wrong.\n" - "Respond directly if appropriate.") - - def return_stopped_response( - self, - early_stopping_method: str, - intermediate_steps: list[tuple[AgentAction, str]], - **kwargs: Any, - ) -> AgentFinish: - try: - return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs) - except ValueError: - return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "") - - def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]: - # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0 - rest_tokens = self.get_message_rest_tokens( - self.model_config, - messages, - **kwargs - ) - - rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens - if rest_tokens >= 0: - return messages - - system_message = None - human_message = None - should_summary_messages = [] - for message in messages: - if isinstance(message, SystemMessage): - system_message = message - elif isinstance(message, HumanMessage): - human_message = message - else: - should_summary_messages.append(message) - - if len(should_summary_messages) > 2: - ai_message = should_summary_messages[-2] - function_message = should_summary_messages[-1] - should_summary_messages = should_summary_messages[self.moving_summary_index:-2] - self.moving_summary_index = len(should_summary_messages) - else: - error_msg = "Exceeded LLM tokens limit, stopped." - raise ExceededLLMTokensLimitError(error_msg) - - new_messages = [system_message, human_message] - - if self.moving_summary_index == 0: - should_summary_messages.insert(0, human_message) - - self.moving_summary_buffer = self.predict_new_summary( - messages=should_summary_messages, - existing_summary=self.moving_summary_buffer - ) - - new_messages.append(AIMessage(content=self.moving_summary_buffer)) - new_messages.append(ai_message) - new_messages.append(function_message) - - return new_messages - - def predict_new_summary( - self, messages: list[BaseMessage], existing_summary: str - ) -> str: - new_lines = get_buffer_string( - messages, - human_prefix="Human", - ai_prefix="AI", - ) - - chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT) - return chain.predict(summary=existing_summary, new_lines=new_lines) - - def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int: - """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. - - Official documentation: https://github.com/openai/openai-cookbook/blob/ - main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" - if model_config.provider == 'azure_openai': - model = model_config.model - model = model.replace("gpt-35", "gpt-3.5") - else: - model = model_config.credentials.get("base_model_name") - - tiktoken_ = _import_tiktoken() - try: - encoding = tiktoken_.encoding_for_model(model) - except KeyError: - model = "cl100k_base" - encoding = tiktoken_.get_encoding(model) - - if model.startswith("gpt-3.5-turbo"): - # every message follows {role/name}\n{content}\n - tokens_per_message = 4 - # if there's a name, the role is omitted - tokens_per_name = -1 - elif model.startswith("gpt-4"): - tokens_per_message = 3 - tokens_per_name = 1 - else: - raise NotImplementedError( - f"get_num_tokens_from_messages() is not presently implemented " - f"for model {model}." - "See https://github.com/openai/openai-python/blob/main/chatml.md for " - "information on how messages are converted to tokens." - ) - num_tokens = 0 - for m in messages: - message = _convert_message_to_dict(m) - num_tokens += tokens_per_message - for key, value in message.items(): - if key == "function_call": - for f_key, f_value in value.items(): - num_tokens += len(encoding.encode(f_key)) - num_tokens += len(encoding.encode(f_value)) - else: - num_tokens += len(encoding.encode(value)) - - if key == "name": - num_tokens += tokens_per_name - # every reply is primed with assistant - num_tokens += 3 - - if kwargs.get('functions'): - for function in kwargs.get('functions'): - num_tokens += len(encoding.encode('name')) - num_tokens += len(encoding.encode(function.get("name"))) - num_tokens += len(encoding.encode('description')) - num_tokens += len(encoding.encode(function.get("description"))) - parameters = function.get("parameters") - num_tokens += len(encoding.encode('parameters')) - if 'title' in parameters: - num_tokens += len(encoding.encode('title')) - num_tokens += len(encoding.encode(parameters.get("title"))) - num_tokens += len(encoding.encode('type')) - num_tokens += len(encoding.encode(parameters.get("type"))) - if 'properties' in parameters: - num_tokens += len(encoding.encode('properties')) - for key, value in parameters.get('properties').items(): - num_tokens += len(encoding.encode(key)) - for field_key, field_value in value.items(): - num_tokens += len(encoding.encode(field_key)) - if field_key == 'enum': - for enum_field in field_value: - num_tokens += 3 - num_tokens += len(encoding.encode(enum_field)) - else: - num_tokens += len(encoding.encode(field_key)) - num_tokens += len(encoding.encode(str(field_value))) - if 'required' in parameters: - num_tokens += len(encoding.encode('required')) - for required_field in parameters['required']: - num_tokens += 3 - num_tokens += len(encoding.encode(required_field)) - - return num_tokens diff --git a/api/core/agent/agent/structured_chat.py b/api/core/agent/agent/structured_chat.py deleted file mode 100644 index e1be624204..0000000000 --- a/api/core/agent/agent/structured_chat.py +++ /dev/null @@ -1,306 +0,0 @@ -import re -from collections.abc import Sequence -from typing import Any, Optional, Union, cast - -from langchain import BasePromptTemplate, PromptTemplate -from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent -from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE -from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX -from langchain.callbacks.base import BaseCallbackManager -from langchain.callbacks.manager import Callbacks -from langchain.memory.prompt import SUMMARY_PROMPT -from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate -from langchain.schema import ( - AgentAction, - AgentFinish, - AIMessage, - BaseMessage, - HumanMessage, - OutputParserException, - get_buffer_string, -) -from langchain.tools import BaseTool - -from core.agent.agent.agent_llm_callback import AgentLLMCallback -from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError -from core.chain.llm_chain import LLMChain -from core.entities.application_entities import ModelConfigEntity -from core.entities.message_entities import lc_messages_to_prompt_messages - -FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). -The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. -Valid "action" values: "Final Answer" or {tool_names} - -Provide only ONE action per $JSON_BLOB, as shown: - -``` -{{{{ - "action": $TOOL_NAME, - "action_input": $INPUT -}}}} -``` - -Follow this format: - -Question: input question to answer -Thought: consider previous and subsequent steps -Action: -``` -$JSON_BLOB -``` -Observation: action result -... (repeat Thought/Action/Observation N times) -Thought: I know what to respond -Action: -``` -{{{{ - "action": "Final Answer", - "action_input": "Final response to human" -}}}} -```""" - - -class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): - moving_summary_buffer: str = "" - moving_summary_index: int = 0 - summary_model_config: ModelConfigEntity = None - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - - def should_use_agent(self, query: str): - """ - return should use agent - Using the ReACT mode to determine whether an agent is needed is costly, - so it's better to just use an Agent for reasoning, which is cheaper. - - :param query: - :return: - """ - return True - - def plan( - self, - intermediate_steps: list[tuple[AgentAction, str]], - callbacks: Callbacks = None, - **kwargs: Any, - ) -> Union[AgentAction, AgentFinish]: - """Given input, decided what to do. - - Args: - intermediate_steps: Steps the LLM has taken to date, - along with observatons - callbacks: Callbacks to run. - **kwargs: User inputs. - - Returns: - Action specifying what tool to use. - """ - full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) - prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)]) - - messages = [] - if prompts: - messages = prompts[0].to_messages() - - prompt_messages = lc_messages_to_prompt_messages(messages) - - rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages) - if rest_tokens < 0: - full_inputs = self.summarize_messages(intermediate_steps, **kwargs) - - try: - full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) - except Exception as e: - raise e - - try: - agent_decision = self.output_parser.parse(full_output) - if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset': - tool_inputs = agent_decision.tool_input - if isinstance(tool_inputs, dict) and 'query' in tool_inputs: - tool_inputs['query'] = kwargs['input'] - agent_decision.tool_input = tool_inputs - return agent_decision - except OutputParserException: - return AgentFinish({"output": "I'm sorry, the answer of model is invalid, " - "I don't know how to respond to that."}, "") - - def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs): - if len(intermediate_steps) >= 2 and self.summary_model_config: - should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1] - should_summary_messages = [AIMessage(content=observation) - for _, observation in should_summary_intermediate_steps] - if self.moving_summary_index == 0: - should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input"))) - - self.moving_summary_index = len(intermediate_steps) - else: - error_msg = "Exceeded LLM tokens limit, stopped." - raise ExceededLLMTokensLimitError(error_msg) - - if self.moving_summary_buffer and 'chat_history' in kwargs: - kwargs["chat_history"].pop() - - self.moving_summary_buffer = self.predict_new_summary( - messages=should_summary_messages, - existing_summary=self.moving_summary_buffer - ) - - if 'chat_history' in kwargs: - kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer)) - - return self.get_full_inputs([intermediate_steps[-1]], **kwargs) - - def predict_new_summary( - self, messages: list[BaseMessage], existing_summary: str - ) -> str: - new_lines = get_buffer_string( - messages, - human_prefix="Human", - ai_prefix="AI", - ) - - chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT) - return chain.predict(summary=existing_summary, new_lines=new_lines) - - @classmethod - def create_prompt( - cls, - tools: Sequence[BaseTool], - prefix: str = PREFIX, - suffix: str = SUFFIX, - human_message_template: str = HUMAN_MESSAGE_TEMPLATE, - format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[list[str]] = None, - memory_prompts: Optional[list[BasePromptTemplate]] = None, - ) -> BasePromptTemplate: - tool_strings = [] - for tool in tools: - args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args))) - tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}") - formatted_tools = "\n".join(tool_strings) - tool_names = ", ".join([('"' + tool.name + '"') for tool in tools]) - format_instructions = format_instructions.format(tool_names=tool_names) - template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) - if input_variables is None: - input_variables = ["input", "agent_scratchpad"] - _memory_prompts = memory_prompts or [] - messages = [ - SystemMessagePromptTemplate.from_template(template), - *_memory_prompts, - HumanMessagePromptTemplate.from_template(human_message_template), - ] - return ChatPromptTemplate(input_variables=input_variables, messages=messages) - - @classmethod - def create_completion_prompt( - cls, - tools: Sequence[BaseTool], - prefix: str = PREFIX, - format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[list[str]] = None, - ) -> PromptTemplate: - """Create prompt in the style of the zero shot agent. - - Args: - tools: List of tools the agent will have access to, used to format the - prompt. - prefix: String to put before the list of tools. - input_variables: List of input variables the final prompt will expect. - - Returns: - A PromptTemplate with the template assembled from the pieces here. - """ - suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:. -Question: {input} -Thought: {agent_scratchpad} -""" - - tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools]) - tool_names = ", ".join([tool.name for tool in tools]) - format_instructions = format_instructions.format(tool_names=tool_names) - template = "\n\n".join([prefix, tool_strings, format_instructions, suffix]) - if input_variables is None: - input_variables = ["input", "agent_scratchpad"] - return PromptTemplate(template=template, input_variables=input_variables) - - def _construct_scratchpad( - self, intermediate_steps: list[tuple[AgentAction, str]] - ) -> str: - agent_scratchpad = "" - for action, observation in intermediate_steps: - agent_scratchpad += action.log - agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}" - - if not isinstance(agent_scratchpad, str): - raise ValueError("agent_scratchpad should be of type string.") - if agent_scratchpad: - llm_chain = cast(LLMChain, self.llm_chain) - if llm_chain.model_config.mode == "chat": - return ( - f"This was your previous work " - f"(but I haven't seen any of it! I only see what " - f"you return as final answer):\n{agent_scratchpad}" - ) - else: - return agent_scratchpad - else: - return agent_scratchpad - - @classmethod - def from_llm_and_tools( - cls, - model_config: ModelConfigEntity, - tools: Sequence[BaseTool], - callback_manager: Optional[BaseCallbackManager] = None, - output_parser: Optional[AgentOutputParser] = None, - prefix: str = PREFIX, - suffix: str = SUFFIX, - human_message_template: str = HUMAN_MESSAGE_TEMPLATE, - format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[list[str]] = None, - memory_prompts: Optional[list[BasePromptTemplate]] = None, - agent_llm_callback: Optional[AgentLLMCallback] = None, - **kwargs: Any, - ) -> Agent: - """Construct an agent from an LLM and tools.""" - cls._validate_tools(tools) - if model_config.mode == "chat": - prompt = cls.create_prompt( - tools, - prefix=prefix, - suffix=suffix, - human_message_template=human_message_template, - format_instructions=format_instructions, - input_variables=input_variables, - memory_prompts=memory_prompts, - ) - else: - prompt = cls.create_completion_prompt( - tools, - prefix=prefix, - format_instructions=format_instructions, - input_variables=input_variables, - ) - llm_chain = LLMChain( - model_config=model_config, - prompt=prompt, - callback_manager=callback_manager, - agent_llm_callback=agent_llm_callback, - parameters={ - 'temperature': 0.2, - 'top_p': 0.3, - 'max_tokens': 1500 - } - ) - tool_names = [tool.name for tool in tools] - _output_parser = output_parser - return cls( - llm_chain=llm_chain, - allowed_tools=tool_names, - output_parser=_output_parser, - **kwargs, - ) diff --git a/api/core/app_runner/assistant_app_runner.py b/api/core/app_runner/assistant_app_runner.py index a4845d0ff1..d9a3447bda 100644 --- a/api/core/app_runner/assistant_app_runner.py +++ b/api/core/app_runner/assistant_app_runner.py @@ -1,4 +1,3 @@ -import json import logging from typing import cast @@ -15,7 +14,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large from core.moderation.base import ModerationException from core.tools.entities.tool_entities import ToolRuntimeVariablePool from extensions.ext_database import db -from models.model import App, Conversation, Message, MessageAgentThought, MessageChain +from models.model import App, Conversation, Message, MessageAgentThought from models.tools import ToolConversationVariables logger = logging.getLogger(__name__) @@ -173,11 +172,6 @@ class AssistantApplicationRunner(AppRunner): # convert db variables to tool variables tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables) - - message_chain = self._init_message_chain( - message=message, - query=query - ) # init model instance model_instance = ModelInstance( @@ -290,38 +284,6 @@ class AssistantApplicationRunner(AppRunner): 'pool': db_variables.variables }) - def _init_message_chain(self, message: Message, query: str) -> MessageChain: - """ - Init MessageChain - :param message: message - :param query: query - :return: - """ - message_chain = MessageChain( - message_id=message.id, - type="AgentExecutor", - input=json.dumps({ - "input": query - }) - ) - - db.session.add(message_chain) - db.session.commit() - - return message_chain - - def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None: - """ - Save MessageChain - :param message_chain: message chain - :param output_text: output text - :return: - """ - message_chain.output = json.dumps({ - "output": output_text - }) - db.session.commit() - def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity, message: Message) -> LLMUsage: """ diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py index e1972efb51..99df249ddf 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app_runner/basic_app_runner.py @@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity -from core.features.dataset_retrieval import DatasetRetrievalFeature +from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationException diff --git a/api/core/app_runner/generate_task_pipeline.py b/api/core/app_runner/generate_task_pipeline.py index 20e4bc7992..5fd635bc3b 100644 --- a/api/core/app_runner/generate_task_pipeline.py +++ b/api/core/app_runner/generate_task_pipeline.py @@ -175,7 +175,7 @@ class GenerateTaskPipeline: 'id': self._message.id, 'message_id': self._message.id, 'mode': self._conversation.mode, - 'answer': event.llm_result.message.content, + 'answer': self._task_state.llm_result.message.content, 'metadata': {}, 'created_at': int(self._message.created_at.timestamp()) } diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index a86afd817a..7498a07559 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -3,12 +3,12 @@ import logging from typing import Optional, cast import numpy as np -from langchain.embeddings.base import Embeddings from sqlalchemy.exc import IntegrityError from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.rag.datasource.entity.embedding import Embeddings from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper diff --git a/api/core/entities/agent_entities.py b/api/core/entities/agent_entities.py new file mode 100644 index 0000000000..0cdf8670c4 --- /dev/null +++ b/api/core/entities/agent_entities.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class PlanningStrategy(Enum): + ROUTER = 'router' + REACT_ROUTER = 'react_router' + REACT = 'react' + FUNCTION_CALL = 'function_call' diff --git a/api/core/features/agent_runner.py b/api/core/features/agent_runner.py deleted file mode 100644 index 7412d81281..0000000000 --- a/api/core/features/agent_runner.py +++ /dev/null @@ -1,199 +0,0 @@ -import logging -from typing import Optional, cast - -from langchain.tools import BaseTool - -from core.agent.agent.agent_llm_callback import AgentLLMCallback -from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy -from core.application_queue_manager import ApplicationQueueManager -from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler -from core.entities.application_entities import ( - AgentEntity, - AppOrchestrationConfigEntity, - InvokeFrom, - ModelConfigEntity, -) -from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.model_entities import ModelFeature, ModelType -from core.model_runtime.model_providers import model_provider_factory -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool -from extensions.ext_database import db -from models.dataset import Dataset -from models.model import Message - -logger = logging.getLogger(__name__) - - -class AgentRunnerFeature: - def __init__(self, tenant_id: str, - app_orchestration_config: AppOrchestrationConfigEntity, - model_config: ModelConfigEntity, - config: AgentEntity, - queue_manager: ApplicationQueueManager, - message: Message, - user_id: str, - agent_llm_callback: AgentLLMCallback, - callback: AgentLoopGatherCallbackHandler, - memory: Optional[TokenBufferMemory] = None,) -> None: - """ - Agent runner - :param tenant_id: tenant id - :param app_orchestration_config: app orchestration config - :param model_config: model config - :param config: dataset config - :param queue_manager: queue manager - :param message: message - :param user_id: user id - :param agent_llm_callback: agent llm callback - :param callback: callback - :param memory: memory - """ - self.tenant_id = tenant_id - self.app_orchestration_config = app_orchestration_config - self.model_config = model_config - self.config = config - self.queue_manager = queue_manager - self.message = message - self.user_id = user_id - self.agent_llm_callback = agent_llm_callback - self.callback = callback - self.memory = memory - - def run(self, query: str, - invoke_from: InvokeFrom) -> Optional[str]: - """ - Retrieve agent loop result. - :param query: query - :param invoke_from: invoke from - :return: - """ - provider = self.config.provider - model = self.config.model - tool_configs = self.config.tools - - # check model is support tool calling - provider_instance = model_provider_factory.get_provider_instance(provider=provider) - model_type_instance = provider_instance.get_model_instance(ModelType.LLM) - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - # get model schema - model_schema = model_type_instance.get_model_schema( - model=model, - credentials=self.model_config.credentials - ) - - if not model_schema: - return None - - planning_strategy = PlanningStrategy.REACT - features = model_schema.features - if features: - if ModelFeature.TOOL_CALL in features \ - or ModelFeature.MULTI_TOOL_CALL in features: - planning_strategy = PlanningStrategy.FUNCTION_CALL - - tools = self.to_tools( - tool_configs=tool_configs, - invoke_from=invoke_from, - callbacks=[self.callback, DifyStdOutCallbackHandler()], - ) - - if len(tools) == 0: - return None - - agent_configuration = AgentConfiguration( - strategy=planning_strategy, - model_config=self.model_config, - tools=tools, - memory=self.memory, - max_iterations=10, - max_execution_time=400.0, - early_stopping_method="generate", - agent_llm_callback=self.agent_llm_callback, - callbacks=[self.callback, DifyStdOutCallbackHandler()] - ) - - agent_executor = AgentExecutor(agent_configuration) - - try: - # check if should use agent - should_use_agent = agent_executor.should_use_agent(query) - if not should_use_agent: - return None - - result = agent_executor.run(query) - return result.output - except Exception as ex: - logger.exception("agent_executor run failed") - return None - - def to_dataset_retriever_tool(self, tool_config: dict, - invoke_from: InvokeFrom) \ - -> Optional[BaseTool]: - """ - A dataset tool is a tool that can be used to retrieve information from a dataset - :param tool_config: tool config - :param invoke_from: invoke from - """ - show_retrieve_source = self.app_orchestration_config.show_retrieve_source - - hit_callback = DatasetIndexToolCallbackHandler( - queue_manager=self.queue_manager, - app_id=self.message.app_id, - message_id=self.message.id, - user_id=self.user_id, - invoke_from=invoke_from - ) - - # get dataset from dataset id - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == self.tenant_id, - Dataset.id == tool_config.get("id") - ).first() - - # pass if dataset is not available - if not dataset: - return None - - # pass if dataset is not available - if (dataset and dataset.available_document_count == 0 - and dataset.available_document_count == 0): - return None - - # get retrieval model config - default_retrieval_model = { - 'search_method': 'semantic_search', - 'reranking_enable': False, - 'reranking_model': { - 'reranking_provider_name': '', - 'reranking_model_name': '' - }, - 'top_k': 2, - 'score_threshold_enabled': False - } - - retrieval_model_config = dataset.retrieval_model \ - if dataset.retrieval_model else default_retrieval_model - - # get top k - top_k = retrieval_model_config['top_k'] - - # get score threshold - score_threshold = None - score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled") - if score_threshold_enabled: - score_threshold = retrieval_model_config.get("score_threshold") - - tool = DatasetRetrieverTool.from_dataset( - dataset=dataset, - top_k=top_k, - score_threshold=score_threshold, - hit_callbacks=[hit_callback], - return_resource=show_retrieve_source, - retriever_from=invoke_from.to_source() - ) - - return tool \ No newline at end of file diff --git a/api/core/features/annotation_reply.py b/api/core/features/annotation_reply.py index e1b64cf73f..fd516e465f 100644 --- a/api/core/features/annotation_reply.py +++ b/api/core/features/annotation_reply.py @@ -59,7 +59,7 @@ class AnnotationReplyFeature: documents = vector.search_by_vector( query=query, - k=1, + top_k=1, score_threshold=score_threshold, filter={ 'group_id': [dataset.id] diff --git a/api/core/features/assistant_base_runner.py b/api/core/features/assistant_base_runner.py index c4a5767b04..2a4ae7e135 100644 --- a/api/core/features/assistant_base_runner.py +++ b/api/core/features/assistant_base_runner.py @@ -606,36 +606,42 @@ class BaseAssistantApplicationRunner(AppRunner): for message in messages: result.append(UserPromptMessage(content=message.query)) agent_thoughts: list[MessageAgentThought] = message.agent_thoughts - for agent_thought in agent_thoughts: - tools = agent_thought.tool - if tools: - tools = tools.split(';') - tool_calls: list[AssistantPromptMessage.ToolCall] = [] - tool_call_response: list[ToolPromptMessage] = [] - tool_inputs = json.loads(agent_thought.tool_input) - for tool in tools: - # generate a uuid for tool call - tool_call_id = str(uuid.uuid4()) - tool_calls.append(AssistantPromptMessage.ToolCall( - id=tool_call_id, - type='function', - function=AssistantPromptMessage.ToolCall.ToolCallFunction( + if agent_thoughts: + for agent_thought in agent_thoughts: + tools = agent_thought.tool + if tools: + tools = tools.split(';') + tool_calls: list[AssistantPromptMessage.ToolCall] = [] + tool_call_response: list[ToolPromptMessage] = [] + tool_inputs = json.loads(agent_thought.tool_input) + for tool in tools: + # generate a uuid for tool call + tool_call_id = str(uuid.uuid4()) + tool_calls.append(AssistantPromptMessage.ToolCall( + id=tool_call_id, + type='function', + function=AssistantPromptMessage.ToolCall.ToolCallFunction( + name=tool, + arguments=json.dumps(tool_inputs.get(tool, {})), + ) + )) + tool_call_response.append(ToolPromptMessage( + content=agent_thought.observation, name=tool, - arguments=json.dumps(tool_inputs.get(tool, {})), - ) - )) - tool_call_response.append(ToolPromptMessage( - content=agent_thought.observation, - name=tool, - tool_call_id=tool_call_id, - )) + tool_call_id=tool_call_id, + )) - result.extend([ - AssistantPromptMessage( - content=agent_thought.thought, - tool_calls=tool_calls, - ), - *tool_call_response - ]) + result.extend([ + AssistantPromptMessage( + content=agent_thought.thought, + tool_calls=tool_calls, + ), + *tool_call_response + ]) + if not tools: + result.append(AssistantPromptMessage(content=agent_thought.thought)) + else: + if message.answer: + result.append(AssistantPromptMessage(content=message.answer)) return result \ No newline at end of file diff --git a/api/core/features/assistant_cot_runner.py b/api/core/features/assistant_cot_runner.py index aa4a6797cd..809834c8cb 100644 --- a/api/core/features/assistant_cot_runner.py +++ b/api/core/features/assistant_cot_runner.py @@ -154,7 +154,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): thought='', action_str='', observation='', - action=None + action=None, ) # publish agent thought if it's first iteration @@ -469,7 +469,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): thought=message.content, action_str='', action=None, - observation=None + observation=None, ) if message.tool_calls: try: @@ -484,7 +484,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): elif isinstance(message, ToolPromptMessage): if current_scratchpad: current_scratchpad.observation = message.content - + return agent_scratchpad def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"], @@ -607,6 +607,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): prompt_message.content = system_message overridden = True break + + # convert tool prompt messages to user prompt messages + for idx, prompt_message in enumerate(prompt_messages): + if isinstance(prompt_message, ToolPromptMessage): + prompt_messages[idx] = UserPromptMessage( + content=prompt_message.content + ) if not overridden: prompt_messages.insert(0, SystemPromptMessage( diff --git a/api/core/third_party/langchain/llms/__init__.py b/api/core/features/dataset_retrieval/__init__.py similarity index 100% rename from api/core/third_party/langchain/llms/__init__.py rename to api/core/features/dataset_retrieval/__init__.py diff --git a/api/core/third_party/spark/__init__.py b/api/core/features/dataset_retrieval/agent/__init__.py similarity index 100% rename from api/core/third_party/spark/__init__.py rename to api/core/features/dataset_retrieval/agent/__init__.py diff --git a/api/core/agent/agent/agent_llm_callback.py b/api/core/features/dataset_retrieval/agent/agent_llm_callback.py similarity index 100% rename from api/core/agent/agent/agent_llm_callback.py rename to api/core/features/dataset_retrieval/agent/agent_llm_callback.py diff --git a/api/core/third_party/langchain/llms/fake.py b/api/core/features/dataset_retrieval/agent/fake_llm.py similarity index 100% rename from api/core/third_party/langchain/llms/fake.py rename to api/core/features/dataset_retrieval/agent/fake_llm.py diff --git a/api/core/chain/llm_chain.py b/api/core/features/dataset_retrieval/agent/llm_chain.py similarity index 91% rename from api/core/chain/llm_chain.py rename to api/core/features/dataset_retrieval/agent/llm_chain.py index 86fb156292..e5155e15a0 100644 --- a/api/core/chain/llm_chain.py +++ b/api/core/features/dataset_retrieval/agent/llm_chain.py @@ -5,11 +5,11 @@ from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.schema import Generation, LLMResult from langchain.schema.language_model import BaseLanguageModel -from core.agent.agent.agent_llm_callback import AgentLLMCallback from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages +from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback +from core.features.dataset_retrieval.agent.fake_llm import FakeLLM from core.model_manager import ModelInstance -from core.third_party.langchain.llms.fake import FakeLLM class LLMChain(LCLLMChain): diff --git a/api/core/agent/agent/multi_dataset_router_agent.py b/api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py similarity index 98% rename from api/core/agent/agent/multi_dataset_router_agent.py rename to api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py index eb594c3d21..59923202fd 100644 --- a/api/core/agent/agent/multi_dataset_router_agent.py +++ b/api/core/features/dataset_retrieval/agent/multi_dataset_router_agent.py @@ -12,9 +12,9 @@ from pydantic import root_validator from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages +from core.features.dataset_retrieval.agent.fake_llm import FakeLLM from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import PromptMessageTool -from core.third_party.langchain.llms.fake import FakeLLM class MultiDatasetRouterAgent(OpenAIFunctionsAgent): diff --git a/api/core/data_loader/file_extractor.py b/api/core/features/dataset_retrieval/agent/output_parser/__init__.py similarity index 100% rename from api/core/data_loader/file_extractor.py rename to api/core/features/dataset_retrieval/agent/output_parser/__init__.py diff --git a/api/core/agent/agent/output_parser/structured_chat.py b/api/core/features/dataset_retrieval/agent/output_parser/structured_chat.py similarity index 100% rename from api/core/agent/agent/output_parser/structured_chat.py rename to api/core/features/dataset_retrieval/agent/output_parser/structured_chat.py diff --git a/api/core/agent/agent/structed_multi_dataset_router_agent.py b/api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py similarity index 99% rename from api/core/agent/agent/structed_multi_dataset_router_agent.py rename to api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py index e104bb01f9..e69302bfd6 100644 --- a/api/core/agent/agent/structed_multi_dataset_router_agent.py +++ b/api/core/features/dataset_retrieval/agent/structed_multi_dataset_router_agent.py @@ -12,8 +12,8 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy from langchain.schema import AgentAction, AgentFinish, OutputParserException from langchain.tools import BaseTool -from core.chain.llm_chain import LLMChain from core.entities.application_entities import ModelConfigEntity +from core.features.dataset_retrieval.agent.llm_chain import LLMChain FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. diff --git a/api/core/agent/agent_executor.py b/api/core/features/dataset_retrieval/agent_based_dataset_executor.py similarity index 69% rename from api/core/agent/agent_executor.py rename to api/core/features/dataset_retrieval/agent_based_dataset_executor.py index 70fe00ee13..588ccc91f5 100644 --- a/api/core/agent/agent_executor.py +++ b/api/core/features/dataset_retrieval/agent_based_dataset_executor.py @@ -1,4 +1,3 @@ -import enum import logging from typing import Optional, Union @@ -8,14 +7,13 @@ from langchain.callbacks.manager import Callbacks from langchain.tools import BaseTool from pydantic import BaseModel, Extra -from core.agent.agent.agent_llm_callback import AgentLLMCallback -from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent -from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent -from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser -from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent -from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent +from core.entities.agent_entities import PlanningStrategy from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import prompt_messages_to_lc_messages +from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback +from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent +from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser +from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent from core.helper import moderation from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.errors.invoke import InvokeError @@ -23,13 +21,6 @@ from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import Datas from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool -class PlanningStrategy(str, enum.Enum): - ROUTER = 'router' - REACT_ROUTER = 'react_router' - REACT = 'react' - FUNCTION_CALL = 'function_call' - - class AgentConfiguration(BaseModel): strategy: PlanningStrategy model_config: ModelConfigEntity @@ -62,28 +53,7 @@ class AgentExecutor: self.agent = self._init_agent() def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]: - if self.configuration.strategy == PlanningStrategy.REACT: - agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools( - model_config=self.configuration.model_config, - tools=self.configuration.tools, - output_parser=StructuredChatOutputParser(), - summary_model_config=self.configuration.summary_model_config - if self.configuration.summary_model_config else None, - agent_llm_callback=self.configuration.agent_llm_callback, - verbose=True - ) - elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL: - agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools( - model_config=self.configuration.model_config, - tools=self.configuration.tools, - extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages()) - if self.configuration.memory else None, # used for read chat histories memory - summary_model_config=self.configuration.summary_model_config - if self.configuration.summary_model_config else None, - agent_llm_callback=self.configuration.agent_llm_callback, - verbose=True - ) - elif self.configuration.strategy == PlanningStrategy.ROUTER: + if self.configuration.strategy == PlanningStrategy.ROUTER: self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool) or isinstance(t, DatasetMultiRetrieverTool)] diff --git a/api/core/features/dataset_retrieval.py b/api/core/features/dataset_retrieval/dataset_retrieval.py similarity index 97% rename from api/core/features/dataset_retrieval.py rename to api/core/features/dataset_retrieval/dataset_retrieval.py index 488a8ca8d0..3e54d8644d 100644 --- a/api/core/features/dataset_retrieval.py +++ b/api/core/features/dataset_retrieval/dataset_retrieval.py @@ -2,9 +2,10 @@ from typing import Optional, cast from langchain.tools import BaseTool -from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.entities.agent_entities import PlanningStrategy from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity +from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 58b551f295..880a30cdf4 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -104,37 +104,17 @@ class HostingConfiguration: if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"): hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200")) + trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS") trial_quota = TrialHostingQuota( quota_limit=hosted_quota_limit, - restrict_models=[ - RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM), - RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM), - RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM), - RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM), - RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM), - RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM), - RestrictModel(model="gpt-3.5-turbo-0125", model_type=ModelType.LLM), - RestrictModel(model="text-davinci-003", model_type=ModelType.LLM), - ] + restrict_models=trial_models ) quotas.append(trial_quota) if app_config.get("HOSTED_OPENAI_PAID_ENABLED"): + paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS") paid_quota = PaidHostingQuota( - restrict_models=[ - RestrictModel(model="gpt-4", model_type=ModelType.LLM), - RestrictModel(model="gpt-4-turbo-preview", model_type=ModelType.LLM), - RestrictModel(model="gpt-4-1106-preview", model_type=ModelType.LLM), - RestrictModel(model="gpt-4-0125-preview", model_type=ModelType.LLM), - RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM), - RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM), - RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM), - RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM), - RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM), - RestrictModel(model="gpt-3.5-turbo-0125", model_type=ModelType.LLM), - RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM), - RestrictModel(model="text-davinci-003", model_type=ModelType.LLM), - ] + restrict_models=paid_models ) quotas.append(paid_quota) @@ -258,3 +238,11 @@ class HostingConfiguration: return HostedModerationConfig( enabled=False ) + + @staticmethod + def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]: + models_str = app_config.get(env_var) + models_list = models_str.split(",") if models_str else [] + return [RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) for model_name in models_list if + model_name.strip()] + diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index c8a2e09443..f5ea49bb5e 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -365,8 +365,9 @@ class IndexingRunner: notion_info={ "notion_workspace_id": data_source_info['notion_workspace_id'], "notion_obj_id": data_source_info['notion_page_id'], - "notion_page_type": data_source_info['notion_page_type'], - "document": dataset_document + "notion_page_type": data_source_info['type'], + "document": dataset_document, + "tenant_id": dataset_document.tenant_id }, document_model=dataset_document.doc_form ) @@ -664,6 +665,7 @@ class IndexingRunner: ) # load index index_processor.load(dataset, chunk_documents) + db.session.add(dataset) document_ids = [document.metadata['doc_id'] for document in chunk_documents] db.session.query(DocumentSegment).filter( diff --git a/api/core/model_runtime/README_CN.md b/api/core/model_runtime/README_CN.md index 6950cdc0c7..3664fa2ca3 100644 --- a/api/core/model_runtime/README_CN.md +++ b/api/core/model_runtime/README_CN.md @@ -20,7 +20,7 @@  - 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./schema.md)。 + 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。 - 可选择的模型列表展示 @@ -86,4 +86,4 @@ Model Runtime 分三层:  ### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md) -你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。 \ No newline at end of file +你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。 diff --git a/api/core/model_runtime/entities/defaults.py b/api/core/model_runtime/entities/defaults.py index 856f4ce7d1..776f6802e6 100644 --- a/api/core/model_runtime/entities/defaults.py +++ b/api/core/model_runtime/entities/defaults.py @@ -81,5 +81,18 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { 'min': 1, 'max': 2048, 'precision': 0, + }, + DefaultParameterName.RESPONSE_FORMAT: { + 'label': { + 'en_US': 'Response Format', + 'zh_Hans': '回复格式', + }, + 'type': 'string', + 'help': { + 'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.', + 'zh_Hans': '设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等', + }, + 'required': False, + 'options': ['JSON', 'XML'], } } \ No newline at end of file diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index e35be27f86..52c2d66f9f 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -91,6 +91,7 @@ class DefaultParameterName(Enum): PRESENCE_PENALTY = "presence_penalty" FREQUENCY_PENALTY = "frequency_penalty" MAX_TOKENS = "max_tokens" + RESPONSE_FORMAT = "response_format" @classmethod def value_of(cls, value: Any) -> 'DefaultParameterName': diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index a9f7a539e2..026e6eca21 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -262,23 +262,23 @@ class AIModel(ABC): try: default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template) default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) - if not parameter_rule.max: + if not parameter_rule.max and 'max' in default_parameter_rule: parameter_rule.max = default_parameter_rule['max'] - if not parameter_rule.min: + if not parameter_rule.min and 'min' in default_parameter_rule: parameter_rule.min = default_parameter_rule['min'] - if not parameter_rule.precision: + if not parameter_rule.default and 'default' in default_parameter_rule: parameter_rule.default = default_parameter_rule['default'] - if not parameter_rule.precision: + if not parameter_rule.precision and 'precision' in default_parameter_rule: parameter_rule.precision = default_parameter_rule['precision'] - if not parameter_rule.required: + if not parameter_rule.required and 'required' in default_parameter_rule: parameter_rule.required = default_parameter_rule['required'] - if not parameter_rule.help: + if not parameter_rule.help and 'help' in default_parameter_rule: parameter_rule.help = I18nObject( en_US=default_parameter_rule['help']['en_US'], ) - if not parameter_rule.help.en_US: + if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']): parameter_rule.help.en_US = default_parameter_rule['help']['en_US'] - if not parameter_rule.help.zh_Hans: + if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']): parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US']) except ValueError: pass diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 1f7edd245f..4b546a5356 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -9,7 +9,13 @@ from typing import Optional, Union from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.callbacks.logging_callback import LoggingCallback from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) from core.model_runtime.entities.model_entities import ( ModelPropertyKey, ModelType, @@ -74,7 +80,20 @@ class LargeLanguageModel(AIModel): ) try: - result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + if "response_format" in model_parameters: + result = self._code_block_mode_wrapper( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + callbacks=callbacks + ) + else: + result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) except Exception as e: self._trigger_invoke_error_callbacks( model=model, @@ -120,6 +139,239 @@ class LargeLanguageModel(AIModel): return result + def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, + callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + """ + Code block mode wrapper, ensure the response is a code block with output markdown quote + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :param callbacks: callbacks + :return: full response or stream response chunk generator result + """ + + block_prompts = """You should always follow the instructions and output a valid {{block}} object. +The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. + + +{{instructions}} + +""" + + code_block = model_parameters.get("response_format", "") + if not code_block: + return self._invoke( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user + ) + + model_parameters.pop("response_format") + stop = stop or [] + stop.extend(["\n```", "```\n"]) + block_prompts = block_prompts.replace("{{block}}", code_block) + + # check if there is a system message + if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + # override the system message + prompt_messages[0] = SystemPromptMessage( + content=block_prompts + .replace("{{instructions}}", prompt_messages[0].content) + ) + else: + # insert the system message + prompt_messages.insert(0, SystemPromptMessage( + content=block_prompts + .replace("{{instructions}}", f"Please output a valid {code_block} object.") + )) + + if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): + # add ```JSON\n to the last message + prompt_messages[-1].content += f"\n```{code_block}\n" + else: + # append a user message + prompt_messages.append(UserPromptMessage( + content=f"```{code_block}\n" + )) + + response = self._invoke( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user + ) + + if isinstance(response, Generator): + first_chunk = next(response) + def new_generator(): + yield first_chunk + yield from response + + if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"): + return self._code_block_mode_stream_processor_with_backtick( + model=model, + prompt_messages=prompt_messages, + input_generator=new_generator() + ) + else: + return self._code_block_mode_stream_processor( + model=model, + prompt_messages=prompt_messages, + input_generator=new_generator() + ) + + return response + + def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage], + input_generator: Generator[LLMResultChunk, None, None] + ) -> Generator[LLMResultChunk, None, None]: + """ + Code block mode stream processor, ensure the response is a code block with output markdown quote + + :param model: model name + :param prompt_messages: prompt messages + :param input_generator: input generator + :return: output generator + """ + state = "normal" + backtick_count = 0 + for piece in input_generator: + if piece.delta.message.content: + content = piece.delta.message.content + piece.delta.message.content = "" + yield piece + piece = content + else: + yield piece + continue + new_piece = "" + for char in piece: + if state == "normal": + if char == "`": + state = "in_backticks" + backtick_count = 1 + else: + new_piece += char + elif state == "in_backticks": + if char == "`": + backtick_count += 1 + if backtick_count == 3: + state = "skip_content" + backtick_count = 0 + else: + new_piece += "`" * backtick_count + char + state = "normal" + backtick_count = 0 + elif state == "skip_content": + if char.isspace(): + state = "normal" + + if new_piece: + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=new_piece, + tool_calls=[] + ), + ) + ) + + def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list, + input_generator: Generator[LLMResultChunk, None, None]) \ + -> Generator[LLMResultChunk, None, None]: + """ + Code block mode stream processor, ensure the response is a code block with output markdown quote. + This version skips the language identifier that follows the opening triple backticks. + + :param model: model name + :param prompt_messages: prompt messages + :param input_generator: input generator + :return: output generator + """ + state = "search_start" + backtick_count = 0 + + for piece in input_generator: + if piece.delta.message.content: + content = piece.delta.message.content + # Reset content to ensure we're only processing and yielding the relevant parts + piece.delta.message.content = "" + # Yield a piece with cleared content before processing it to maintain the generator structure + yield piece + piece = content + else: + # Yield pieces without content directly + yield piece + continue + + if state == "done": + continue + + new_piece = "" + for char in piece: + if state == "search_start": + if char == "`": + backtick_count += 1 + if backtick_count == 3: + state = "skip_language" + backtick_count = 0 + else: + backtick_count = 0 + elif state == "skip_language": + # Skip everything until the first newline, marking the end of the language identifier + if char == "\n": + state = "in_code_block" + elif state == "in_code_block": + if char == "`": + backtick_count += 1 + if backtick_count == 3: + state = "done" + break + else: + if backtick_count > 0: + # If backticks were counted but we're still collecting content, it was a false start + new_piece += "`" * backtick_count + backtick_count = 0 + new_piece += char + + elif state == "done": + break + + if new_piece: + # Only yield content collected within the code block + yield LLMResultChunk( + model=model, + prompt_messages=prompt_messages, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=new_piece, + tool_calls=[] + ), + ) + ) + def _invoke_result_generator(self, model: str, result: Generator, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, @@ -204,7 +456,7 @@ class LargeLanguageModel(AIModel): :return: full response or stream response chunk generator result """ raise NotImplementedError - + @abstractmethod def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml index b2c6518395..8c878d67d8 100644 --- a/api/core/model_runtime/model_providers/_position.yaml +++ b/api/core/model_runtime/model_providers/_position.yaml @@ -6,6 +6,7 @@ - bedrock - togetherai - ollama +- mistralai - replicate - huggingface_hub - zhipuai diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-2.1.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-2.1.yaml index 08beef3caa..6707c34594 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-2.1.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-2.1.yaml @@ -27,6 +27,8 @@ parameter_rules: default: 4096 min: 1 max: 4096 + - name: response_format + use_template: response_format pricing: input: '8.00' output: '24.00' diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml index 3c49067630..12faf60bc9 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-2.yaml @@ -27,6 +27,8 @@ parameter_rules: default: 4096 min: 1 max: 4096 + - name: response_format + use_template: response_format pricing: input: '8.00' output: '24.00' diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml index d44859faa3..25d32a09af 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-instant-1.yaml @@ -26,6 +26,8 @@ parameter_rules: default: 4096 min: 1 max: 4096 + - name: response_format + use_template: response_format pricing: input: '1.63' output: '5.51' diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index c743708896..00e5ef6fda 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -6,6 +6,7 @@ from anthropic import Anthropic, Stream from anthropic.types import Completion, completion_create_params from httpx import Timeout +from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -25,9 +26,16 @@ from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. +The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. + + +{{instructions}} + +""" class AnthropicLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, @@ -48,6 +56,53 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): """ # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) + + def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, + callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + """ + Code block mode wrapper for invoking large language model + """ + if 'response_format' in model_parameters and model_parameters['response_format']: + stop = stop or [] + self._transform_json_prompts( + model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, model_parameters['response_format'] + ) + model_parameters.pop('response_format') + + return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + + def _transform_json_prompts(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ + -> None: + """ + Transform json prompts + """ + if "```\n" not in stop: + stop.append("```\n") + + # check if there is a system message + if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + # override the system message + prompt_messages[0] = SystemPromptMessage( + content=ANTHROPIC_BLOCK_MODE_PROMPT + .replace("{{instructions}}", prompt_messages[0].content) + .replace("{{block}}", response_format) + ) + else: + # insert the system message + prompt_messages.insert(0, SystemPromptMessage( + content=ANTHROPIC_BLOCK_MODE_PROMPT + .replace("{{instructions}}", f"Please output a valid {response_format} object.") + .replace("{{block}}", response_format) + )) + + prompt_messages.append(AssistantPromptMessage( + content=f"```{response_format}\n" + )) def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml index 3b98e615e6..ffdc9c3659 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-pro.yaml @@ -27,6 +27,8 @@ parameter_rules: default: 2048 min: 1 max: 2048 + - name: response_format + use_template: response_format pricing: input: '0.00' output: '0.00' diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 686761ab5f..2feff8ebe9 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -31,6 +31,16 @@ from core.model_runtime.model_providers.__base.large_language_model import Large logger = logging.getLogger(__name__) +GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. +The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. + + +{{instructions}} + +""" + + class GoogleLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, @@ -53,7 +63,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): """ # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) - + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """ diff --git a/api/core/model_runtime/model_providers/mistralai/__init__.py b/api/core/model_runtime/model_providers/mistralai/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/mistralai/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/mistralai/_assets/icon_l_en.png new file mode 100644 index 0000000000..f019b1edce Binary files /dev/null and b/api/core/model_runtime/model_providers/mistralai/_assets/icon_l_en.png differ diff --git a/api/core/model_runtime/model_providers/mistralai/_assets/icon_s_en.png b/api/core/model_runtime/model_providers/mistralai/_assets/icon_s_en.png new file mode 100644 index 0000000000..de199b4317 Binary files /dev/null and b/api/core/model_runtime/model_providers/mistralai/_assets/icon_s_en.png differ diff --git a/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml b/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml new file mode 100644 index 0000000000..5e74dc5dfe --- /dev/null +++ b/api/core/model_runtime/model_providers/mistralai/llm/_position.yaml @@ -0,0 +1,5 @@ +- open-mistral-7b +- open-mixtral-8x7b +- mistral-small-latest +- mistral-medium-latest +- mistral-large-latest diff --git a/api/core/model_runtime/model_providers/mistralai/llm/llm.py b/api/core/model_runtime/model_providers/mistralai/llm/llm.py new file mode 100644 index 0000000000..01ed8010de --- /dev/null +++ b/api/core/model_runtime/model_providers/mistralai/llm/llm.py @@ -0,0 +1,31 @@ +from collections.abc import Generator +from typing import Optional, Union + +from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + + +class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel): + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + + self._add_custom_parameters(credentials) + + # mistral dose not support user/stop arguments + stop = [] + user = None + + return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + super().validate_credentials(model, credentials) + + @staticmethod + def _add_custom_parameters(credentials: dict) -> None: + credentials['mode'] = 'chat' + credentials['endpoint_url'] = 'https://api.mistral.ai/v1' diff --git a/api/core/model_runtime/model_providers/mistralai/llm/mistral-large-latest.yaml b/api/core/model_runtime/model_providers/mistralai/llm/mistral-large-latest.yaml new file mode 100644 index 0000000000..b729012c40 --- /dev/null +++ b/api/core/model_runtime/model_providers/mistralai/llm/mistral-large-latest.yaml @@ -0,0 +1,50 @@ +model: mistral-large-latest +label: + zh_Hans: mistral-large-latest + en_US: mistral-large-latest +model_type: llm +features: + - agent-thought +model_properties: + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.7 + min: 0 + max: 1 + - name: top_p + use_template: top_p + default: 1 + min: 0 + max: 1 + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 8000 + - name: safe_prompt + defulat: false + type: boolean + help: + en_US: Whether to inject a safety prompt before all conversations. + zh_Hans: 是否开启提示词审查 + label: + en_US: SafePrompt + zh_Hans: 提示词审查 + - name: random_seed + type: int + help: + en_US: The seed to use for random sampling. If set, different calls will generate deterministic results. + zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定 + label: + en_US: RandomSeed + zh_Hans: 随机数种子 + default: 0 + min: 0 + max: 2147483647 +pricing: + input: '0.008' + output: '0.024' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/mistralai/llm/mistral-medium-latest.yaml b/api/core/model_runtime/model_providers/mistralai/llm/mistral-medium-latest.yaml new file mode 100644 index 0000000000..6e586b4843 --- /dev/null +++ b/api/core/model_runtime/model_providers/mistralai/llm/mistral-medium-latest.yaml @@ -0,0 +1,50 @@ +model: mistral-medium-latest +label: + zh_Hans: mistral-medium-latest + en_US: mistral-medium-latest +model_type: llm +features: + - agent-thought +model_properties: + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.7 + min: 0 + max: 1 + - name: top_p + use_template: top_p + default: 1 + min: 0 + max: 1 + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 8000 + - name: safe_prompt + defulat: false + type: boolean + help: + en_US: Whether to inject a safety prompt before all conversations. + zh_Hans: 是否开启提示词审查 + label: + en_US: SafePrompt + zh_Hans: 提示词审查 + - name: random_seed + type: int + help: + en_US: The seed to use for random sampling. If set, different calls will generate deterministic results. + zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定 + label: + en_US: RandomSeed + zh_Hans: 随机数种子 + default: 0 + min: 0 + max: 2147483647 +pricing: + input: '0.0027' + output: '0.0081' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/mistralai/llm/mistral-small-latest.yaml b/api/core/model_runtime/model_providers/mistralai/llm/mistral-small-latest.yaml new file mode 100644 index 0000000000..4e7e6147f5 --- /dev/null +++ b/api/core/model_runtime/model_providers/mistralai/llm/mistral-small-latest.yaml @@ -0,0 +1,50 @@ +model: mistral-small-latest +label: + zh_Hans: mistral-small-latest + en_US: mistral-small-latest +model_type: llm +features: + - agent-thought +model_properties: + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.7 + min: 0 + max: 1 + - name: top_p + use_template: top_p + default: 1 + min: 0 + max: 1 + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 8000 + - name: safe_prompt + defulat: false + type: boolean + help: + en_US: Whether to inject a safety prompt before all conversations. + zh_Hans: 是否开启提示词审查 + label: + en_US: SafePrompt + zh_Hans: 提示词审查 + - name: random_seed + type: int + help: + en_US: The seed to use for random sampling. If set, different calls will generate deterministic results. + zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定 + label: + en_US: RandomSeed + zh_Hans: 随机数种子 + default: 0 + min: 0 + max: 2147483647 +pricing: + input: '0.002' + output: '0.006' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/mistralai/llm/open-mistral-7b.yaml b/api/core/model_runtime/model_providers/mistralai/llm/open-mistral-7b.yaml new file mode 100644 index 0000000000..30454f7df2 --- /dev/null +++ b/api/core/model_runtime/model_providers/mistralai/llm/open-mistral-7b.yaml @@ -0,0 +1,50 @@ +model: open-mistral-7b +label: + zh_Hans: open-mistral-7b + en_US: open-mistral-7b +model_type: llm +features: + - agent-thought +model_properties: + context_size: 8000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.7 + min: 0 + max: 1 + - name: top_p + use_template: top_p + default: 1 + min: 0 + max: 1 + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 2048 + - name: safe_prompt + defulat: false + type: boolean + help: + en_US: Whether to inject a safety prompt before all conversations. + zh_Hans: 是否开启提示词审查 + label: + en_US: SafePrompt + zh_Hans: 提示词审查 + - name: random_seed + type: int + help: + en_US: The seed to use for random sampling. If set, different calls will generate deterministic results. + zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定 + label: + en_US: RandomSeed + zh_Hans: 随机数种子 + default: 0 + min: 0 + max: 2147483647 +pricing: + input: '0.00025' + output: '0.00025' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x7b.yaml b/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x7b.yaml new file mode 100644 index 0000000000..a35cf0a9ae --- /dev/null +++ b/api/core/model_runtime/model_providers/mistralai/llm/open-mixtral-8x7b.yaml @@ -0,0 +1,50 @@ +model: open-mixtral-8x7b +label: + zh_Hans: open-mixtral-8x7b + en_US: open-mixtral-8x7b +model_type: llm +features: + - agent-thought +model_properties: + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + default: 0.7 + min: 0 + max: 1 + - name: top_p + use_template: top_p + default: 1 + min: 0 + max: 1 + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 8000 + - name: safe_prompt + defulat: false + type: boolean + help: + en_US: Whether to inject a safety prompt before all conversations. + zh_Hans: 是否开启提示词审查 + label: + en_US: SafePrompt + zh_Hans: 提示词审查 + - name: random_seed + type: int + help: + en_US: The seed to use for random sampling. If set, different calls will generate deterministic results. + zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定 + label: + en_US: RandomSeed + zh_Hans: 随机数种子 + default: 0 + min: 0 + max: 2147483647 +pricing: + input: '0.0007' + output: '0.0007' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/mistralai/mistralai.py b/api/core/model_runtime/model_providers/mistralai/mistralai.py new file mode 100644 index 0000000000..f1d825f6c6 --- /dev/null +++ b/api/core/model_runtime/model_providers/mistralai/mistralai.py @@ -0,0 +1,30 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class MistralAIProvider(ModelProvider): + + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + + model_instance.validate_credentials( + model='open-mistral-7b', + credentials=credentials + ) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + raise ex diff --git a/api/core/model_runtime/model_providers/mistralai/mistralai.yaml b/api/core/model_runtime/model_providers/mistralai/mistralai.yaml new file mode 100644 index 0000000000..c9b4226ea6 --- /dev/null +++ b/api/core/model_runtime/model_providers/mistralai/mistralai.yaml @@ -0,0 +1,31 @@ +provider: mistralai +label: + en_US: MistralAI +description: + en_US: Models provided by MistralAI, such as open-mistral-7b and mistral-large-latest. + zh_Hans: MistralAI 提供的模型,例如 open-mistral-7b 和 mistral-large-latest。 +icon_small: + en_US: icon_s_en.png +icon_large: + en_US: icon_l_en.png +background: "#FFFFFF" +help: + title: + en_US: Get your API Key from MistralAI + zh_Hans: 从 MistralAI 获取 API Key + url: + en_US: https://console.mistral.ai/api-keys/ +supported_model_types: + - llm +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key diff --git a/api/core/model_runtime/model_providers/moonshot/llm/llm.py b/api/core/model_runtime/model_providers/moonshot/llm/llm.py index 5db3e2827b..05feee877e 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/llm.py +++ b/api/core/model_runtime/model_providers/moonshot/llm/llm.py @@ -13,6 +13,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) + user = user[:32] if user else None return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def validate_credentials(self, model: str, credentials: dict) -> None: diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0125.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0125.yaml index 3e40db01f9..c1602b2efc 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0125.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0125.yaml @@ -24,6 +24,18 @@ parameter_rules: default: 512 min: 1 max: 4096 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.0005' output: '0.0015' diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml index 6d519cbee6..31dc53e89f 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-0613.yaml @@ -24,6 +24,8 @@ parameter_rules: default: 512 min: 1 max: 4096 + - name: response_format + use_template: response_format pricing: input: '0.0015' output: '0.002' diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml index 499792e39d..56ab965c39 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-1106.yaml @@ -24,6 +24,18 @@ parameter_rules: default: 512 min: 1 max: 4096 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.001' output: '0.002' diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml index a86bacb34f..4a0e2ef191 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k-0613.yaml @@ -24,6 +24,8 @@ parameter_rules: default: 512 min: 1 max: 16385 + - name: response_format + use_template: response_format pricing: input: '0.003' output: '0.004' diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml index 467041e842..3684c1945c 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-16k.yaml @@ -24,6 +24,8 @@ parameter_rules: default: 512 min: 1 max: 16385 + - name: response_format + use_template: response_format pricing: input: '0.003' output: '0.004' diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-instruct.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-instruct.yaml index 926ee05d97..ad831539e0 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-instruct.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo-instruct.yaml @@ -21,6 +21,8 @@ parameter_rules: default: 512 min: 1 max: 4096 + - name: response_format + use_template: response_format pricing: input: '0.0015' output: '0.002' diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml index fddf1836c4..4ffd31a814 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-3.5-turbo.yaml @@ -24,6 +24,18 @@ parameter_rules: default: 512 min: 1 max: 4096 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: response_format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object pricing: input: '0.001' output: '0.002' diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 2a1137d443..2ea65780f1 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -9,6 +9,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletio from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall from openai.types.chat.chat_completion_message import FunctionCall +from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -28,6 +29,14 @@ from core.model_runtime.model_providers.openai._common import _CommonOpenAI logger = logging.getLogger(__name__) +OPENAI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. +The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. + + +{{instructions}} + +""" class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): """ @@ -84,6 +93,131 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): user=user ) + def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, + callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + """ + Code block mode wrapper for invoking large language model + """ + # handle fine tune remote models + base_model = model + if model.startswith('ft:'): + base_model = model.split(':')[1] + + # get model mode + model_mode = self.get_model_mode(base_model, credentials) + + # transform response format + if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']: + stop = stop or [] + if model_mode == LLMMode.CHAT: + # chat model + self._transform_chat_json_prompts( + model=base_model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + response_format=model_parameters['response_format'] + ) + else: + self._transform_completion_json_prompts( + model=base_model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user, + response_format=model_parameters['response_format'] + ) + model_parameters.pop('response_format') + + return self._invoke( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user + ) + + def _transform_chat_json_prompts(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ + -> None: + """ + Transform json prompts + """ + if "```\n" not in stop: + stop.append("```\n") + if "\n```" not in stop: + stop.append("\n```") + + # check if there is a system message + if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + # override the system message + prompt_messages[0] = SystemPromptMessage( + content=OPENAI_BLOCK_MODE_PROMPT + .replace("{{instructions}}", prompt_messages[0].content) + .replace("{{block}}", response_format) + ) + prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n")) + else: + # insert the system message + prompt_messages.insert(0, SystemPromptMessage( + content=OPENAI_BLOCK_MODE_PROMPT + .replace("{{instructions}}", f"Please output a valid {response_format} object.") + .replace("{{block}}", response_format) + )) + prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}")) + + def _transform_completion_json_prompts(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ + -> None: + """ + Transform json prompts + """ + if "```\n" not in stop: + stop.append("```\n") + if "\n```" not in stop: + stop.append("\n```") + + # override the last user message + user_message = None + for i in range(len(prompt_messages) - 1, -1, -1): + if isinstance(prompt_messages[i], UserPromptMessage): + user_message = prompt_messages[i] + break + + if user_message: + if prompt_messages[i].content[-11:] == 'Assistant: ': + # now we are in the chat app, remove the last assistant message + prompt_messages[i].content = prompt_messages[i].content[:-11] + prompt_messages[i] = UserPromptMessage( + content=OPENAI_BLOCK_MODE_PROMPT + .replace("{{instructions}}", user_message.content) + .replace("{{block}}", response_format) + ) + prompt_messages[i].content += f"Assistant:\n```{response_format}\n" + else: + prompt_messages[i] = UserPromptMessage( + content=OPENAI_BLOCK_MODE_PROMPT + .replace("{{instructions}}", user_message.content) + .replace("{{block}}", response_format) + ) + prompt_messages[i].content += f"\n```{response_format}\n" + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """ diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index 7ae8b87764..405f93498e 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -13,6 +13,7 @@ from dashscope.common.error import ( ) from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry +from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -57,6 +58,88 @@ class TongyiLargeLanguageModel(LargeLanguageModel): """ # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) + + def _code_block_mode_wrapper(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + stream: bool = True, user: str | None = None, callbacks: list[Callback] = None) \ + -> LLMResult | Generator: + """ + Wrapper for code block mode + """ + block_prompts = """You should always follow the instructions and output a valid {{block}} object. +The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. + + +{{instructions}} + +""" + + code_block = model_parameters.get("response_format", "") + if not code_block: + return self._invoke( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user + ) + + model_parameters.pop("response_format") + stop = stop or [] + stop.extend(["\n```", "```\n"]) + block_prompts = block_prompts.replace("{{block}}", code_block) + + # check if there is a system message + if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + # override the system message + prompt_messages[0] = SystemPromptMessage( + content=block_prompts + .replace("{{instructions}}", prompt_messages[0].content) + ) + else: + # insert the system message + prompt_messages.insert(0, SystemPromptMessage( + content=block_prompts + .replace("{{instructions}}", f"Please output a valid {code_block} object.") + )) + + mode = self.get_model_mode(model, credentials) + if mode == LLMMode.CHAT: + if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): + # add ```JSON\n to the last message + prompt_messages[-1].content += f"\n```{code_block}\n" + else: + # append a user message + prompt_messages.append(UserPromptMessage( + content=f"```{code_block}\n" + )) + else: + prompt_messages.append(AssistantPromptMessage(content=f"```{code_block}\n")) + + response = self._invoke( + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=stream, + user=user + ) + + if isinstance(response, Generator): + return self._code_block_mode_stream_processor_with_backtick( + model=model, + prompt_messages=prompt_messages, + input_generator=response + ) + + return response def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: @@ -117,7 +200,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): """ extra_model_kwargs = {} if stop: - extra_model_kwargs['stop_sequences'] = stop + extra_model_kwargs['stop'] = stop # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) @@ -131,7 +214,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel): params = { 'model': model, **model_parameters, - **credentials_kwargs + **credentials_kwargs, + **extra_model_kwargs, } mode = self.get_model_mode(model, credentials) diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-1201.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-1201.yaml index 11eca82736..3461863e67 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-1201.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-1201.yaml @@ -57,3 +57,5 @@ parameter_rules: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. required: false + - name: response_format + use_template: response_format diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-longcontext.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-longcontext.yaml index 58aab20004..9089c5904a 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-longcontext.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max-longcontext.yaml @@ -57,3 +57,5 @@ parameter_rules: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. required: false + - name: response_format + use_template: response_format diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max.yaml index ccfa2356c3..eb1e8ac09b 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-max.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-max.yaml @@ -57,3 +57,5 @@ parameter_rules: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. required: false + - name: response_format + use_template: response_format diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus.yaml index 1dd13a1a26..83640371f9 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-plus.yaml @@ -56,6 +56,8 @@ parameter_rules: help: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. + - name: response_format + use_template: response_format pricing: input: '0.02' output: '0.02' diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo.yaml index 8da184ec9e..5455555bbd 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-turbo.yaml @@ -57,6 +57,8 @@ parameter_rules: zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。 en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment. required: false + - name: response_format + use_template: response_format pricing: input: '0.008' output: '0.008' diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-4.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-4.yaml index 0439506817..de9249ea34 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-4.yaml +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-4.yaml @@ -25,6 +25,8 @@ parameter_rules: use_template: presence_penalty - name: frequency_penalty use_template: frequency_penalty + - name: response_format + use_template: response_format - name: disable_search label: zh_Hans: 禁用搜索 diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-8k.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-8k.yaml index fe06eb9975..b709644628 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-8k.yaml +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-8k.yaml @@ -25,6 +25,8 @@ parameter_rules: use_template: presence_penalty - name: frequency_penalty use_template: frequency_penalty + - name: response_format + use_template: response_format - name: disable_search label: zh_Hans: 禁用搜索 diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-turbo.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-turbo.yaml index bcd9d1235b..2769c214e0 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-turbo.yaml +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot-turbo.yaml @@ -25,3 +25,5 @@ parameter_rules: use_template: presence_penalty - name: frequency_penalty use_template: frequency_penalty + - name: response_format + use_template: response_format diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot.yaml b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot.yaml index 75fb3b1942..5b1237b243 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot.yaml +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie-bot.yaml @@ -34,3 +34,5 @@ parameter_rules: zh_Hans: 禁用模型自行进行外部搜索。 en_US: Disable the model to perform external search. required: false + - name: response_format + use_template: response_format diff --git a/api/core/model_runtime/model_providers/wenxin/llm/llm.py b/api/core/model_runtime/model_providers/wenxin/llm/llm.py index 51b3c97497..d39d63deee 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/llm.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/llm.py @@ -1,6 +1,7 @@ from collections.abc import Generator -from typing import cast +from typing import Optional, Union, cast +from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, @@ -29,8 +30,18 @@ from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( RateLimitReachedError, ) +ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. +The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. -class ErnieBotLarguageModel(LargeLanguageModel): + +{{instructions}} + + +You should also complete the text started with ``` but not tell ``` directly. +""" + +class ErnieBotLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, @@ -39,6 +50,62 @@ class ErnieBotLarguageModel(LargeLanguageModel): return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) + def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, + callbacks: list[Callback] = None) -> Union[LLMResult, Generator]: + """ + Code block mode wrapper for invoking large language model + """ + if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']: + response_format = model_parameters['response_format'] + stop = stop or [] + self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, response_format) + model_parameters.pop('response_format') + if stream: + return self._code_block_mode_stream_processor( + model=model, + prompt_messages=prompt_messages, + input_generator=self._invoke(model=model, credentials=credentials, prompt_messages=prompt_messages, + model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user) + ) + + return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + + def _transform_json_prompts(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ + -> None: + """ + Transform json prompts to model prompts + """ + + # check if there is a system message + if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + # override the system message + prompt_messages[0] = SystemPromptMessage( + content=ERNIE_BOT_BLOCK_MODE_PROMPT + .replace("{{instructions}}", prompt_messages[0].content) + .replace("{{block}}", response_format) + ) + else: + # insert the system message + prompt_messages.insert(0, SystemPromptMessage( + content=ERNIE_BOT_BLOCK_MODE_PROMPT + .replace("{{instructions}}", f"Please output a valid {response_format} object.") + .replace("{{block}}", response_format) + )) + + if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): + # add ```JSON\n to the last message + prompt_messages[-1].content += "\n```JSON\n{\n" + else: + # append a user message + prompt_messages.append(UserPromptMessage( + content="```JSON\n{\n" + )) + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: list[PromptMessageTool] | None = None) -> int: # tools is not supported yet diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index c62422dfb0..27277164c9 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -19,6 +19,17 @@ from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_comp from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk from core.model_runtime.utils import helper +GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object. +The structure of the JSON object you can found in the instructions, use {"answer": "$your_answer"} as the default structure +if you are not sure about the structure. + +And you should always end the block with a "```" to indicate the end of the JSON object. + + +{{instructions}} + + +```JSON""" class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): @@ -44,8 +55,42 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): credentials_kwargs = self._to_credential_kwargs(credentials) # invoke model + # stop = stop or [] + # self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user) + # def _transform_json_prompts(self, model: str, credentials: dict, + # prompt_messages: list[PromptMessage], model_parameters: dict, + # tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + # stream: bool = True, user: str | None = None) \ + # -> None: + # """ + # Transform json prompts to model prompts + # """ + # if "}\n\n" not in stop: + # stop.append("}\n\n") + + # # check if there is a system message + # if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): + # # override the system message + # prompt_messages[0] = SystemPromptMessage( + # content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content) + # ) + # else: + # # insert the system message + # prompt_messages.insert(0, SystemPromptMessage( + # content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", "Please output a valid JSON object.") + # )) + # # check if the last message is a user message + # if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): + # # add ```JSON\n to the last message + # prompt_messages[-1].content += "\n```JSON\n" + # else: + # # append a user message + # prompt_messages.append(UserPromptMessage( + # content="```JSON\n" + # )) + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """ @@ -106,7 +151,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): """ extra_model_kwargs = {} if stop: - extra_model_kwargs['stop_sequences'] = stop + extra_model_kwargs['stop'] = stop client = ZhipuAI( api_key=credentials_kwargs['api_key'] @@ -256,10 +301,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): ] if stream: - response = client.chat.completions.create(stream=stream, **params) + response = client.chat.completions.create(stream=stream, **params, **extra_model_kwargs) return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages) - response = client.chat.completions.create(**params) + response = client.chat.completions.create(**params, **extra_model_kwargs) return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages) def _handle_generate_response(self, model: str, diff --git a/api/core/rag/datasource/keyword/keyword_factory.py b/api/core/rag/datasource/keyword/keyword_factory.py index bccec20714..f5e2bf0f83 100644 --- a/api/core/rag/datasource/keyword/keyword_factory.py +++ b/api/core/rag/datasource/keyword/keyword_factory.py @@ -1,4 +1,4 @@ -from typing import Any, cast +from typing import Any from flask import current_app @@ -14,7 +14,7 @@ class Keyword: self._keyword_processor = self._init_keyword() def _init_keyword(self) -> BaseKeyword: - config = cast(dict, current_app.config) + config = current_app.config keyword_type = config.get('KEYWORD_STORE') if not keyword_type: @@ -25,7 +25,7 @@ class Keyword: dataset=self._dataset ) else: - raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.") + raise ValueError(f"Keyword store {keyword_type} is not supported.") def create(self, texts: list[Document], **kwargs): self._keyword_processor.create(texts, **kwargs) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 79673ffa83..0f9c753056 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -2,7 +2,6 @@ import threading from typing import Optional from flask import Flask, current_app -from flask_login import current_user from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.keyword_factory import Keyword @@ -27,6 +26,11 @@ class RetrievalService: @classmethod def retrieve(cls, retrival_method: str, dataset_id: str, query: str, top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None): + dataset = db.session.query(Dataset).filter( + Dataset.id == dataset_id + ).first() + if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0: + return [] all_documents = [] threads = [] # retrieval_model source with keyword @@ -35,7 +39,8 @@ class RetrievalService: 'flask_app': current_app._get_current_object(), 'dataset_id': dataset_id, 'query': query, - 'top_k': top_k + 'top_k': top_k, + 'all_documents': all_documents }) threads.append(keyword_thread) keyword_thread.start() @@ -73,7 +78,7 @@ class RetrievalService: thread.join() if retrival_method == 'hybrid_search': - data_post_processor = DataPostProcessor(str(current_user.current_tenant_id), reranking_model, False) + data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False) all_documents = data_post_processor.invoke( query=query, documents=all_documents, @@ -96,7 +101,7 @@ class RetrievalService: documents = keyword.search( query, - k=top_k + top_k=top_k ) all_documents.extend(documents) @@ -116,7 +121,7 @@ class RetrievalService: documents = vector.search_by_vector( query, search_type='similarity_score_threshold', - k=top_k, + top_k=top_k, score_threshold=score_threshold, filter={ 'group_id': [dataset.id] diff --git a/api/core/rag/datasource/vdb/field.py b/api/core/rag/datasource/vdb/field.py index 6a594a83ca..dc400dafbb 100644 --- a/api/core/rag/datasource/vdb/field.py +++ b/api/core/rag/datasource/vdb/field.py @@ -7,4 +7,4 @@ class Field(Enum): GROUP_KEY = "group_id" VECTOR = "vector" TEXT_KEY = "text" - PRIMARY_KEY = " id" + PRIMARY_KEY = "id" diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 9a251ede97..0fc8ed5a26 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -124,12 +124,23 @@ class MilvusVector(BaseVector): def delete_by_ids(self, doc_ids: list[str]) -> None: - self._client.delete(collection_name=self._collection_name, pks=doc_ids) + result = self._client.query(collection_name=self._collection_name, + filter=f'metadata["doc_id"] in {doc_ids}', + output_fields=["id"]) + if result: + ids = [item["id"] for item in result] + self._client.delete(collection_name=self._collection_name, pks=ids) def delete(self) -> None: + alias = uuid4().hex + if self._client_config.secure: + uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port) + else: + uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port) + connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password) from pymilvus import utility - utility.drop_collection(self._collection_name, None) + utility.drop_collection(self._collection_name, None, using=alias) def text_exists(self, id: str) -> bool: diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index dd8fc93041..1921de07ed 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -1,4 +1,5 @@ -from typing import Any, cast +import json +from typing import Any from flask import current_app @@ -22,7 +23,7 @@ class Vector: self._vector_processor = self._init_vector() def _init_vector(self) -> BaseVector: - config = cast(dict, current_app.config) + config = current_app.config vector_type = config.get('VECTOR_STORE') if self._dataset.index_struct_dict: @@ -39,6 +40,11 @@ class Vector: else: dataset_id = self._dataset.id collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + index_struct_dict = { + "type": 'weaviate', + "vector_store": {"class_prefix": collection_name} + } + self._dataset.index_struct = json.dumps(index_struct_dict) return WeaviateVector( collection_name=collection_name, config=WeaviateConfig( @@ -66,6 +72,13 @@ class Vector: dataset_id = self._dataset.id collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + if not self._dataset.index_struct_dict: + index_struct_dict = { + "type": 'qdrant', + "vector_store": {"class_prefix": collection_name} + } + self._dataset.index_struct = json.dumps(index_struct_dict) + return QdrantVector( collection_name=collection_name, group_id=self._dataset.id, @@ -84,6 +97,11 @@ class Vector: else: dataset_id = self._dataset.id collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node' + index_struct_dict = { + "type": 'milvus', + "vector_store": {"class_prefix": collection_name} + } + self._dataset.index_struct = json.dumps(index_struct_dict) return MilvusVector( collection_name=collection_name, config=MilvusConfig( diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 5c3a810fbf..008e54085d 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -127,7 +127,10 @@ class WeaviateVector(BaseVector): ) def delete(self): - self._client.schema.delete_class(self._collection_name) + # check whether the index already exists + schema = self._default_schema(self._collection_name) + if self._client.schema.contains(schema): + self._client.schema.delete_class(self._collection_name) def text_exists(self, id: str) -> bool: collection_name = self._collection_name @@ -147,10 +150,11 @@ class WeaviateVector(BaseVector): return True def delete_by_ids(self, ids: list[str]) -> None: - self._client.data_object.delete( - ids, - class_name=self._collection_name - ) + for uuid in ids: + self._client.data_object.delete( + class_name=self._collection_name, + uuid=uuid, + ) def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: """Look up similar documents by embedding vector in Weaviate.""" diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index bc5310f7be..49cd4d0c03 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -12,6 +12,7 @@ class NotionInfo(BaseModel): notion_obj_id: str notion_page_type: str document: Document = None + tenant_id: str class Config: arbitrary_types_allowed = True diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 7c7dc5bdae..0de7065335 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -132,7 +132,8 @@ class ExtractProcessor: notion_workspace_id=extract_setting.notion_info.notion_workspace_id, notion_obj_id=extract_setting.notion_info.notion_obj_id, notion_page_type=extract_setting.notion_info.notion_page_type, - document_model=extract_setting.notion_info.document + document_model=extract_setting.notion_info.document, + tenant_id=extract_setting.notion_info.tenant_id, ) return extractor.extract() else: diff --git a/api/core/rag/extractor/html_extractor.py b/api/core/rag/extractor/html_extractor.py index 557ea42b19..ceb5306255 100644 --- a/api/core/rag/extractor/html_extractor.py +++ b/api/core/rag/extractor/html_extractor.py @@ -1,13 +1,14 @@ """Abstract interface for document loader implementations.""" -from typing import Optional +from bs4 import BeautifulSoup from core.rag.extractor.extractor_base import BaseExtractor -from core.rag.extractor.helpers import detect_file_encodings from core.rag.models.document import Document class HtmlExtractor(BaseExtractor): - """Load html files. + + """ + Load html files. Args: @@ -15,57 +16,19 @@ class HtmlExtractor(BaseExtractor): """ def __init__( - self, - file_path: str, - encoding: Optional[str] = None, - autodetect_encoding: bool = False, - source_column: Optional[str] = None, - csv_args: Optional[dict] = None, + self, + file_path: str ): """Initialize with file path.""" self._file_path = file_path - self._encoding = encoding - self._autodetect_encoding = autodetect_encoding - self.source_column = source_column - self.csv_args = csv_args or {} def extract(self) -> list[Document]: - """Load data into document objects.""" - try: - with open(self._file_path, newline="", encoding=self._encoding) as csvfile: - docs = self._read_from_file(csvfile) - except UnicodeDecodeError as e: - if self._autodetect_encoding: - detected_encodings = detect_file_encodings(self._file_path) - for encoding in detected_encodings: - try: - with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile: - docs = self._read_from_file(csvfile) - break - except UnicodeDecodeError: - continue - else: - raise RuntimeError(f"Error loading {self._file_path}") from e + return [Document(page_content=self._load_as_text())] - return docs + def _load_as_text(self) -> str: + with open(self._file_path, "rb") as fp: + soup = BeautifulSoup(fp, 'html.parser') + text = soup.get_text() + text = text.strip() if text else '' - def _read_from_file(self, csvfile) -> list[Document]: - docs = [] - csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore - for i, row in enumerate(csv_reader): - content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items()) - try: - source = ( - row[self.source_column] - if self.source_column is not None - else '' - ) - except KeyError: - raise ValueError( - f"Source column '{self.source_column}' not found in CSV file." - ) - metadata = {"source": source, "row": i} - doc = Document(page_content=content, metadata=metadata) - docs.append(doc) - - return docs + return text \ No newline at end of file diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index f28436ffd9..c40064fd1d 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -4,7 +4,6 @@ from typing import Any, Optional import requests from flask import current_app -from flask_login import current_user from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document @@ -30,8 +29,10 @@ class NotionExtractor(BaseExtractor): notion_workspace_id: str, notion_obj_id: str, notion_page_type: str, + tenant_id: str, document_model: Optional[DocumentModel] = None, - notion_access_token: Optional[str] = None + notion_access_token: Optional[str] = None, + ): self._notion_access_token = None self._document_model = document_model @@ -41,7 +42,7 @@ class NotionExtractor(BaseExtractor): if notion_access_token: self._notion_access_token = notion_access_token else: - self._notion_access_token = self._get_access_token(current_user.current_tenant_id, + self._notion_access_token = self._get_access_token(tenant_id, self._notion_workspace_id) if not self._notion_access_token: integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN') diff --git a/api/core/third_party/spark/spark_llm.py b/api/core/third_party/spark/spark_llm.py deleted file mode 100644 index 5c97bba530..0000000000 --- a/api/core/third_party/spark/spark_llm.py +++ /dev/null @@ -1,189 +0,0 @@ -import base64 -import hashlib -import hmac -import json -import queue -import ssl -from datetime import datetime -from time import mktime -from typing import Optional -from urllib.parse import urlencode, urlparse -from wsgiref.handlers import format_date_time - -import websocket - - -class SparkLLMClient: - def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None): - domain = 'spark-api.xf-yun.com' - endpoint = 'chat' - if api_domain: - domain = api_domain - if model_name == 'spark-v3': - endpoint = 'multimodal' - - model_api_configs = { - 'spark': { - 'version': 'v1.1', - 'chat_domain': 'general' - }, - 'spark-v2': { - 'version': 'v2.1', - 'chat_domain': 'generalv2' - }, - 'spark-v3': { - 'version': 'v3.1', - 'chat_domain': 'generalv3' - }, - 'spark-v3.5': { - 'version': 'v3.5', - 'chat_domain': 'generalv3.5' - } - } - - api_version = model_api_configs[model_name]['version'] - - self.chat_domain = model_api_configs[model_name]['chat_domain'] - self.api_base = f"wss://{domain}/{api_version}/{endpoint}" - self.app_id = app_id - self.ws_url = self.create_url( - urlparse(self.api_base).netloc, - urlparse(self.api_base).path, - self.api_base, - api_key, - api_secret - ) - - self.queue = queue.Queue() - self.blocking_message = '' - - def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str: - # generate timestamp by RFC1123 - now = datetime.now() - date = format_date_time(mktime(now.timetuple())) - - signature_origin = "host: " + host + "\n" - signature_origin += "date: " + date + "\n" - signature_origin += "GET " + path + " HTTP/1.1" - - # encrypt using hmac-sha256 - signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'), - digestmod=hashlib.sha256).digest() - - signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') - - authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' - - authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') - - v = { - "authorization": authorization, - "date": date, - "host": host - } - # generate url - url = api_base + '?' + urlencode(v) - return url - - def run(self, messages: list, user_id: str, - model_kwargs: Optional[dict] = None, streaming: bool = False): - websocket.enableTrace(False) - ws = websocket.WebSocketApp( - self.ws_url, - on_message=self.on_message, - on_error=self.on_error, - on_close=self.on_close, - on_open=self.on_open - ) - ws.messages = messages - ws.user_id = user_id - ws.model_kwargs = model_kwargs - ws.streaming = streaming - ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) - - def on_error(self, ws, error): - self.queue.put({ - 'status_code': error.status_code, - 'error': error.resp_body.decode('utf-8') - }) - ws.close() - - def on_close(self, ws, close_status_code, close_reason): - self.queue.put({'done': True}) - - def on_open(self, ws): - self.blocking_message = '' - data = json.dumps(self.gen_params( - messages=ws.messages, - user_id=ws.user_id, - model_kwargs=ws.model_kwargs - )) - ws.send(data) - - def on_message(self, ws, message): - data = json.loads(message) - code = data['header']['code'] - if code != 0: - self.queue.put({ - 'status_code': 400, - 'error': f"Code: {code}, Error: {data['header']['message']}" - }) - ws.close() - else: - choices = data["payload"]["choices"] - status = choices["status"] - content = choices["text"][0]["content"] - if ws.streaming: - self.queue.put({'data': content}) - else: - self.blocking_message += content - - if status == 2: - if not ws.streaming: - self.queue.put({'data': self.blocking_message}) - ws.close() - - def gen_params(self, messages: list, user_id: str, - model_kwargs: Optional[dict] = None) -> dict: - data = { - "header": { - "app_id": self.app_id, - "uid": user_id - }, - "parameter": { - "chat": { - "domain": self.chat_domain - } - }, - "payload": { - "message": { - "text": messages - } - } - } - - if model_kwargs: - data['parameter']['chat'].update(model_kwargs) - - return data - - def subscribe(self): - while True: - content = self.queue.get() - if 'error' in content: - if content['status_code'] == 401: - raise SparkError('[Spark] The credentials you provided are incorrect. ' - 'Please double-check and fill them in again.') - elif content['status_code'] == 403: - raise SparkError("[Spark] Sorry, the credentials you provided are access denied. " - "Please try again after obtaining the necessary permissions.") - else: - raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}") - - if 'data' not in content: - break - yield content - - -class SparkError(Exception): - pass diff --git a/api/core/tool/current_datetime_tool.py b/api/core/tool/current_datetime_tool.py deleted file mode 100644 index 208490a5bf..0000000000 --- a/api/core/tool/current_datetime_tool.py +++ /dev/null @@ -1,24 +0,0 @@ -from datetime import datetime - -from langchain.tools import BaseTool -from pydantic import BaseModel, Field - - -class DatetimeToolInput(BaseModel): - type: str = Field(..., description="Type for current time, must be: datetime.") - - -class DatetimeTool(BaseTool): - """Tool for querying current datetime.""" - name: str = "current_datetime" - args_schema: type[BaseModel] = DatetimeToolInput - description: str = "A tool when you want to get the current date, time, week, month or year, " \ - "and the time zone is UTC. Result is \" \"." - - def _run(self, type: str) -> str: - # get current time - current_time = datetime.utcnow() - return current_time.strftime("%Y-%m-%d %H:%M:%S UTC+0000 %A") - - async def _arun(self, tool_input: str) -> str: - raise NotImplementedError() diff --git a/api/core/tool/provider/base.py b/api/core/tool/provider/base.py deleted file mode 100644 index bf5dc3bf56..0000000000 --- a/api/core/tool/provider/base.py +++ /dev/null @@ -1,63 +0,0 @@ -import base64 -from abc import ABC, abstractmethod -from typing import Optional - -from extensions.ext_database import db -from libs import rsa -from models.account import Tenant -from models.tool import ToolProvider, ToolProviderName - - -class BaseToolProvider(ABC): - def __init__(self, tenant_id: str): - self.tenant_id = tenant_id - - @abstractmethod - def get_provider_name(self) -> ToolProviderName: - raise NotImplementedError - - @abstractmethod - def encrypt_credentials(self, credentials: dict) -> Optional[dict]: - raise NotImplementedError - - @abstractmethod - def get_credentials(self, obfuscated: bool = False) -> Optional[dict]: - raise NotImplementedError - - @abstractmethod - def credentials_to_func_kwargs(self) -> Optional[dict]: - raise NotImplementedError - - @abstractmethod - def credentials_validate(self, credentials: dict): - raise NotImplementedError - - def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]: - """ - Returns the Provider instance for the given tenant_id and tool_name. - """ - query = db.session.query(ToolProvider).filter( - ToolProvider.tenant_id == self.tenant_id, - ToolProvider.tool_name == self.get_provider_name().value - ) - - if must_enabled: - query = query.filter(ToolProvider.is_enabled == True) - - return query.first() - - def encrypt_token(self, token) -> str: - tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() - encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) - return base64.b64encode(encrypted_token).decode() - - def decrypt_token(self, token: str, obfuscated: bool = False) -> str: - token = rsa.decrypt(base64.b64decode(token), self.tenant_id) - - if obfuscated: - return self._obfuscated_token(token) - - return token - - def _obfuscated_token(self, token: str) -> str: - return token[:6] + '*' * (len(token) - 8) + token[-2:] diff --git a/api/core/tool/provider/errors.py b/api/core/tool/provider/errors.py deleted file mode 100644 index c6b937063f..0000000000 --- a/api/core/tool/provider/errors.py +++ /dev/null @@ -1,2 +0,0 @@ -class ToolValidateFailedError(Exception): - description = "Tool Provider Validate failed" diff --git a/api/core/tool/provider/serpapi_provider.py b/api/core/tool/provider/serpapi_provider.py deleted file mode 100644 index c87510e541..0000000000 --- a/api/core/tool/provider/serpapi_provider.py +++ /dev/null @@ -1,77 +0,0 @@ -from typing import Optional - -from core.tool.provider.base import BaseToolProvider -from core.tool.provider.errors import ToolValidateFailedError -from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper -from models.tool import ToolProviderName - - -class SerpAPIToolProvider(BaseToolProvider): - def get_provider_name(self) -> ToolProviderName: - """ - Returns the name of the provider. - - :return: - """ - return ToolProviderName.SERPAPI - - def get_credentials(self, obfuscated: bool = False) -> Optional[dict]: - """ - Returns the credentials for SerpAPI as a dictionary. - - :param obfuscated: obfuscate credentials if True - :return: - """ - tool_provider = self.get_provider(must_enabled=True) - if not tool_provider: - return None - - credentials = tool_provider.credentials - if not credentials: - return None - - if credentials.get('api_key'): - credentials['api_key'] = self.decrypt_token(credentials.get('api_key'), obfuscated) - - return credentials - - def credentials_to_func_kwargs(self) -> Optional[dict]: - """ - Returns the credentials function kwargs as a dictionary. - - :return: - """ - credentials = self.get_credentials() - if not credentials: - return None - - return { - 'serpapi_api_key': credentials.get('api_key') - } - - def credentials_validate(self, credentials: dict): - """ - Validates the given credentials. - - :param credentials: - :return: - """ - if 'api_key' not in credentials or not credentials.get('api_key'): - raise ToolValidateFailedError("SerpAPI api_key is required.") - - api_key = credentials.get('api_key') - - try: - OptimizedSerpAPIWrapper(serpapi_api_key=api_key).run(query='test') - except Exception as e: - raise ToolValidateFailedError("SerpAPI api_key is invalid. {}".format(e)) - - def encrypt_credentials(self, credentials: dict) -> Optional[dict]: - """ - Encrypts the given credentials. - - :param credentials: - :return: - """ - credentials['api_key'] = self.encrypt_token(credentials.get('api_key')) - return credentials diff --git a/api/core/tool/provider/tool_provider_service.py b/api/core/tool/provider/tool_provider_service.py deleted file mode 100644 index b3602da010..0000000000 --- a/api/core/tool/provider/tool_provider_service.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import Optional - -from core.tool.provider.base import BaseToolProvider -from core.tool.provider.serpapi_provider import SerpAPIToolProvider - - -class ToolProviderService: - - def __init__(self, tenant_id: str, provider_name: str): - self.provider = self._init_provider(tenant_id, provider_name) - - def _init_provider(self, tenant_id: str, provider_name: str) -> BaseToolProvider: - if provider_name == 'serpapi': - return SerpAPIToolProvider(tenant_id) - else: - raise Exception('tool provider {} not found'.format(provider_name)) - - def get_credentials(self, obfuscated: bool = False) -> Optional[dict]: - """ - Returns the credentials for Tool as a dictionary. - - :param obfuscated: - :return: - """ - return self.provider.get_credentials(obfuscated) - - def credentials_validate(self, credentials: dict): - """ - Validates the given credentials. - - :param credentials: - :raises: ValidateFailedError - """ - return self.provider.credentials_validate(credentials) - - def encrypt_credentials(self, credentials: dict): - """ - Encrypts the given credentials. - - :param credentials: - :return: - """ - return self.provider.encrypt_credentials(credentials) diff --git a/api/core/tool/serpapi_wrapper.py b/api/core/tool/serpapi_wrapper.py deleted file mode 100644 index 0c3f107d94..0000000000 --- a/api/core/tool/serpapi_wrapper.py +++ /dev/null @@ -1,51 +0,0 @@ -from langchain import SerpAPIWrapper -from pydantic import BaseModel, Field - - -class OptimizedSerpAPIInput(BaseModel): - query: str = Field(..., description="search query.") - - -class OptimizedSerpAPIWrapper(SerpAPIWrapper): - - @staticmethod - def _process_response(res: dict, num_results: int = 5) -> str: - """Process response from SerpAPI.""" - if "error" in res.keys(): - raise ValueError(f"Got error from SerpAPI: {res['error']}") - if "answer_box" in res.keys() and type(res["answer_box"]) == list: - res["answer_box"] = res["answer_box"][0] - if "answer_box" in res.keys() and "answer" in res["answer_box"].keys(): - toret = res["answer_box"]["answer"] - elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): - toret = res["answer_box"]["snippet"] - elif ( - "answer_box" in res.keys() - and "snippet_highlighted_words" in res["answer_box"].keys() - ): - toret = res["answer_box"]["snippet_highlighted_words"][0] - elif ( - "sports_results" in res.keys() - and "game_spotlight" in res["sports_results"].keys() - ): - toret = res["sports_results"]["game_spotlight"] - elif ( - "shopping_results" in res.keys() - and "title" in res["shopping_results"][0].keys() - ): - toret = res["shopping_results"][:3] - elif ( - "knowledge_graph" in res.keys() - and "description" in res["knowledge_graph"].keys() - ): - toret = res["knowledge_graph"]["description"] - elif 'organic_results' in res.keys() and len(res['organic_results']) > 0: - toret = "" - for result in res["organic_results"][:num_results]: - if "link" in result: - toret += "----------------\nlink: " + result["link"] + "\n" - if "snippet" in result: - toret += "snippet: " + result["snippet"] + "\n" - else: - toret = "No good search result found" - return "search result:\n" + toret diff --git a/api/core/tool/web_reader_tool.py b/api/core/tool/web_reader_tool.py deleted file mode 100644 index 6a3e52a7b4..0000000000 --- a/api/core/tool/web_reader_tool.py +++ /dev/null @@ -1,443 +0,0 @@ -import hashlib -import json -import os -import re -import site -import subprocess -import tempfile -import unicodedata -from contextlib import contextmanager -from typing import Any - -import requests -from bs4 import BeautifulSoup, CData, Comment, NavigableString -from langchain.chains import RefineDocumentsChain -from langchain.chains.summarize import refine_prompts -from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain.tools.base import BaseTool -from newspaper import Article -from pydantic import BaseModel, Field -from regex import regex - -from core.chain.llm_chain import LLMChain -from core.entities.application_entities import ModelConfigEntity -from core.rag.extractor import extract_processor -from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.models.document import Document - -FULL_TEMPLATE = """ -TITLE: {title} -AUTHORS: {authors} -PUBLISH DATE: {publish_date} -TOP_IMAGE_URL: {top_image} -TEXT: - -{text} -""" - - -class WebReaderToolInput(BaseModel): - url: str = Field(..., description="URL of the website to read") - summary: bool = Field( - default=False, - description="When the user's question requires extracting the summarizing content of the webpage, " - "set it to true." - ) - cursor: int = Field( - default=0, - description="Start reading from this character." - "Use when the first response was truncated" - "and you want to continue reading the page." - "The value cannot exceed 24000.", - ) - - -class WebReaderTool(BaseTool): - """Reader tool for getting website title and contents. Gives more control than SimpleReaderTool.""" - - name: str = "web_reader" - args_schema: type[BaseModel] = WebReaderToolInput - description: str = "use this to read a website. " \ - "If you can answer the question based on the information provided, " \ - "there is no need to use." - page_contents: str = None - url: str = None - max_chunk_length: int = 4000 - summary_chunk_tokens: int = 4000 - summary_chunk_overlap: int = 0 - summary_separators: list[str] = ["\n\n", "。", ".", " ", ""] - continue_reading: bool = True - model_config: ModelConfigEntity - model_parameters: dict[str, Any] - - def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str: - try: - if not self.page_contents or self.url != url: - page_contents = get_url(url) - self.page_contents = page_contents - self.url = url - else: - page_contents = self.page_contents - except Exception as e: - return f'Read this website failed, caused by: {str(e)}.' - - if summary: - character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( - chunk_size=self.summary_chunk_tokens, - chunk_overlap=self.summary_chunk_overlap, - separators=self.summary_separators - ) - - texts = character_splitter.split_text(page_contents) - docs = [Document(page_content=t) for t in texts] - - if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'): - return "No content found." - - # only use first 5 docs - if len(docs) > 5: - docs = docs[:5] - - chain = self.get_summary_chain() - try: - page_contents = chain.run(docs) - except Exception as e: - return f'Read this website failed, caused by: {str(e)}.' - else: - page_contents = page_result(page_contents, cursor, self.max_chunk_length) - - if self.continue_reading and len(page_contents) >= self.max_chunk_length: - page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \ - f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \ - f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING." - - return page_contents - - async def _arun(self, url: str) -> str: - raise NotImplementedError - - def get_summary_chain(self) -> RefineDocumentsChain: - initial_chain = LLMChain( - model_config=self.model_config, - prompt=refine_prompts.PROMPT, - parameters=self.model_parameters - ) - refine_chain = LLMChain( - model_config=self.model_config, - prompt=refine_prompts.REFINE_PROMPT, - parameters=self.model_parameters - ) - return RefineDocumentsChain( - initial_llm_chain=initial_chain, - refine_llm_chain=refine_chain, - document_variable_name="text", - initial_response_name="existing_answer", - callbacks=self.callbacks - ) - - -def page_result(text: str, cursor: int, max_length: int) -> str: - """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" - return text[cursor: cursor + max_length] - - -def get_url(url: str) -> str: - """Fetch URL and return the contents as a string.""" - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" - } - supported_content_types = extract_processor.SUPPORT_URL_CONTENT_TYPES + ["text/html"] - - head_response = requests.head(url, headers=headers, allow_redirects=True, timeout=(5, 10)) - - if head_response.status_code != 200: - return "URL returned status code {}.".format(head_response.status_code) - - # check content-type - main_content_type = head_response.headers.get('Content-Type').split(';')[0].strip() - if main_content_type not in supported_content_types: - return "Unsupported content-type [{}] of URL.".format(main_content_type) - - if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES: - return ExtractProcessor.load_from_url(url, return_text=True) - - response = requests.get(url, headers=headers, allow_redirects=True, timeout=(5, 30)) - a = extract_using_readabilipy(response.text) - - if not a['plain_text'] or not a['plain_text'].strip(): - return get_url_from_newspaper3k(url) - - res = FULL_TEMPLATE.format( - title=a['title'], - authors=a['byline'], - publish_date=a['date'], - top_image="", - text=a['plain_text'] if a['plain_text'] else "", - ) - - return res - - -def get_url_from_newspaper3k(url: str) -> str: - - a = Article(url) - a.download() - a.parse() - - res = FULL_TEMPLATE.format( - title=a.title, - authors=a.authors, - publish_date=a.publish_date, - top_image=a.top_image, - text=a.text, - ) - - return res - - -def extract_using_readabilipy(html): - with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html: - f_html.write(html) - f_html.close() - html_path = f_html.name - - # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file - article_json_path = html_path + ".json" - jsdir = os.path.join(find_module_path('readabilipy'), 'javascript') - with chdir(jsdir): - subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path]) - - # Read output of call to Readability.parse() from JSON file and return as Python dictionary - with open(article_json_path, encoding="utf-8") as json_file: - input_json = json.loads(json_file.read()) - - # Deleting files after processing - os.unlink(article_json_path) - os.unlink(html_path) - - article_json = { - "title": None, - "byline": None, - "date": None, - "content": None, - "plain_content": None, - "plain_text": None - } - # Populate article fields from readability fields where present - if input_json: - if "title" in input_json and input_json["title"]: - article_json["title"] = input_json["title"] - if "byline" in input_json and input_json["byline"]: - article_json["byline"] = input_json["byline"] - if "date" in input_json and input_json["date"]: - article_json["date"] = input_json["date"] - if "content" in input_json and input_json["content"]: - article_json["content"] = input_json["content"] - article_json["plain_content"] = plain_content(article_json["content"], False, False) - article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"]) - if "textContent" in input_json and input_json["textContent"]: - article_json["plain_text"] = input_json["textContent"] - article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"]) - - return article_json - - -def find_module_path(module_name): - for package_path in site.getsitepackages(): - potential_path = os.path.join(package_path, module_name) - if os.path.exists(potential_path): - return potential_path - - return None - -@contextmanager -def chdir(path): - """Change directory in context and return to original on exit""" - # From https://stackoverflow.com/a/37996581, couldn't find a built-in - original_path = os.getcwd() - os.chdir(path) - try: - yield - finally: - os.chdir(original_path) - - -def extract_text_blocks_as_plain_text(paragraph_html): - # Load article as DOM - soup = BeautifulSoup(paragraph_html, 'html.parser') - # Select all lists - list_elements = soup.find_all(['ul', 'ol']) - # Prefix text in all list items with "* " and make lists paragraphs - for list_element in list_elements: - plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')]))) - list_element.string = plain_items - list_element.name = "p" - # Select all text blocks - text_blocks = [s.parent for s in soup.find_all(string=True)] - text_blocks = [plain_text_leaf_node(block) for block in text_blocks] - # Drop empty paragraphs - text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks)) - return text_blocks - - -def plain_text_leaf_node(element): - # Extract all text, stripped of any child HTML elements and normalise it - plain_text = normalise_text(element.get_text()) - if plain_text != "" and element.name == "li": - plain_text = "* {}, ".format(plain_text) - if plain_text == "": - plain_text = None - if "data-node-index" in element.attrs: - plain = {"node_index": element["data-node-index"], "text": plain_text} - else: - plain = {"text": plain_text} - return plain - - -def plain_content(readability_content, content_digests, node_indexes): - # Load article as DOM - soup = BeautifulSoup(readability_content, 'html.parser') - # Make all elements plain - elements = plain_elements(soup.contents, content_digests, node_indexes) - if node_indexes: - # Add node index attributes to nodes - elements = [add_node_indexes(element) for element in elements] - # Replace article contents with plain elements - soup.contents = elements - return str(soup) - - -def plain_elements(elements, content_digests, node_indexes): - # Get plain content versions of all elements - elements = [plain_element(element, content_digests, node_indexes) - for element in elements] - if content_digests: - # Add content digest attribute to nodes - elements = [add_content_digest(element) for element in elements] - return elements - - -def plain_element(element, content_digests, node_indexes): - # For lists, we make each item plain text - if is_leaf(element): - # For leaf node elements, extract the text content, discarding any HTML tags - # 1. Get element contents as text - plain_text = element.get_text() - # 2. Normalise the extracted text string to a canonical representation - plain_text = normalise_text(plain_text) - # 3. Update element content to be plain text - element.string = plain_text - elif is_text(element): - if is_non_printing(element): - # The simplified HTML may have come from Readability.js so might - # have non-printing text (e.g. Comment or CData). In this case, we - # keep the structure, but ensure that the string is empty. - element = type(element)("") - else: - plain_text = element.string - plain_text = normalise_text(plain_text) - element = type(element)(plain_text) - else: - # If not a leaf node or leaf type call recursively on child nodes, replacing - element.contents = plain_elements(element.contents, content_digests, node_indexes) - return element - - -def add_node_indexes(element, node_index="0"): - # Can't add attributes to string types - if is_text(element): - return element - # Add index to current element - element["data-node-index"] = node_index - # Add index to child elements - for local_idx, child in enumerate( - [c for c in element.contents if not is_text(c)], start=1): - # Can't add attributes to leaf string types - child_index = "{stem}.{local}".format( - stem=node_index, local=local_idx) - add_node_indexes(child, node_index=child_index) - return element - - -def normalise_text(text): - """Normalise unicode and whitespace.""" - # Normalise unicode first to try and standardise whitespace characters as much as possible before normalising them - text = strip_control_characters(text) - text = normalise_unicode(text) - text = normalise_whitespace(text) - return text - - -def strip_control_characters(text): - """Strip out unicode control characters which might break the parsing.""" - # Unicode control characters - # [Cc]: Other, Control [includes new lines] - # [Cf]: Other, Format - # [Cn]: Other, Not Assigned - # [Co]: Other, Private Use - # [Cs]: Other, Surrogate - control_chars = set(['Cc', 'Cf', 'Cn', 'Co', 'Cs']) - retained_chars = ['\t', '\n', '\r', '\f'] - - # Remove non-printing control characters - return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text]) - - -def normalise_unicode(text): - """Normalise unicode such that things that are visually equivalent map to the same unicode string where possible.""" - normal_form = "NFKC" - text = unicodedata.normalize(normal_form, text) - return text - - -def normalise_whitespace(text): - """Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed.""" - text = regex.sub(r"\s+", " ", text) - # Remove leading and trailing whitespace - text = text.strip() - return text - -def is_leaf(element): - return (element.name in ['p', 'li']) - - -def is_text(element): - return isinstance(element, NavigableString) - - -def is_non_printing(element): - return any(isinstance(element, _e) for _e in [Comment, CData]) - - -def add_content_digest(element): - if not is_text(element): - element["data-content-digest"] = content_digest(element) - return element - - -def content_digest(element): - if is_text(element): - # Hash - trimmed_string = element.string.strip() - if trimmed_string == "": - digest = "" - else: - digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest() - else: - contents = element.contents - num_contents = len(contents) - if num_contents == 0: - # No hash when no child elements exist - digest = "" - elif num_contents == 1: - # If single child, use digest of child - digest = content_digest(contents[0]) - else: - # Build content digest from the "non-empty" digests of child nodes - digest = hashlib.sha256() - child_digests = list( - filter(lambda x: x != "", [content_digest(content) for content in contents])) - for child in child_digests: - digest.update(child.encode('utf-8')) - digest = digest.hexdigest() - return digest diff --git a/api/core/tools/provider/api_tool_provider.py b/api/core/tools/provider/api_tool_provider.py index 13f4bc2c3d..eb839e9341 100644 --- a/api/core/tools/provider/api_tool_provider.py +++ b/api/core/tools/provider/api_tool_provider.py @@ -55,6 +55,21 @@ class ApiBasedToolProviderController(ToolProviderController): en_US='The api key', zh_Hans='api key的值' ) + ), + 'api_key_header_prefix': ToolProviderCredentials( + name='api_key_header_prefix', + required=False, + default='basic', + type=ToolProviderCredentials.CredentialsType.SELECT, + help=I18nObject( + en_US='The prefix of the api key header', + zh_Hans='api key header 的前缀' + ), + options=[ + ToolCredentialsOption(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')), + ToolCredentialsOption(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')), + ToolCredentialsOption(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom')) + ] ) } elif auth_type == ApiProviderAuthType.NONE: diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index f6914d3473..781eff13b4 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -62,6 +62,17 @@ class ApiTool(Tool): if 'api_key_value' not in credentials: raise ToolProviderCredentialValidationError('Missing api_key_value') + elif not isinstance(credentials['api_key_value'], str): + raise ToolProviderCredentialValidationError('api_key_value must be a string') + + if 'api_key_header_prefix' in credentials: + api_key_header_prefix = credentials['api_key_header_prefix'] + if api_key_header_prefix == 'basic': + credentials['api_key_value'] = f'Basic {credentials["api_key_value"]}' + elif api_key_header_prefix == 'bearer': + credentials['api_key_value'] = f'Bearer {credentials["api_key_value"]}' + elif api_key_header_prefix == 'custom': + pass headers[api_key_header] = credentials['api_key_value'] @@ -200,7 +211,7 @@ class ApiTool(Tool): # replace path parameters for name, value in path_params.items(): - url = url.replace(f'{{{name}}}', value) + url = url.replace(f'{{{name}}}', f'{value}') # parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored if 'Content-Type' in headers: diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index 57b6e090c4..d9934acff9 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -171,7 +171,7 @@ class DatasetMultiRetrieverTool(BaseTool): if dataset.indexing_technique == "economy": # use keyword table query - documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + documents = RetrievalService.retrieve(retrival_method='keyword_search', dataset_id=dataset.id, query=query, top_k=self.top_k diff --git a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py index d3ec0fba69..13331d981b 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_retriever_tool.py @@ -69,7 +69,7 @@ class DatasetRetrieverTool(BaseTool): retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model if dataset.indexing_technique == "economy": # use keyword table query - documents = RetrievalService.retrieve(retrival_method=retrieval_model['search_method'], + documents = RetrievalService.retrieve(retrival_method='keyword_search', dataset_id=dataset.id, query=query, top_k=self.top_k diff --git a/api/core/tools/tool/dataset_retriever_tool.py b/api/core/tools/tool/dataset_retriever_tool.py index 6906c21024..30128c4dca 100644 --- a/api/core/tools/tool/dataset_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever_tool.py @@ -4,7 +4,7 @@ from langchain.tools import BaseTool from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom -from core.features.dataset_retrieval import DatasetRetrievalFeature +from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter from core.tools.tool.tool import Tool @@ -15,12 +15,12 @@ class DatasetRetrieverTool(Tool): @staticmethod def get_dataset_tools(tenant_id: str, - dataset_ids: list[str], - retrieve_config: DatasetRetrieveConfigEntity, - return_resource: bool, - invoke_from: InvokeFrom, - hit_callback: DatasetIndexToolCallbackHandler - ) -> list['DatasetRetrieverTool']: + dataset_ids: list[str], + retrieve_config: DatasetRetrieveConfigEntity, + return_resource: bool, + invoke_from: InvokeFrom, + hit_callback: DatasetIndexToolCallbackHandler + ) -> list['DatasetRetrieverTool']: """ get dataset tool """ @@ -46,7 +46,7 @@ class DatasetRetrieverTool(Tool): ) # restore retrieve strategy retrieve_config.retrieve_strategy = original_retriever_mode - + # convert langchain tools to Tools tools = [] for langchain_tool in langchain_tools: @@ -60,7 +60,7 @@ class DatasetRetrieverTool(Tool): llm=langchain_tool.description), runtime=DatasetRetrieverTool.Runtime() ) - + tools.append(tool) return tools @@ -68,13 +68,13 @@ class DatasetRetrieverTool(Tool): def get_runtime_parameters(self) -> list[ToolParameter]: return [ ToolParameter(name='query', - label=I18nObject(en_US='', zh_Hans=''), - human_description=I18nObject(en_US='', zh_Hans=''), - type=ToolParameter.ToolParameterType.STRING, - form=ToolParameter.ToolParameterForm.LLM, - llm_description='Query for the dataset to be used to retrieve the dataset.', - required=True, - default=''), + label=I18nObject(en_US='', zh_Hans=''), + human_description=I18nObject(en_US='', zh_Hans=''), + type=ToolParameter.ToolParameterType.STRING, + form=ToolParameter.ToolParameterForm.LLM, + llm_description='Query for the dataset to be used to retrieve the dataset.', + required=True, + default=''), ] def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: @@ -84,7 +84,7 @@ class DatasetRetrieverTool(Tool): query = tool_parameters.get('query', None) if not query: return self.create_text_message(text='please input query') - + # invoke dataset retriever tool result = self.langchain_tool._run(query=query) @@ -94,4 +94,4 @@ class DatasetRetrieverTool(Tool): """ validate the credentials for dataset retriever tool """ - pass \ No newline at end of file + pass diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 91c18be3f5..889316c235 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -1,4 +1,6 @@ +import re +import uuid from json import loads as json_loads from requests import get @@ -46,7 +48,7 @@ class ApiBasedToolSchemaParser: parameters = [] if 'parameters' in interface['operation']: for parameter in interface['operation']['parameters']: - parameters.append(ToolParameter( + tool_parameter = ToolParameter( name=parameter['name'], label=I18nObject( en_US=parameter['name'], @@ -61,7 +63,14 @@ class ApiBasedToolSchemaParser: form=ToolParameter.ToolParameterForm.LLM, llm_description=parameter.get('description'), default=parameter['schema']['default'] if 'schema' in parameter and 'default' in parameter['schema'] else None, - )) + ) + + # check if there is a type + typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter) + if typ: + tool_parameter.type = typ + + parameters.append(tool_parameter) # create tool bundle # check if there is a request body if 'requestBody' in interface['operation']: @@ -80,13 +89,14 @@ class ApiBasedToolSchemaParser: root = root[ref] # overwrite the content interface['operation']['requestBody']['content'][content_type]['schema'] = root + # parse body parameters if 'schema' in interface['operation']['requestBody']['content'][content_type]: body_schema = interface['operation']['requestBody']['content'][content_type]['schema'] required = body_schema['required'] if 'required' in body_schema else [] properties = body_schema['properties'] if 'properties' in body_schema else {} for name, property in properties.items(): - parameters.append(ToolParameter( + tool = ToolParameter( name=name, label=I18nObject( en_US=name, @@ -101,7 +111,14 @@ class ApiBasedToolSchemaParser: form=ToolParameter.ToolParameterForm.LLM, llm_description=property['description'] if 'description' in property else '', default=property['default'] if 'default' in property else None, - )) + ) + + # check if there is a type + typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property) + if typ: + tool.type = typ + + parameters.append(tool) # check if parameters is duplicated parameters_count = {} @@ -119,7 +136,11 @@ class ApiBasedToolSchemaParser: path = interface['path'] if interface['path'].startswith('/'): path = interface['path'][1:] - path = path.replace('/', '_') + # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ + path = re.sub(r'[^a-zA-Z0-9_-]', '', path) + if not path: + path = str(uuid.uuid4()) + interface['operation']['operationId'] = f'{path}_{interface["method"]}' bundles.append(ApiBasedToolBundle( @@ -134,7 +155,23 @@ class ApiBasedToolSchemaParser: )) return bundles + + @staticmethod + def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType: + parameter = parameter or {} + typ = None + if 'type' in parameter: + typ = parameter['type'] + elif 'schema' in parameter and 'type' in parameter['schema']: + typ = parameter['schema']['type'] + if typ == 'integer' or typ == 'number': + return ToolParameter.ToolParameterType.NUMBER + elif typ == 'boolean': + return ToolParameter.ToolParameterType.BOOLEAN + elif typ == 'string': + return ToolParameter.ToolParameterType.STRING + @staticmethod def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiBasedToolBundle]: """ diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index 9975978357..ba10b318dc 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -7,23 +7,14 @@ import subprocess import tempfile import unicodedata from contextlib import contextmanager -from typing import Any import requests from bs4 import BeautifulSoup, CData, Comment, NavigableString -from langchain.chains import RefineDocumentsChain -from langchain.chains.summarize import refine_prompts -from langchain.text_splitter import RecursiveCharacterTextSplitter -from langchain.tools.base import BaseTool from newspaper import Article -from pydantic import BaseModel, Field from regex import regex -from core.chain.llm_chain import LLMChain -from core.entities.application_entities import ModelConfigEntity from core.rag.extractor import extract_processor from core.rag.extractor.extract_processor import ExtractProcessor -from core.rag.models.document import Document FULL_TEMPLATE = """ TITLE: {title} @@ -36,106 +27,6 @@ TEXT: """ -class WebReaderToolInput(BaseModel): - url: str = Field(..., description="URL of the website to read") - summary: bool = Field( - default=False, - description="When the user's question requires extracting the summarizing content of the webpage, " - "set it to true." - ) - cursor: int = Field( - default=0, - description="Start reading from this character." - "Use when the first response was truncated" - "and you want to continue reading the page." - "The value cannot exceed 24000.", - ) - - -class WebReaderTool(BaseTool): - """Reader tool for getting website title and contents. Gives more control than SimpleReaderTool.""" - - name: str = "web_reader" - args_schema: type[BaseModel] = WebReaderToolInput - description: str = "use this to read a website. " \ - "If you can answer the question based on the information provided, " \ - "there is no need to use." - page_contents: str = None - url: str = None - max_chunk_length: int = 4000 - summary_chunk_tokens: int = 4000 - summary_chunk_overlap: int = 0 - summary_separators: list[str] = ["\n\n", "。", ".", " ", ""] - continue_reading: bool = True - model_config: ModelConfigEntity - model_parameters: dict[str, Any] - - def _run(self, url: str, summary: bool = False, cursor: int = 0) -> str: - try: - if not self.page_contents or self.url != url: - page_contents = get_url(url) - self.page_contents = page_contents - self.url = url - else: - page_contents = self.page_contents - except Exception as e: - return f'Read this website failed, caused by: {str(e)}.' - - if summary: - character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( - chunk_size=self.summary_chunk_tokens, - chunk_overlap=self.summary_chunk_overlap, - separators=self.summary_separators - ) - - texts = character_splitter.split_text(page_contents) - docs = [Document(page_content=t) for t in texts] - - if len(docs) == 0 or docs[0].page_content.endswith('TEXT:'): - return "No content found." - - # only use first 5 docs - if len(docs) > 5: - docs = docs[:5] - - chain = self.get_summary_chain() - try: - page_contents = chain.run(docs) - except Exception as e: - return f'Read this website failed, caused by: {str(e)}.' - else: - page_contents = page_result(page_contents, cursor, self.max_chunk_length) - - if self.continue_reading and len(page_contents) >= self.max_chunk_length: - page_contents += f"\nPAGE WAS TRUNCATED. IF YOU FIND INFORMATION THAT CAN ANSWER QUESTION " \ - f"THEN DIRECT ANSWER AND STOP INVOKING web_reader TOOL, OTHERWISE USE " \ - f"CURSOR={cursor+len(page_contents)} TO CONTINUE READING." - - return page_contents - - async def _arun(self, url: str) -> str: - raise NotImplementedError - - def get_summary_chain(self) -> RefineDocumentsChain: - initial_chain = LLMChain( - model_config=self.model_config, - prompt=refine_prompts.PROMPT, - parameters=self.model_parameters - ) - refine_chain = LLMChain( - model_config=self.model_config, - prompt=refine_prompts.REFINE_PROMPT, - parameters=self.model_parameters - ) - return RefineDocumentsChain( - initial_llm_chain=initial_chain, - refine_llm_chain=refine_chain, - document_variable_name="text", - initial_response_name="existing_answer", - callbacks=self.callbacks - ) - - def page_result(text: str, cursor: int, max_length: int) -> str: """Page through `text` and return a substring of `max_length` characters starting from `cursor`.""" return text[cursor: cursor + max_length] diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 2851624ba1..2e21e56266 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,7 +1,7 @@ import re import uuid -from core.agent.agent_executor import PlanningStrategy +from core.entities.agent_entities import PlanningStrategy from core.external_data_tool.factory import ExternalDataToolFactory from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers import model_provider_factory diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 568974b74f..6d5a0537d3 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -133,8 +133,9 @@ class HitTestingService: if embedding_length <= 1: return [{'x': 0, 'y': 0}] - concatenate_data = np.array(embeddings).reshape(embedding_length, -1) - # concatenate_data = np.concatenate(embeddings) + noise = np.random.normal(0, 1e-4, np.array(embeddings).shape) + concatenate_data = np.array(embeddings) + noise + concatenate_data = concatenate_data.reshape(embedding_length, -1) perplexity = embedding_length / 2 + 1 if perplexity >= embedding_length: diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 923e44dd85..778b4e51d3 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -1,3 +1,5 @@ + +from flask import current_app from flask_login import current_user from extensions.ext_database import db @@ -31,7 +33,15 @@ class WorkspaceService: can_replace_logo = FeatureService.get_features(tenant_info['id']).can_replace_logo - if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]): - tenant_info['custom_config'] = tenant.custom_config_dict + if can_replace_logo and TenantService.has_roles(tenant, + [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]): + base_url = current_app.config.get('FILES_URL') + replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None + remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False) + + tenant_info['custom_config'] = { + 'remove_webapp_brand': remove_webapp_brand, + 'replace_webapp_logo': replace_webapp_logo, + } return tenant_info diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 16e4affc91..37e109c847 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -40,7 +40,6 @@ def clean_dataset_task(dataset_id: str, tenant_id: str, indexing_technique: str, indexing_technique=indexing_technique, index_struct=index_struct, collection_binding_id=collection_binding_id, - doc_form=doc_form ) documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all() diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 84e2029705..a646158dbd 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -58,7 +58,8 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): notion_workspace_id=workspace_id, notion_obj_id=page_id, notion_page_type=page_type, - notion_access_token=data_source_binding.access_token + notion_access_token=data_source_binding.access_token, + tenant_id=document.tenant_id ) last_edited_time = loader.get_notion_last_edited_time() diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_llm.py b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py index 1af21f147e..0d6c144929 100644 --- a/api/tests/integration_tests/model_runtime/wenxin/test_llm.py +++ b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py @@ -7,18 +7,18 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.wenxin.llm.llm import ErnieBotLarguageModel +from core.model_runtime.model_providers.wenxin.llm.llm import ErnieBotLargeLanguageModel def test_predefined_models(): - model = ErnieBotLarguageModel() + model = ErnieBotLargeLanguageModel() model_schemas = model.predefined_models() assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) def test_validate_credentials_for_chat_model(): sleep(3) - model = ErnieBotLarguageModel() + model = ErnieBotLargeLanguageModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( @@ -39,7 +39,7 @@ def test_validate_credentials_for_chat_model(): def test_invoke_model_ernie_bot(): sleep(3) - model = ErnieBotLarguageModel() + model = ErnieBotLargeLanguageModel() response = model.invoke( model='ernie-bot', @@ -67,7 +67,7 @@ def test_invoke_model_ernie_bot(): def test_invoke_model_ernie_bot_turbo(): sleep(3) - model = ErnieBotLarguageModel() + model = ErnieBotLargeLanguageModel() response = model.invoke( model='ernie-bot-turbo', @@ -95,7 +95,7 @@ def test_invoke_model_ernie_bot_turbo(): def test_invoke_model_ernie_8k(): sleep(3) - model = ErnieBotLarguageModel() + model = ErnieBotLargeLanguageModel() response = model.invoke( model='ernie-bot-8k', @@ -123,7 +123,7 @@ def test_invoke_model_ernie_8k(): def test_invoke_model_ernie_bot_4(): sleep(3) - model = ErnieBotLarguageModel() + model = ErnieBotLargeLanguageModel() response = model.invoke( model='ernie-bot-4', @@ -151,7 +151,7 @@ def test_invoke_model_ernie_bot_4(): def test_invoke_stream_model(): sleep(3) - model = ErnieBotLarguageModel() + model = ErnieBotLargeLanguageModel() response = model.invoke( model='ernie-bot', @@ -182,7 +182,7 @@ def test_invoke_stream_model(): def test_invoke_model_with_system(): sleep(3) - model = ErnieBotLarguageModel() + model = ErnieBotLargeLanguageModel() response = model.invoke( model='ernie-bot', @@ -212,7 +212,7 @@ def test_invoke_model_with_system(): def test_invoke_with_search(): sleep(3) - model = ErnieBotLarguageModel() + model = ErnieBotLargeLanguageModel() response = model.invoke( model='ernie-bot', @@ -250,7 +250,7 @@ def test_invoke_with_search(): def test_get_num_tokens(): sleep(3) - model = ErnieBotLarguageModel() + model = ErnieBotLargeLanguageModel() response = model.get_num_tokens( model='ernie-bot', diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index e3a7bdbbe2..7cd09fd6ea 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3.1' services: # API service api: - image: langgenius/dify-api:0.5.6 + image: langgenius/dify-api:0.5.7 restart: always environment: # Startup mode, 'api' starts the API server. @@ -135,7 +135,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.5.6 + image: langgenius/dify-api:0.5.7 restart: always environment: # Startup mode, 'worker' starts the Celery worker for processing the queue. @@ -206,7 +206,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.5.6 + image: langgenius/dify-web:0.5.7 restart: always environment: EDITION: SELF_HOSTED diff --git a/web/app/(commonLayout)/apps/Apps.tsx b/web/app/(commonLayout)/apps/Apps.tsx index 6b95e96886..63e9bc457b 100644 --- a/web/app/(commonLayout)/apps/Apps.tsx +++ b/web/app/(commonLayout)/apps/Apps.tsx @@ -12,6 +12,7 @@ import { useAppContext } from '@/context/app-context' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { CheckModal } from '@/hooks/use-pay' import TabSliderNew from '@/app/components/base/tab-slider-new' +import { useTabSearchParams } from '@/hooks/use-tab-searchparams' import { SearchLg } from '@/app/components/base/icons/src/vender/line/general' import { XCircle } from '@/app/components/base/icons/src/vender/solid/general' @@ -35,7 +36,9 @@ const getKey = ( const Apps = () => { const { t } = useTranslation() const { isCurrentWorkspaceManager } = useAppContext() - const [activeTab, setActiveTab] = useState('all') + const [activeTab, setActiveTab] = useTabSearchParams({ + defaultTab: 'all', + }) const [keywords, setKeywords] = useState('') const [searchKeywords, setSearchKeywords] = useState('') diff --git a/web/app/(commonLayout)/datasets/Container.tsx b/web/app/(commonLayout)/datasets/Container.tsx index cb2f3ed9cc..5c2f227222 100644 --- a/web/app/(commonLayout)/datasets/Container.tsx +++ b/web/app/(commonLayout)/datasets/Container.tsx @@ -1,7 +1,7 @@ 'use client' // Libraries -import { useRef, useState } from 'react' +import { useRef } from 'react' import { useTranslation } from 'react-i18next' import useSWR from 'swr' @@ -15,6 +15,9 @@ import TabSliderNew from '@/app/components/base/tab-slider-new' // Services import { fetchDatasetApiBaseUrl } from '@/service/datasets' +// Hooks +import { useTabSearchParams } from '@/hooks/use-tab-searchparams' + const Container = () => { const { t } = useTranslation() @@ -23,7 +26,9 @@ const Container = () => { { value: 'api', text: t('dataset.datasetsApi') }, ] - const [activeTab, setActiveTab] = useState('dataset') + const [activeTab, setActiveTab] = useTabSearchParams({ + defaultTab: 'dataset', + }) const containerRef = useRef(null) const { data } = useSWR(activeTab === 'dataset' ? null : '/datasets/api-base-info', fetchDatasetApiBaseUrl) diff --git a/web/app/components/app/annotation/header-opts/index.tsx b/web/app/components/app/annotation/header-opts/index.tsx index 90b1a9672e..aba3b6324c 100644 --- a/web/app/components/app/annotation/header-opts/index.tsx +++ b/web/app/components/app/annotation/header-opts/index.tsx @@ -42,6 +42,7 @@ const HeaderOptions: FC = ({ const { locale } = useContext(I18n) const { CSVDownloader, Type } = useCSVDownloader() const [list, setList] = useState([]) + const annotationUnavailable = list.length === 0 const listTransformer = (list: AnnotationItemBasic[]) => list.map( (item: AnnotationItemBasic) => { @@ -116,11 +117,11 @@ const HeaderOptions: FC = ({ ...list.map(item => [item.question, item.answer]), ]} > - + CSV - + JSONL diff --git a/web/app/components/app/annotation/header-opts/style.module.css b/web/app/components/app/annotation/header-opts/style.module.css index 29d43f449d..68234aed00 100644 --- a/web/app/components/app/annotation/header-opts/style.module.css +++ b/web/app/components/app/annotation/header-opts/style.module.css @@ -19,7 +19,7 @@ } .actionItem { - @apply h-9 py-2 px-3 mx-1 flex items-center space-x-2 hover:bg-gray-100 rounded-lg cursor-pointer; + @apply h-9 py-2 px-3 mx-1 flex items-center space-x-2 hover:bg-gray-100 rounded-lg cursor-pointer disabled:opacity-50; width: calc(100% - 0.5rem); } @@ -35,4 +35,4 @@ left: 4px; transform: translateX(-100%); box-shadow: 0px 12px 16px -4px rgba(16, 24, 40, 0.08), 0px 4px 6px -2px rgba(16, 24, 40, 0.03); -} \ No newline at end of file +} diff --git a/web/app/components/app/configuration/features/chat-group/opening-statement/index.tsx b/web/app/components/app/configuration/features/chat-group/opening-statement/index.tsx index 29ecce5281..6be76210da 100644 --- a/web/app/components/app/configuration/features/chat-group/opening-statement/index.tsx +++ b/web/app/components/app/configuration/features/chat-group/opening-statement/index.tsx @@ -18,7 +18,7 @@ import { getNewVar } from '@/utils/var' import { varHighlightHTML } from '@/app/components/app/configuration/base/var-highlight' import { Plus, Trash03 } from '@/app/components/base/icons/src/vender/line/general' -const MAX_QUESTION_NUM = 3 +const MAX_QUESTION_NUM = 5 export type IOpeningStatementProps = { value: string diff --git a/web/app/components/base/logo/logo-site.tsx b/web/app/components/base/logo/logo-site.tsx index 9d9bccfaf8..65569c8c99 100644 --- a/web/app/components/base/logo/logo-site.tsx +++ b/web/app/components/base/logo/logo-site.tsx @@ -1,15 +1,17 @@ import type { FC } from 'react' +import classNames from 'classnames' type LogoSiteProps = { className?: string } + const LogoSite: FC = ({ className, }) => { return ( ) diff --git a/web/app/components/billing/billing-page/index.tsx b/web/app/components/billing/billing-page/index.tsx index 494851ea5c..843d0995e5 100644 --- a/web/app/components/billing/billing-page/index.tsx +++ b/web/app/components/billing/billing-page/index.tsx @@ -1,7 +1,8 @@ 'use client' import type { FC } from 'react' -import React, { useEffect } from 'react' +import React from 'react' import { useTranslation } from 'react-i18next' +import useSWR from 'swr' import PlanComp from '../plan' import { ReceiptList } from '../../base/icons/src/vender/line/financeAndECommerce' import { LinkExternal01 } from '../../base/icons/src/vender/line/general' @@ -12,17 +13,11 @@ import { useProviderContext } from '@/context/provider-context' const Billing: FC = () => { const { t } = useTranslation() const { isCurrentWorkspaceManager } = useAppContext() - const [billingUrl, setBillingUrl] = React.useState('') const { enableBilling } = useProviderContext() - - useEffect(() => { - if (!enableBilling || !isCurrentWorkspaceManager) - return - (async () => { - const { url } = await fetchBillingUrl() - setBillingUrl(url) - })() - }, [isCurrentWorkspaceManager]) + const { data: billingUrl } = useSWR( + (!enableBilling || !isCurrentWorkspaceManager) ? null : ['/billing/invoices'], + () => fetchBillingUrl().then(data => data.url), + ) return ( @@ -39,4 +34,5 @@ const Billing: FC = () => { ) } + export default React.memo(Billing) diff --git a/web/app/components/custom/custom-web-app-brand/index.tsx b/web/app/components/custom/custom-web-app-brand/index.tsx index 4817cfddab..857706bf26 100644 --- a/web/app/components/custom/custom-web-app-brand/index.tsx +++ b/web/app/components/custom/custom-web-app-brand/index.tsx @@ -16,8 +16,6 @@ import { updateCurrentWorkspace, } from '@/service/common' import { useAppContext } from '@/context/app-context' -import { API_PREFIX } from '@/config' -import { getPurifyHref } from '@/utils' const ALLOW_FILE_EXTENSIONS = ['svg', 'png'] @@ -123,7 +121,7 @@ const CustomWebAppBrand = () => { POWERED BY { webappLogo - ? + ? : } diff --git a/web/app/components/develop/template/template.en.mdx b/web/app/components/develop/template/template.en.mdx index 9bc994551b..f930cfe1c9 100644 --- a/web/app/components/develop/template/template.en.mdx +++ b/web/app/components/develop/template/template.en.mdx @@ -289,9 +289,9 @@ The text generation application offers non-session support and is ideal for tran ### Request Example - + ```bash {{ title: 'cURL' }} - curl -X POST 'https://cloud.dify.ai/v1/completion-messages/:task_id/stop' \ + curl -X POST '${props.appDetail.api_base_url}/completion-messages/:task_id/stop' \ -H 'Authorization: Bearer {api_key}' \ -H 'Content-Type: application/json' \ --data-raw '{ diff --git a/web/app/components/develop/template/template.zh.mdx b/web/app/components/develop/template/template.zh.mdx index 6e2ff881d1..8153906d0a 100644 --- a/web/app/components/develop/template/template.zh.mdx +++ b/web/app/components/develop/template/template.zh.mdx @@ -266,9 +266,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - `result` (string) 固定返回 success - + ```bash {{ title: 'cURL' }} - curl -X POST 'https://cloud.dify.ai/v1/completion-messages/:task_id/stop' \ + curl -X POST '${props.appDetail.api_base_url}/completion-messages/:task_id/stop' \ -H 'Authorization: Bearer {api_key}' \ -H 'Content-Type: application/json' \ --data-raw '{ diff --git a/web/app/components/develop/template/template_chat.en.mdx b/web/app/components/develop/template/template_chat.en.mdx index 9e8dd69874..e102108154 100644 --- a/web/app/components/develop/template/template_chat.en.mdx +++ b/web/app/components/develop/template/template_chat.en.mdx @@ -344,9 +344,9 @@ Chat applications support session persistence, allowing previous chat history to ### Request Example - + ```bash {{ title: 'cURL' }} - curl -X POST 'https://cloud.dify.ai/v1/chat-messages/:task_id/stop' \ + curl -X POST '${props.appDetail.api_base_url}/chat-messages/:task_id/stop' \ -H 'Authorization: Bearer {api_key}' \ -H 'Content-Type: application/json' \ --data-raw '{ @@ -1025,9 +1025,9 @@ Chat applications support session persistence, allowing previous chat history to - (string) url of icon - + ```bash {{ title: 'cURL' }} - curl -X GET 'https://cloud.dify.ai/v1/meta?user=abc-123' \ + curl -X GET '${props.appDetail.api_base_url}/meta?user=abc-123' \ -H 'Authorization: Bearer {api_key}' ``` diff --git a/web/app/components/develop/template/template_chat.zh.mdx b/web/app/components/develop/template/template_chat.zh.mdx index 47f64466e7..7bc3cd5337 100644 --- a/web/app/components/develop/template/template_chat.zh.mdx +++ b/web/app/components/develop/template/template_chat.zh.mdx @@ -360,9 +360,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - `result` (string) 固定返回 success - + ```bash {{ title: 'cURL' }} - curl -X POST 'https://cloud.dify.ai/v1/chat-messages/:task_id/stop' \ + curl -X POST '${props.appDetail.api_base_url}/chat-messages/:task_id/stop' \ -H 'Authorization: Bearer {api_key}' \ -H 'Content-Type: application/json' \ --data-raw '{ @@ -1022,9 +1022,9 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - (string) 图标URL - + ```bash {{ title: 'cURL' }} - curl -X GET 'https://cloud.dify.ai/v1/meta?user=abc-123' \ + curl -X GET '${props.appDetail.api_base_url}/meta?user=abc-123' \ -H 'Authorization: Bearer {api_key}' ``` diff --git a/web/app/components/explore/app-list/index.tsx b/web/app/components/explore/app-list/index.tsx index 9e28ea3ded..ac4229974e 100644 --- a/web/app/components/explore/app-list/index.tsx +++ b/web/app/components/explore/app-list/index.tsx @@ -1,18 +1,20 @@ 'use client' -import React, { useEffect } from 'react' +import React from 'react' import cn from 'classnames' import { useRouter } from 'next/navigation' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' +import useSWR from 'swr' import Toast from '../../base/toast' import s from './style.module.css' import ExploreContext from '@/context/explore-context' -import type { App, AppCategory } from '@/models/explore' +import type { App } from '@/models/explore' import Category from '@/app/components/explore/category' import AppCard from '@/app/components/explore/app-card' import { fetchAppDetail, fetchAppList } from '@/service/explore' import { createApp } from '@/service/apps' +import { useTabSearchParams } from '@/hooks/use-tab-searchparams' import CreateAppModal from '@/app/components/explore/create-app-modal' import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' import Loading from '@/app/components/base/loading' @@ -36,32 +38,43 @@ const Apps = ({ const { isCurrentWorkspaceManager } = useAppContext() const router = useRouter() const { hasEditPermission } = useContext(ExploreContext) - const [currCategory, setCurrCategory] = React.useState('') - const [allList, setAllList] = React.useState([]) - const [isLoaded, setIsLoaded] = React.useState(false) + const allCategoriesEn = t('explore.apps.allCategories') + const [currCategory, setCurrCategory] = useTabSearchParams({ + defaultTab: allCategoriesEn, + }) + const { + data: { categories, allList }, + } = useSWR( + ['/explore/apps'], + () => + fetchAppList().then(({ categories, recommended_apps }) => ({ + categories, + allList: recommended_apps.sort((a, b) => a.position - b.position), + })), + { + fallbackData: { + categories: [], + allList: [], + }, + }, + ) const currList = (() => { - if (currCategory === '') + if (currCategory === allCategoriesEn) return allList return allList.filter(item => item.category === currCategory) })() - const [categories, setCategories] = React.useState([]) - useEffect(() => { - (async () => { - const { categories, recommended_apps }: any = await fetchAppList() - const sortedRecommendedApps = [...recommended_apps] - sortedRecommendedApps.sort((a, b) => a.position - b.position) // position from small to big - setCategories(categories) - setAllList(sortedRecommendedApps) - setIsLoaded(true) - })() - }, []) - const [currApp, setCurrApp] = React.useState(null) const [isShowCreateModal, setIsShowCreateModal] = React.useState(false) - const onCreate: CreateAppModalProps['onConfirm'] = async ({ name, icon, icon_background, description }) => { - const { app_model_config: model_config } = await fetchAppDetail(currApp?.app.id as string) + const onCreate: CreateAppModalProps['onConfirm'] = async ({ + name, + icon, + icon_background, + }) => { + const { app_model_config: model_config } = await fetchAppDetail( + currApp?.app.id as string, + ) try { const app = await createApp({ @@ -78,17 +91,20 @@ const Apps = ({ message: t('app.newApp.appCreated'), }) localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') - router.push(`/app/${app.id}/${isCurrentWorkspaceManager ? 'configuration' : 'overview'}`) + router.push( + `/app/${app.id}/${isCurrentWorkspaceManager ? 'configuration' : 'overview' + }`, + ) } catch (e) { Toast.notify({ type: 'error', message: t('app.newApp.appCreateFailed') }) } } - if (!isLoaded) { + if (!categories) { return ( - - + + ) } diff --git a/web/app/components/header/account-setting/index.tsx b/web/app/components/header/account-setting/index.tsx index a83542ef05..d0f5db243a 100644 --- a/web/app/components/header/account-setting/index.tsx +++ b/web/app/components/header/account-setting/index.tsx @@ -138,16 +138,12 @@ export default function AccountSetting({ ] const scrollRef = useRef(null) const [scrolled, setScrolled] = useState(false) - const scrollHandle = (e: Event) => { - if ((e.target as HTMLDivElement).scrollTop > 0) - setScrolled(true) - - else - setScrolled(false) - } useEffect(() => { const targetElement = scrollRef.current - + const scrollHandle = (e: Event) => { + const userScrolled = (e.target as HTMLDivElement).scrollTop > 0 + setScrolled(userScrolled) + } targetElement?.addEventListener('scroll', scrollHandle) return () => { targetElement?.removeEventListener('scroll', scrollHandle) diff --git a/web/app/components/header/index.tsx b/web/app/components/header/index.tsx index 3a818903e4..957b1a442f 100644 --- a/web/app/components/header/index.tsx +++ b/web/app/components/header/index.tsx @@ -54,7 +54,7 @@ const Header = () => { } {!isMobile && <> - + >} diff --git a/web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx b/web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx index 1deef1b531..9da0ff7dcc 100644 --- a/web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx +++ b/web/app/components/tools/edit-custom-collection-modal/config-credentials.tsx @@ -3,11 +3,13 @@ import type { FC } from 'react' import React from 'react' import { useTranslation } from 'react-i18next' import cn from 'classnames' +import Tooltip from '../../base/tooltip' +import { HelpCircle } from '../../base/icons/src/vender/line/general' import type { Credential } from '@/app/components/tools/types' import Drawer from '@/app/components/base/drawer-plus' import Button from '@/app/components/base/button' import Radio from '@/app/components/base/radio/ui' -import { AuthType } from '@/app/components/tools/types' +import { AuthHeaderPrefix, AuthType } from '@/app/components/tools/types' type Props = { credential: Credential @@ -18,9 +20,9 @@ const keyClassNames = 'py-2 leading-5 text-sm font-medium text-gray-900' type ItemProps = { text: string - value: AuthType + value: AuthType | AuthHeaderPrefix isChecked: boolean - onClick: (value: AuthType) => void + onClick: (value: AuthType | AuthHeaderPrefix) => void } const SelectItem: FC = ({ text, value, isChecked, onClick }) => { @@ -31,7 +33,6 @@ const SelectItem: FC = ({ text, value, isChecked, onClick }) => { > {text} - ) } @@ -43,6 +44,7 @@ const ConfigCredential: FC = ({ }) => { const { t } = useTranslation() const [tempCredential, setTempCredential] = React.useState(credential) + return ( = ({ text={t('tools.createTool.authMethod.types.none')} value={AuthType.none} isChecked={tempCredential.auth_type === AuthType.none} - onClick={value => setTempCredential({ ...tempCredential, auth_type: value })} + onClick={value => setTempCredential({ ...tempCredential, auth_type: value as AuthType })} /> setTempCredential({ ...tempCredential, auth_type: value })} + onClick={value => setTempCredential({ + ...tempCredential, + auth_type: value as AuthType, + api_key_header: tempCredential.api_key_header || 'Authorization', + api_key_value: tempCredential.api_key_value || '', + api_key_header_prefix: tempCredential.api_key_header_prefix || AuthHeaderPrefix.custom, + })} /> {tempCredential.auth_type === AuthType.apiKey && ( <> + {t('tools.createTool.authHeaderPrefix.title')} + + setTempCredential({ ...tempCredential, api_key_header_prefix: value as AuthHeaderPrefix })} + /> + setTempCredential({ ...tempCredential, api_key_header_prefix: value as AuthHeaderPrefix })} + /> + setTempCredential({ ...tempCredential, api_key_header_prefix: value as AuthHeaderPrefix })} + /> + - {t('tools.createTool.authMethod.key')} + + {t('tools.createTool.authMethod.key')} + + {t('tools.createTool.authMethod.keyTooltip')} + + } + > + + + setTempCredential({ ...tempCredential, api_key_header: e.target.value })} @@ -83,7 +124,6 @@ const ConfigCredential: FC = ({ placeholder={t('tools.createTool.authMethod.types.apiKeyPlaceholder')!} /> - {t('tools.createTool.authMethod.value')} = ({ const { t } = useTranslation() const isAdd = !payload const isEdit = !!payload + const [editFirst, setEditFirst] = useState(!isAdd) const [paramsSchemas, setParamsSchemas] = useState(payload?.tools || []) const [customCollection, setCustomCollection, getCustomCollection] = useGetState(isAdd @@ -44,6 +45,8 @@ const EditCustomCollectionModal: FC = ({ provider: '', credentials: { auth_type: AuthType.none, + api_key_header: 'Authorization', + api_key_header_prefix: AuthHeaderPrefix.basic, }, icon: { content: '🕵️', diff --git a/web/app/components/tools/index.tsx b/web/app/components/tools/index.tsx index d540cfb2b4..2f6eeb07e5 100644 --- a/web/app/components/tools/index.tsx +++ b/web/app/components/tools/index.tsx @@ -16,6 +16,7 @@ import EditCustomToolModal from './edit-custom-collection-modal' import NoCustomTool from './info/no-custom-tool' import NoSearchRes from './info/no-search-res' import NoCustomToolPlaceholder from './no-custom-tool-placeholder' +import { useTabSearchParams } from '@/hooks/use-tab-searchparams' import TabSlider from '@/app/components/base/tab-slider' import { createCustomCollection, fetchCollectionList as doFetchCollectionList, fetchBuiltInToolList, fetchCustomToolList } from '@/service/tools' import type { AgentTool } from '@/types/app' @@ -68,7 +69,9 @@ const Tools: FC = ({ })() const [query, setQuery] = useState('') - const [collectionType, setCollectionType] = useState(collectionTypeOptions[0].value) + const [collectionType, setCollectionType] = useTabSearchParams({ + defaultTab: collectionTypeOptions[0].value, + }) const showCollectionList = (() => { let typeFilteredList: Collection[] = [] diff --git a/web/app/components/tools/tool-list/index.tsx b/web/app/components/tools/tool-list/index.tsx index 58fcf5613b..3bee3292e6 100644 --- a/web/app/components/tools/tool-list/index.tsx +++ b/web/app/components/tools/tool-list/index.tsx @@ -3,7 +3,7 @@ import type { FC } from 'react' import React, { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import cn from 'classnames' -import { CollectionType, LOC } from '../types' +import { AuthHeaderPrefix, AuthType, CollectionType, LOC } from '../types' import type { Collection, CustomCollectionBackend, Tool } from '../types' import Loading from '../../base/loading' import { ArrowNarrowRight } from '../../base/icons/src/vender/line/arrows' @@ -53,6 +53,10 @@ const ToolList: FC = ({ (async () => { if (collection.type === CollectionType.custom) { const res = await fetchCustomCollection(collection.name) + if (res.credentials.auth_type === AuthType.apiKey && !res.credentials.api_key_header_prefix) { + if (res.credentials.api_key_value) + res.credentials.api_key_header_prefix = AuthHeaderPrefix.custom + } setCustomCollection({ ...res, provider: collection.name, diff --git a/web/app/components/tools/types.ts b/web/app/components/tools/types.ts index e06e011767..389276e81c 100644 --- a/web/app/components/tools/types.ts +++ b/web/app/components/tools/types.ts @@ -9,10 +9,17 @@ export enum AuthType { apiKey = 'api_key', } +export enum AuthHeaderPrefix { + basic = 'basic', + bearer = 'bearer', + custom = 'custom', +} + export type Credential = { 'auth_type': AuthType 'api_key_header'?: string 'api_key_value'?: string + 'api_key_header_prefix'?: AuthHeaderPrefix } export enum CollectionType { diff --git a/web/hooks/use-pay.tsx b/web/hooks/use-pay.tsx index 6ec4795940..cd43e8dd99 100644 --- a/web/hooks/use-pay.tsx +++ b/web/hooks/use-pay.tsx @@ -122,7 +122,7 @@ export const useCheckNotion = () => { const notionCode = searchParams.get('code') const notionError = searchParams.get('error') const { data } = useSWR( - canBinding + (canBinding && notionCode) ? `/oauth/data-source/binding/notion?code=${notionCode}` : null, fetchDataSourceNotionBinding, diff --git a/web/hooks/use-tab-searchparams.ts b/web/hooks/use-tab-searchparams.ts new file mode 100644 index 0000000000..f34c7be8c7 --- /dev/null +++ b/web/hooks/use-tab-searchparams.ts @@ -0,0 +1,34 @@ +import { usePathname, useRouter, useSearchParams } from 'next/navigation' + +type UseTabSearchParamsOptions = { + defaultTab: string + routingBehavior?: 'push' | 'replace' + searchParamName?: string +} + +/** + * Custom hook to manage tab state via URL search parameters in a Next.js application. + * This hook allows for syncing the active tab with the browser's URL, enabling bookmarking and sharing of URLs with a specific tab activated. + * + * @param {UseTabSearchParamsOptions} options Configuration options for the hook: + * - `defaultTab`: The tab to default to when no tab is specified in the URL. + * - `routingBehavior`: Optional. Determines how changes to the active tab update the browser's history ('push' or 'replace'). Default is 'push'. + * - `searchParamName`: Optional. The name of the search parameter that holds the tab state in the URL. Default is 'category'. + * @returns A tuple where the first element is the active tab and the second element is a function to set the active tab. + */ +export const useTabSearchParams = ({ + defaultTab, + routingBehavior = 'push', + searchParamName = 'category', +}: UseTabSearchParamsOptions) => { + const router = useRouter() + const pathName = usePathname() + const searchParams = useSearchParams() + const activeTab = searchParams.get(searchParamName) || defaultTab + + const setActiveTab = (newActiveTab: string) => { + router[routingBehavior](`${pathName}?${searchParamName}=${newActiveTab}`) + } + + return [activeTab, setActiveTab] as const +} diff --git a/web/i18n/en-US/tools.ts b/web/i18n/en-US/tools.ts index 9746fab8bc..30e075210c 100644 --- a/web/i18n/en-US/tools.ts +++ b/web/i18n/en-US/tools.ts @@ -51,6 +51,7 @@ const translation = { authMethod: { title: 'Authorization method', type: 'Authorization type', + keyTooltip: 'Http Header Key, You can leave it with "Authorization" if you have no idea what it is or set it to a custom value', types: { none: 'None', api_key: 'API Key', @@ -60,6 +61,14 @@ const translation = { key: 'Key', value: 'Value', }, + authHeaderPrefix: { + title: 'Auth Type', + types: { + basic: 'Basic', + bearer: 'Bearer', + custom: 'Custom', + }, + }, privacyPolicy: 'Privacy policy', privacyPolicyPlaceholder: 'Please enter privacy policy', }, diff --git a/web/i18n/pt-BR/tools.ts b/web/i18n/pt-BR/tools.ts index 9e2da08a1a..3434bd15ee 100644 --- a/web/i18n/pt-BR/tools.ts +++ b/web/i18n/pt-BR/tools.ts @@ -58,6 +58,13 @@ const translation = { key: 'Chave', value: 'Valor', }, + authHeaderPrefix: { + types: { + basic: 'Basic', + bearer: 'Bearer', + custom: 'Custom', + }, + }, privacyPolicy: 'Política de Privacidade', privacyPolicyPlaceholder: 'Digite a política de privacidade', }, diff --git a/web/i18n/uk-UA/tools.ts b/web/i18n/uk-UA/tools.ts index 56b4371cfb..307149c386 100644 --- a/web/i18n/uk-UA/tools.ts +++ b/web/i18n/uk-UA/tools.ts @@ -58,6 +58,13 @@ const translation = { key: 'Ключ', value: 'Значення', }, + authHeaderPrefix: { + types: { + basic: 'Basic', + bearer: 'Bearer', + custom: 'Custom', + }, + }, privacyPolicy: 'Політика конфіденційності', privacyPolicyPlaceholder: 'Введіть політику конфіденційності', }, diff --git a/web/i18n/zh-Hans/tools.ts b/web/i18n/zh-Hans/tools.ts index ff3b5c0fb8..c709d62547 100644 --- a/web/i18n/zh-Hans/tools.ts +++ b/web/i18n/zh-Hans/tools.ts @@ -51,6 +51,7 @@ const translation = { authMethod: { title: '鉴权方法', type: '鉴权类型', + keyTooltip: 'HTTP 头部名称,如果你不知道是什么,可以将其保留为 Authorization 或设置为自定义值', types: { none: '无', api_key: 'API Key', @@ -60,6 +61,14 @@ const translation = { key: '键', value: '值', }, + authHeaderPrefix: { + title: '鉴权头部前缀', + types: { + basic: 'Basic', + bearer: 'Bearer', + custom: 'Custom', + }, + }, privacyPolicy: '隐私协议', privacyPolicyPlaceholder: '请输入隐私协议', }, diff --git a/web/models/explore.ts b/web/models/explore.ts index a15fc17da3..739c325cfc 100644 --- a/web/models/explore.ts +++ b/web/models/explore.ts @@ -15,7 +15,7 @@ export type App = { app_id: string description: string copyright: string - privacy_policy: string + privacy_policy: string | null category: AppCategory position: number is_listed: boolean diff --git a/web/package.json b/web/package.json index 08f0a75b1b..fe9c60778b 100644 --- a/web/package.json +++ b/web/package.json @@ -1,6 +1,6 @@ { "name": "dify-web", - "version": "0.5.6", + "version": "0.5.7", "private": true, "scripts": { "dev": "next dev", diff --git a/web/service/explore.ts b/web/service/explore.ts index 60fb8b1128..bb608f7ee5 100644 --- a/web/service/explore.ts +++ b/web/service/explore.ts @@ -1,7 +1,11 @@ import { del, get, patch, post } from './base' +import type { App, AppCategory } from '@/models/explore' export const fetchAppList = () => { - return get('/explore/apps') + return get<{ + categories: AppCategory[] + recommended_apps: App[] + }>('/explore/apps') } export const fetchAppDetail = (id: string): Promise => {