mirror of https://github.com/langgenius/dify.git
merge main
This commit is contained in:
commit
f7a90f2660
11
README.md
11
README.md
|
|
@ -21,6 +21,17 @@
|
|||
<img alt="Docker Pulls" src="https://img.shields.io/docker/pulls/langgenius/dify-web"></a>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://discord.com/events/1082486657678311454/1211724120996188220" target="_blank">
|
||||
Dify.AI Upcoming Meetup Event [👉 Click to Join the Event Here 👈]
|
||||
</a>
|
||||
<ul align="center" style="text-decoration: none; list-style: none;">
|
||||
<li> US EST: 09:00 (9:00 AM)</li>
|
||||
<li> CET: 15:00 (3:00 PM)</li>
|
||||
<li> CST: 22:00 (10:00 PM)</li>
|
||||
</ul>
|
||||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://dify.ai/blog/dify-ai-unveils-ai-agent-creating-gpts-and-assistants-with-various-llms" target="_blank">
|
||||
Dify.AI Unveils AI Agent: Creating GPTs and Assistants with Various LLMs
|
||||
|
|
|
|||
Binary file not shown.
158
api/commands.py
158
api/commands.py
|
|
@ -6,15 +6,15 @@ import click
|
|||
from flask import current_app
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.embedding.cached_embedding import CacheEmbedding
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, password_pattern, valid_password
|
||||
from libs.rsa import generate_key_pair
|
||||
from models.account import Tenant
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.model import Account
|
||||
from models.provider import Provider, ProviderModel
|
||||
|
||||
|
|
@ -124,14 +124,15 @@ def reset_encrypt_key_pair():
|
|||
'the asymmetric key pair of workspace {} has been reset.'.format(tenant.id), fg='green'))
|
||||
|
||||
|
||||
@click.command('create-qdrant-indexes', help='Create qdrant indexes.')
|
||||
def create_qdrant_indexes():
|
||||
@click.command('vdb-migrate', help='migrate vector db.')
|
||||
def vdb_migrate():
|
||||
"""
|
||||
Migrate other vector database datas to Qdrant.
|
||||
Migrate vector database datas to target vector database .
|
||||
"""
|
||||
click.echo(click.style('Start create qdrant indexes.', fg='green'))
|
||||
click.echo(click.style('Start migrate vector db.', fg='green'))
|
||||
create_count = 0
|
||||
|
||||
config = current_app.config
|
||||
vector_type = config.get('VECTOR_STORE')
|
||||
page = 1
|
||||
while True:
|
||||
try:
|
||||
|
|
@ -140,54 +141,101 @@ def create_qdrant_indexes():
|
|||
except NotFound:
|
||||
break
|
||||
|
||||
model_manager = ModelManager()
|
||||
|
||||
page += 1
|
||||
for dataset in datasets:
|
||||
if dataset.index_struct_dict:
|
||||
if dataset.index_struct_dict['type'] != 'qdrant':
|
||||
try:
|
||||
click.echo('Create dataset qdrant index: {}'.format(dataset.id))
|
||||
try:
|
||||
embedding_model = model_manager.get_model_instance(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=dataset.embedding_model
|
||||
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
from core.index.vector_index.qdrant_vector_index import QdrantConfig, QdrantVectorIndex
|
||||
|
||||
index = QdrantVectorIndex(
|
||||
dataset=dataset,
|
||||
config=QdrantConfig(
|
||||
endpoint=current_app.config.get('QDRANT_URL'),
|
||||
api_key=current_app.config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
if index:
|
||||
index.create_qdrant_dataset(dataset)
|
||||
index_struct = {
|
||||
"type": 'qdrant',
|
||||
"vector_store": {
|
||||
"class_prefix": dataset.index_struct_dict['vector_store']['class_prefix']}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct)
|
||||
db.session.commit()
|
||||
create_count += 1
|
||||
else:
|
||||
click.echo('passed.')
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
try:
|
||||
click.echo('Create dataset vdb index: {}'.format(dataset.id))
|
||||
if dataset.index_struct_dict:
|
||||
if dataset.index_struct_dict['type'] == vector_type:
|
||||
continue
|
||||
if vector_type == "weaviate":
|
||||
dataset_id = dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
index_struct_dict = {
|
||||
"type": 'weaviate',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
elif vector_type == "qdrant":
|
||||
if dataset.collection_binding_id:
|
||||
dataset_collection_binding = db.session.query(DatasetCollectionBinding). \
|
||||
filter(DatasetCollectionBinding.id == dataset.collection_binding_id). \
|
||||
one_or_none()
|
||||
if dataset_collection_binding:
|
||||
collection_name = dataset_collection_binding.collection_name
|
||||
else:
|
||||
raise ValueError('Dataset Collection Bindings is not exist!')
|
||||
else:
|
||||
dataset_id = dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
index_struct_dict = {
|
||||
"type": 'qdrant',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
|
||||
elif vector_type == "milvus":
|
||||
dataset_id = dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
index_struct_dict = {
|
||||
"type": 'milvus',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
dataset.index_struct = json.dumps(index_struct_dict)
|
||||
else:
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
|
||||
vector = Vector(dataset)
|
||||
click.echo(f"vdb_migrate {dataset.id}")
|
||||
|
||||
try:
|
||||
vector.delete()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
dataset_documents = db.session.query(DatasetDocument).filter(
|
||||
DatasetDocument.dataset_id == dataset.id,
|
||||
DatasetDocument.indexing_status == 'completed',
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
).all()
|
||||
|
||||
documents = []
|
||||
for dataset_document in dataset_documents:
|
||||
segments = db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.status == 'completed',
|
||||
DocumentSegment.enabled == True
|
||||
).all()
|
||||
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
}
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
|
||||
if documents:
|
||||
try:
|
||||
vector.create(documents)
|
||||
except Exception as e:
|
||||
raise e
|
||||
click.echo(f"Dataset {dataset.id} create successfully.")
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
create_count += 1
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
click.echo(
|
||||
click.style('Create dataset index error: {} {}'.format(e.__class__.__name__, str(e)),
|
||||
fg='red'))
|
||||
continue
|
||||
|
||||
click.echo(click.style('Congratulations! Create {} dataset indexes.'.format(create_count), fg='green'))
|
||||
|
||||
|
|
@ -196,4 +244,4 @@ def register_commands(app):
|
|||
app.cli.add_command(reset_password)
|
||||
app.cli.add_command(reset_email)
|
||||
app.cli.add_command(reset_encrypt_key_pair)
|
||||
app.cli.add_command(create_qdrant_indexes)
|
||||
app.cli.add_command(vdb_migrate)
|
||||
|
|
|
|||
|
|
@ -38,7 +38,9 @@ DEFAULTS = {
|
|||
'LOG_LEVEL': 'INFO',
|
||||
'HOSTED_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_OPENAI_TRIAL_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_TRIAL_MODELS': 'gpt-3.5-turbo,gpt-3.5-turbo-1106,gpt-3.5-turbo-instruct,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,text-davinci-003',
|
||||
'HOSTED_OPENAI_PAID_ENABLED': 'False',
|
||||
'HOSTED_OPENAI_PAID_MODELS': 'gpt-4,gpt-4-turbo-preview,gpt-4-1106-preview,gpt-4-0125-preview,gpt-3.5-turbo,gpt-3.5-turbo-16k,gpt-3.5-turbo-16k-0613,gpt-3.5-turbo-1106,gpt-3.5-turbo-0613,gpt-3.5-turbo-0125,gpt-3.5-turbo-instruct,text-davinci-003',
|
||||
'HOSTED_AZURE_OPENAI_ENABLED': 'False',
|
||||
'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200,
|
||||
'HOSTED_ANTHROPIC_QUOTA_LIMIT': 600000,
|
||||
|
|
@ -88,7 +90,7 @@ class Config:
|
|||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.5.6"
|
||||
self.CURRENT_VERSION = "0.5.7"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
|
|
@ -261,8 +263,10 @@ class Config:
|
|||
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
|
||||
self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION')
|
||||
self.HOSTED_OPENAI_TRIAL_ENABLED = get_bool_env('HOSTED_OPENAI_TRIAL_ENABLED')
|
||||
self.HOSTED_OPENAI_TRIAL_MODELS = get_env('HOSTED_OPENAI_TRIAL_MODELS')
|
||||
self.HOSTED_OPENAI_QUOTA_LIMIT = int(get_env('HOSTED_OPENAI_QUOTA_LIMIT'))
|
||||
self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED')
|
||||
self.HOSTED_OPENAI_PAID_MODELS = get_env('HOSTED_OPENAI_PAID_MODELS')
|
||||
|
||||
self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED')
|
||||
self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY')
|
||||
|
|
|
|||
|
|
@ -124,19 +124,13 @@ class AppListApi(Resource):
|
|||
available_models_names = [f'{model.provider.provider}.{model.model}' for model in available_models]
|
||||
provider_model = f"{model_config_dict['model']['provider']}.{model_config_dict['model']['name']}"
|
||||
if provider_model not in available_models_names:
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if not model_instance:
|
||||
if not default_model_entity:
|
||||
raise ProviderNotInitializeError(
|
||||
"No Default System Reasoning Model available. Please configure "
|
||||
"in the Settings -> Model Provider.")
|
||||
else:
|
||||
model_config_dict["model"]["provider"] = model_instance.provider
|
||||
model_config_dict["model"]["name"] = model_instance.model
|
||||
model_config_dict["model"]["provider"] = default_model_entity.provider
|
||||
model_config_dict["model"]["name"] = default_model_entity.model
|
||||
|
||||
model_configuration = AppModelConfigService.validate_configuration(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
|
|
|
|||
|
|
@ -178,7 +178,8 @@ class DataSourceNotionApi(Resource):
|
|||
notion_workspace_id=workspace_id,
|
||||
notion_obj_id=page_id,
|
||||
notion_page_type=page_type,
|
||||
notion_access_token=data_source_binding.access_token
|
||||
notion_access_token=data_source_binding.access_token,
|
||||
tenant_id=current_user.current_tenant_id
|
||||
)
|
||||
|
||||
text_docs = extractor.extract()
|
||||
|
|
@ -208,7 +209,8 @@ class DataSourceNotionApi(Resource):
|
|||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page['page_id'],
|
||||
"notion_page_type": page['type']
|
||||
"notion_page_type": page['type'],
|
||||
"tenant_id": current_user.current_tenant_id
|
||||
},
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
|
|
|
|||
|
|
@ -298,7 +298,8 @@ class DatasetIndexingEstimateApi(Resource):
|
|||
notion_info={
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_obj_id": page['page_id'],
|
||||
"notion_page_type": page['type']
|
||||
"notion_page_type": page['type'],
|
||||
"tenant_id": current_user.current_tenant_id
|
||||
},
|
||||
document_model=args['doc_form']
|
||||
)
|
||||
|
|
|
|||
|
|
@ -455,7 +455,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
|
|||
notion_info={
|
||||
"notion_workspace_id": data_source_info['notion_workspace_id'],
|
||||
"notion_obj_id": data_source_info['notion_page_id'],
|
||||
"notion_page_type": data_source_info['type']
|
||||
"notion_page_type": data_source_info['type'],
|
||||
"tenant_id": current_user.current_tenant_id
|
||||
},
|
||||
document_model=document.doc_form
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,27 +0,0 @@
|
|||
from extensions.ext_database import db
|
||||
from models.model import EndUser
|
||||
|
||||
|
||||
def create_or_update_end_user_for_user_id(app_model, user_id):
|
||||
"""
|
||||
Create or update session terminal based on user ID.
|
||||
"""
|
||||
end_user = db.session.query(EndUser) \
|
||||
.filter(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == 'service_api'
|
||||
).first()
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type='service_api',
|
||||
is_anonymous=True,
|
||||
session_id=user_id
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return end_user
|
||||
|
|
@ -1,16 +1,16 @@
|
|||
import json
|
||||
|
||||
from flask import current_app
|
||||
from flask_restful import fields, marshal_with
|
||||
from flask_restful import fields, marshal_with, Resource
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig
|
||||
from models.tools import ApiToolProvider
|
||||
|
||||
|
||||
class AppParameterApi(AppApiResource):
|
||||
class AppParameterApi(Resource):
|
||||
"""Resource for app variables."""
|
||||
|
||||
variable_fields = {
|
||||
|
|
@ -42,8 +42,9 @@ class AppParameterApi(AppApiResource):
|
|||
'system_parameters': fields.Nested(system_parameters_fields)
|
||||
}
|
||||
|
||||
@validate_app_token
|
||||
@marshal_with(parameters_fields)
|
||||
def get(self, app_model: App, end_user):
|
||||
def get(self, app_model: App):
|
||||
"""Retrieve app parameters."""
|
||||
app_model_config = app_model.app_model_config
|
||||
|
||||
|
|
@ -64,8 +65,9 @@ class AppParameterApi(AppApiResource):
|
|||
}
|
||||
}
|
||||
|
||||
class AppMetaApi(AppApiResource):
|
||||
def get(self, app_model: App, end_user):
|
||||
class AppMetaApi(Resource):
|
||||
@validate_app_token
|
||||
def get(self, app_model: App):
|
||||
"""Get app meta"""
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restful import reqparse
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
|
|
@ -17,10 +17,10 @@ from controllers.service_api.app.error import (
|
|||
ProviderQuotaExceededError,
|
||||
UnsupportedAudioTypeError,
|
||||
)
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from models.model import App, AppModelConfig
|
||||
from models.model import App, AppModelConfig, EndUser
|
||||
from services.audio_service import AudioService
|
||||
from services.errors.audio import (
|
||||
AudioTooLargeServiceError,
|
||||
|
|
@ -30,8 +30,9 @@ from services.errors.audio import (
|
|||
)
|
||||
|
||||
|
||||
class AudioApi(AppApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
class AudioApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
app_model_config: AppModelConfig = app_model.app_model_config
|
||||
|
||||
if not app_model_config.speech_to_text_dict['enabled']:
|
||||
|
|
@ -73,11 +74,11 @@ class AudioApi(AppApiResource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
class TextApi(AppApiResource):
|
||||
def post(self, app_model: App, end_user):
|
||||
class TextApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('text', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('streaming', type=bool, required=False, nullable=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -85,7 +86,7 @@ class TextApi(AppApiResource):
|
|||
response = AudioService.transcript_tts(
|
||||
tenant_id=app_model.tenant_id,
|
||||
text=args['text'],
|
||||
end_user=args['user'],
|
||||
end_user=end_user,
|
||||
voice=app_model.app_model_config.text_to_speech_dict.get('voice'),
|
||||
streaming=args['streaming']
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,12 +4,11 @@ from collections.abc import Generator
|
|||
from typing import Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
from flask_restful import reqparse
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
|
|
@ -19,17 +18,19 @@ from controllers.service_api.app.error import (
|
|||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, EndUser
|
||||
from services.completion_service import CompletionService
|
||||
|
||||
|
||||
class CompletionApi(AppApiResource):
|
||||
def post(self, app_model, end_user):
|
||||
class CompletionApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != 'completion':
|
||||
raise AppUnavailableError()
|
||||
|
||||
|
|
@ -38,16 +39,12 @@ class CompletionApi(AppApiResource):
|
|||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
try:
|
||||
|
|
@ -82,29 +79,20 @@ class CompletionApi(AppApiResource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
class CompletionStopApi(AppApiResource):
|
||||
def post(self, app_model, end_user, task_id):
|
||||
class CompletionStopApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id):
|
||||
if app_model.mode != 'completion':
|
||||
raise AppUnavailableError()
|
||||
|
||||
if end_user is None:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
user = args.get('user')
|
||||
if user is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user)
|
||||
else:
|
||||
raise ValueError("arg user muse be input.")
|
||||
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
||||
|
||||
class ChatApi(AppApiResource):
|
||||
def post(self, app_model, end_user):
|
||||
class ChatApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
|
|
@ -114,7 +102,6 @@ class ChatApi(AppApiResource):
|
|||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('user', type=str, required=True, nullable=False, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json')
|
||||
|
||||
|
|
@ -122,9 +109,6 @@ class ChatApi(AppApiResource):
|
|||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
app_model=app_model,
|
||||
|
|
@ -157,22 +141,12 @@ class ChatApi(AppApiResource):
|
|||
raise InternalServerError()
|
||||
|
||||
|
||||
class ChatStopApi(AppApiResource):
|
||||
def post(self, app_model, end_user, task_id):
|
||||
class ChatStopApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True))
|
||||
def post(self, app_model: App, end_user: EndUser, task_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
if end_user is None:
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('user', required=True, nullable=False, type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
user = args.get('user')
|
||||
if user is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user)
|
||||
else:
|
||||
raise ValueError("arg user muse be input.")
|
||||
|
||||
ApplicationQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id)
|
||||
|
||||
return {'result': 'success'}, 200
|
||||
|
|
|
|||
|
|
@ -1,52 +1,44 @@
|
|||
from flask import request
|
||||
from flask_restful import marshal_with, reqparse
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields
|
||||
from libs.helper import uuid_value
|
||||
from models.model import App, EndUser
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
|
||||
class ConversationApi(AppApiResource):
|
||||
class ConversationApi(Resource):
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@marshal_with(conversation_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('last_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('user', type=str, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
return ConversationService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit'])
|
||||
except services.errors.conversation.LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
|
||||
class ConversationDetailApi(AppApiResource):
|
||||
class ConversationDetailApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def delete(self, app_model, end_user, c_id):
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
user = request.get_json().get('user')
|
||||
|
||||
if end_user is None and user is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user)
|
||||
|
||||
try:
|
||||
ConversationService.delete(app_model, conversation_id, end_user)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
|
|
@ -54,10 +46,11 @@ class ConversationDetailApi(AppApiResource):
|
|||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class ConversationRenameApi(AppApiResource):
|
||||
class ConversationRenameApi(Resource):
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
@marshal_with(simple_conversation_fields)
|
||||
def post(self, app_model, end_user, c_id):
|
||||
def post(self, app_model: App, end_user: EndUser, c_id):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
|
|
@ -65,13 +58,9 @@ class ConversationRenameApi(AppApiResource):
|
|||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
return ConversationService.rename(
|
||||
app_model,
|
||||
|
|
|
|||
|
|
@ -1,30 +1,27 @@
|
|||
from flask import request
|
||||
from flask_restful import marshal_with
|
||||
from flask_restful import Resource, marshal_with
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import (
|
||||
FileTooLargeError,
|
||||
NoFileUploadedError,
|
||||
TooManyFilesError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from fields.file_fields import file_fields
|
||||
from models.model import App, EndUser
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class FileApi(AppApiResource):
|
||||
class FileApi(Resource):
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM))
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model, end_user):
|
||||
def post(self, app_model: App, end_user: EndUser):
|
||||
|
||||
file = request.files['file']
|
||||
user_args = request.form.get('user')
|
||||
|
||||
if end_user is None and user_args is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user_args)
|
||||
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
|
|
|
|||
|
|
@ -1,20 +1,18 @@
|
|||
from flask_restful import fields, marshal_with, reqparse
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse
|
||||
from flask_restful.inputs import int_range
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import NotChatAppError
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from extensions.ext_database import db
|
||||
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
|
||||
from fields.conversation_fields import message_file_fields
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from models.model import EndUser, Message
|
||||
from models.model import App, EndUser
|
||||
from services.message_service import MessageService
|
||||
|
||||
|
||||
class MessageListApi(AppApiResource):
|
||||
class MessageListApi(Resource):
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
}
|
||||
|
|
@ -70,8 +68,9 @@ class MessageListApi(AppApiResource):
|
|||
'data': fields.List(fields.Nested(message_fields))
|
||||
}
|
||||
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
@marshal_with(message_infinite_scroll_pagination_fields)
|
||||
def get(self, app_model, end_user):
|
||||
def get(self, app_model: App, end_user: EndUser):
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
|
||||
|
|
@ -79,12 +78,8 @@ class MessageListApi(AppApiResource):
|
|||
parser.add_argument('conversation_id', required=True, type=uuid_value, location='args')
|
||||
parser.add_argument('first_id', type=uuid_value, location='args')
|
||||
parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args')
|
||||
parser.add_argument('user', type=str, location='args')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
return MessageService.pagination_by_first_id(app_model, end_user,
|
||||
args['conversation_id'], args['first_id'], args['limit'])
|
||||
|
|
@ -94,18 +89,15 @@ class MessageListApi(AppApiResource):
|
|||
raise NotFound("First Message Not Exists.")
|
||||
|
||||
|
||||
class MessageFeedbackApi(AppApiResource):
|
||||
def post(self, app_model, end_user, message_id):
|
||||
class MessageFeedbackApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
|
||||
def post(self, app_model: App, end_user: EndUser, message_id):
|
||||
message_id = str(message_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
MessageService.create_feedback(app_model, message_id, end_user, args['rating'])
|
||||
except services.errors.message.MessageNotExistsError:
|
||||
|
|
@ -114,29 +106,17 @@ class MessageFeedbackApi(AppApiResource):
|
|||
return {'result': 'success'}
|
||||
|
||||
|
||||
class MessageSuggestedApi(AppApiResource):
|
||||
def get(self, app_model, end_user, message_id):
|
||||
class MessageSuggestedApi(Resource):
|
||||
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
|
||||
def get(self, app_model: App, end_user: EndUser, message_id):
|
||||
message_id = str(message_id)
|
||||
if app_model.mode != 'chat':
|
||||
raise NotChatAppError()
|
||||
try:
|
||||
message = db.session.query(Message).filter(
|
||||
Message.id == message_id,
|
||||
Message.app_id == app_model.id,
|
||||
).first()
|
||||
|
||||
if end_user is None and message.from_end_user_id is not None:
|
||||
user = db.session.query(EndUser) \
|
||||
.filter(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.id == message.from_end_user_id,
|
||||
EndUser.type == 'service_api'
|
||||
).first()
|
||||
else:
|
||||
user = end_user
|
||||
try:
|
||||
questions = MessageService.get_suggested_questions_after_answer(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
user=end_user,
|
||||
message_id=message_id,
|
||||
check_enabled=False
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,22 +1,40 @@
|
|||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from flask_restful import Resource
|
||||
from pydantic import BaseModel
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.login import _get_user
|
||||
from models.account import Account, Tenant, TenantAccountJoin
|
||||
from models.model import ApiToken, App
|
||||
from models.model import ApiToken, App, EndUser
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def validate_app_token(view=None):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
class WhereisUserArg(Enum):
|
||||
"""
|
||||
Enum for whereis_user_arg.
|
||||
"""
|
||||
QUERY = 'query'
|
||||
JSON = 'json'
|
||||
FORM = 'form'
|
||||
|
||||
|
||||
class FetchUserArg(BaseModel):
|
||||
fetch_from: WhereisUserArg
|
||||
required: bool = False
|
||||
|
||||
|
||||
def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optional[FetchUserArg] = None):
|
||||
def decorator(view_func):
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args, **kwargs):
|
||||
api_token = validate_and_get_api_token('app')
|
||||
|
||||
app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
|
||||
|
|
@ -29,16 +47,35 @@ def validate_app_token(view=None):
|
|||
if not app_model.enable_api:
|
||||
raise NotFound()
|
||||
|
||||
return view(app_model, None, *args, **kwargs)
|
||||
return decorated
|
||||
kwargs['app_model'] = app_model
|
||||
|
||||
if view:
|
||||
if fetch_user_arg:
|
||||
if fetch_user_arg.fetch_from == WhereisUserArg.QUERY:
|
||||
user_id = request.args.get('user')
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.JSON:
|
||||
user_id = request.get_json().get('user')
|
||||
elif fetch_user_arg.fetch_from == WhereisUserArg.FORM:
|
||||
user_id = request.form.get('user')
|
||||
else:
|
||||
# use default-user
|
||||
user_id = None
|
||||
|
||||
if not user_id and fetch_user_arg.required:
|
||||
raise ValueError("Arg user must be provided.")
|
||||
|
||||
if user_id:
|
||||
user_id = str(user_id)
|
||||
|
||||
kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id)
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
return decorated_view
|
||||
|
||||
if view is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(view)
|
||||
|
||||
# if view is None, it means that the decorator is used without parentheses
|
||||
# use the decorator as a function for method_decorators
|
||||
return decorator
|
||||
|
||||
|
||||
def cloud_edition_billing_resource_check(resource: str,
|
||||
api_token_type: str,
|
||||
|
|
@ -128,8 +165,33 @@ def validate_and_get_api_token(scope=None):
|
|||
return api_token
|
||||
|
||||
|
||||
class AppApiResource(Resource):
|
||||
method_decorators = [validate_app_token]
|
||||
def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] = None) -> EndUser:
|
||||
"""
|
||||
Create or update session terminal based on user ID.
|
||||
"""
|
||||
if not user_id:
|
||||
user_id = 'DEFAULT-USER'
|
||||
|
||||
end_user = db.session.query(EndUser) \
|
||||
.filter(
|
||||
EndUser.tenant_id == app_model.tenant_id,
|
||||
EndUser.app_id == app_model.id,
|
||||
EndUser.session_id == user_id,
|
||||
EndUser.type == 'service_api'
|
||||
).first()
|
||||
|
||||
if end_user is None:
|
||||
end_user = EndUser(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
type='service_api',
|
||||
is_anonymous=True if user_id == 'DEFAULT-USER' else False,
|
||||
session_id=user_id
|
||||
)
|
||||
db.session.add(end_user)
|
||||
db.session.commit()
|
||||
|
||||
return end_user
|
||||
|
||||
|
||||
class DatasetApiResource(Resource):
|
||||
|
|
|
|||
|
|
@ -1,49 +0,0 @@
|
|||
from typing import cast
|
||||
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
|
||||
class CalcTokenMixin:
|
||||
|
||||
def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int:
|
||||
"""
|
||||
Got the rest tokens available for the model after excluding messages tokens and completion max tokens
|
||||
|
||||
:param model_config:
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
return 0
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
messages
|
||||
)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
|
||||
|
||||
return rest_tokens
|
||||
|
||||
|
||||
class ExceededLLMTokensLimitError(Exception):
|
||||
pass
|
||||
|
|
@ -1,361 +0,0 @@
|
|||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent
|
||||
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
from langchain.prompts.chat import BaseMessagePromptTemplate
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_model_config: ModelConfigEntity = None
|
||||
model_config: ModelConfigEntity
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator
|
||||
def validate_llm(cls, values: dict) -> dict:
|
||||
return values
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None,
|
||||
system_message: Optional[SystemMessage] = SystemMessage(
|
||||
content="You are a helpful AI assistant."
|
||||
),
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseSingleActionAgent:
|
||||
prompt = cls.create_prompt(
|
||||
extra_prompt_messages=extra_prompt_messages,
|
||||
system_message=system_message,
|
||||
)
|
||||
return cls(
|
||||
model_config=model_config,
|
||||
llm=FakeLLM(response=''),
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
agent_llm_callback=agent_llm_callback,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
original_max_tokens = 0
|
||||
for parameter_rule in self.model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
original_max_tokens = (self.model_config.parameters.get(parameter_rule.name)
|
||||
or self.model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
|
||||
self.model_config.parameters['max_tokens'] = 40
|
||||
|
||||
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
|
||||
messages = prompt.to_messages()
|
||||
|
||||
try:
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||
model=self.model_config.model,
|
||||
)
|
||||
|
||||
tools = []
|
||||
for function in self.functions:
|
||||
tool = PromptMessageTool(
|
||||
**function
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
model_parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
self.model_config.parameters['max_tokens'] = original_max_tokens
|
||||
|
||||
return True if result.message.tool_calls else False
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
|
||||
selected_inputs = {
|
||||
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
|
||||
}
|
||||
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
|
||||
prompt = self.prompt.format_prompt(**full_inputs)
|
||||
messages = prompt.to_messages()
|
||||
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
|
||||
# summarize messages if rest_tokens < 0
|
||||
try:
|
||||
prompt_messages = self.summarize_messages_if_needed(prompt_messages, functions=self.functions)
|
||||
except ExceededLLMTokensLimitError as e:
|
||||
return AgentFinish(return_values={"output": str(e)}, log=str(e))
|
||||
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=self.model_config.provider_model_bundle,
|
||||
model=self.model_config.model,
|
||||
)
|
||||
|
||||
tools = []
|
||||
for function in self.functions:
|
||||
tool = PromptMessageTool(
|
||||
**function
|
||||
)
|
||||
|
||||
tools.append(tool)
|
||||
|
||||
result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
stream=False,
|
||||
callbacks=[self.agent_llm_callback] if self.agent_llm_callback else [],
|
||||
model_parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
|
||||
ai_message = AIMessage(
|
||||
content=result.message.content or "",
|
||||
additional_kwargs={
|
||||
'function_call': {
|
||||
'id': result.message.tool_calls[0].id,
|
||||
**result.message.tool_calls[0].function.dict()
|
||||
} if result.message.tool_calls else None
|
||||
}
|
||||
)
|
||||
agent_decision = _parse_ai_message(ai_message)
|
||||
|
||||
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
|
||||
return agent_decision
|
||||
|
||||
@classmethod
|
||||
def get_system_message(cls):
|
||||
return SystemMessage(content="You are a helpful AI assistant.\n"
|
||||
"The current date or current time you know is wrong.\n"
|
||||
"Respond directly if appropriate.")
|
||||
|
||||
def return_stopped_response(
|
||||
self,
|
||||
early_stopping_method: str,
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
**kwargs: Any,
|
||||
) -> AgentFinish:
|
||||
try:
|
||||
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
|
||||
except ValueError:
|
||||
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]:
|
||||
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
|
||||
rest_tokens = self.get_message_rest_tokens(
|
||||
self.model_config,
|
||||
messages,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
|
||||
if rest_tokens >= 0:
|
||||
return messages
|
||||
|
||||
system_message = None
|
||||
human_message = None
|
||||
should_summary_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, SystemMessage):
|
||||
system_message = message
|
||||
elif isinstance(message, HumanMessage):
|
||||
human_message = message
|
||||
else:
|
||||
should_summary_messages.append(message)
|
||||
|
||||
if len(should_summary_messages) > 2:
|
||||
ai_message = should_summary_messages[-2]
|
||||
function_message = should_summary_messages[-1]
|
||||
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
|
||||
self.moving_summary_index = len(should_summary_messages)
|
||||
else:
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
new_messages = [system_message, human_message]
|
||||
|
||||
if self.moving_summary_index == 0:
|
||||
should_summary_messages.insert(0, human_message)
|
||||
|
||||
self.moving_summary_buffer = self.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
|
||||
new_messages.append(AIMessage(content=self.moving_summary_buffer))
|
||||
new_messages.append(ai_message)
|
||||
new_messages.append(function_message)
|
||||
|
||||
return new_messages
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: list[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
human_prefix="Human",
|
||||
ai_prefix="AI",
|
||||
)
|
||||
|
||||
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
if model_config.provider == 'azure_openai':
|
||||
model = model_config.model
|
||||
model = model.replace("gpt-35", "gpt-3.5")
|
||||
else:
|
||||
model = model_config.credentials.get("base_model_name")
|
||||
|
||||
tiktoken_ = _import_tiktoken()
|
||||
try:
|
||||
encoding = tiktoken_.encoding_for_model(model)
|
||||
except KeyError:
|
||||
model = "cl100k_base"
|
||||
encoding = tiktoken_.get_encoding(model)
|
||||
|
||||
if model.startswith("gpt-3.5-turbo"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-4"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
for m in messages:
|
||||
message = _convert_message_to_dict(m)
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
if key == "function_call":
|
||||
for f_key, f_value in value.items():
|
||||
num_tokens += len(encoding.encode(f_key))
|
||||
num_tokens += len(encoding.encode(f_value))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
|
||||
if kwargs.get('functions'):
|
||||
for function in kwargs.get('functions'):
|
||||
num_tokens += len(encoding.encode('name'))
|
||||
num_tokens += len(encoding.encode(function.get("name")))
|
||||
num_tokens += len(encoding.encode('description'))
|
||||
num_tokens += len(encoding.encode(function.get("description")))
|
||||
parameters = function.get("parameters")
|
||||
num_tokens += len(encoding.encode('parameters'))
|
||||
if 'title' in parameters:
|
||||
num_tokens += len(encoding.encode('title'))
|
||||
num_tokens += len(encoding.encode(parameters.get("title")))
|
||||
num_tokens += len(encoding.encode('type'))
|
||||
num_tokens += len(encoding.encode(parameters.get("type")))
|
||||
if 'properties' in parameters:
|
||||
num_tokens += len(encoding.encode('properties'))
|
||||
for key, value in parameters.get('properties').items():
|
||||
num_tokens += len(encoding.encode(key))
|
||||
for field_key, field_value in value.items():
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
if field_key == 'enum':
|
||||
for enum_field in field_value:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(enum_field))
|
||||
else:
|
||||
num_tokens += len(encoding.encode(field_key))
|
||||
num_tokens += len(encoding.encode(str(field_value)))
|
||||
if 'required' in parameters:
|
||||
num_tokens += len(encoding.encode('required'))
|
||||
for required_field in parameters['required']:
|
||||
num_tokens += 3
|
||||
num_tokens += len(encoding.encode(required_field))
|
||||
|
||||
return num_tokens
|
||||
|
|
@ -1,306 +0,0 @@
|
|||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from langchain import BasePromptTemplate, PromptTemplate
|
||||
from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
|
||||
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
|
||||
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.memory.prompt import SUMMARY_PROMPT
|
||||
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
OutputParserException,
|
||||
get_buffer_string,
|
||||
)
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||
Valid "action" values: "Final Answer" or {tool_names}
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
```
|
||||
{{{{
|
||||
"action": $TOOL_NAME,
|
||||
"action_input": $INPUT
|
||||
}}}}
|
||||
```
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
Thought: consider previous and subsequent steps
|
||||
Action:
|
||||
```
|
||||
$JSON_BLOB
|
||||
```
|
||||
Observation: action result
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
Thought: I know what to respond
|
||||
Action:
|
||||
```
|
||||
{{{{
|
||||
"action": "Final Answer",
|
||||
"action_input": "Final response to human"
|
||||
}}}}
|
||||
```"""
|
||||
|
||||
|
||||
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
|
||||
moving_summary_buffer: str = ""
|
||||
moving_summary_index: int = 0
|
||||
summary_model_config: ModelConfigEntity = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def should_use_agent(self, query: str):
|
||||
"""
|
||||
return should use agent
|
||||
Using the ReACT mode to determine whether an agent is needed is costly,
|
||||
so it's better to just use an Agent for reasoning, which is cheaper.
|
||||
|
||||
:param query:
|
||||
:return:
|
||||
"""
|
||||
return True
|
||||
|
||||
def plan(
|
||||
self,
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Given input, decided what to do.
|
||||
|
||||
Args:
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observatons
|
||||
callbacks: Callbacks to run.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
Action specifying what tool to use.
|
||||
"""
|
||||
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
|
||||
prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
|
||||
|
||||
messages = []
|
||||
if prompts:
|
||||
messages = prompts[0].to_messages()
|
||||
|
||||
prompt_messages = lc_messages_to_prompt_messages(messages)
|
||||
|
||||
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
|
||||
if rest_tokens < 0:
|
||||
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
|
||||
|
||||
try:
|
||||
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
agent_decision = self.output_parser.parse(full_output)
|
||||
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
|
||||
tool_inputs = agent_decision.tool_input
|
||||
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
|
||||
tool_inputs['query'] = kwargs['input']
|
||||
agent_decision.tool_input = tool_inputs
|
||||
return agent_decision
|
||||
except OutputParserException:
|
||||
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
|
||||
"I don't know how to respond to that."}, "")
|
||||
|
||||
def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs):
|
||||
if len(intermediate_steps) >= 2 and self.summary_model_config:
|
||||
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
|
||||
should_summary_messages = [AIMessage(content=observation)
|
||||
for _, observation in should_summary_intermediate_steps]
|
||||
if self.moving_summary_index == 0:
|
||||
should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
|
||||
|
||||
self.moving_summary_index = len(intermediate_steps)
|
||||
else:
|
||||
error_msg = "Exceeded LLM tokens limit, stopped."
|
||||
raise ExceededLLMTokensLimitError(error_msg)
|
||||
|
||||
if self.moving_summary_buffer and 'chat_history' in kwargs:
|
||||
kwargs["chat_history"].pop()
|
||||
|
||||
self.moving_summary_buffer = self.predict_new_summary(
|
||||
messages=should_summary_messages,
|
||||
existing_summary=self.moving_summary_buffer
|
||||
)
|
||||
|
||||
if 'chat_history' in kwargs:
|
||||
kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
|
||||
|
||||
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
|
||||
|
||||
def predict_new_summary(
|
||||
self, messages: list[BaseMessage], existing_summary: str
|
||||
) -> str:
|
||||
new_lines = get_buffer_string(
|
||||
messages,
|
||||
human_prefix="Human",
|
||||
ai_prefix="AI",
|
||||
)
|
||||
|
||||
chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
|
||||
return chain.predict(summary=existing_summary, new_lines=new_lines)
|
||||
|
||||
@classmethod
|
||||
def create_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||
) -> BasePromptTemplate:
|
||||
tool_strings = []
|
||||
for tool in tools:
|
||||
args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
|
||||
tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
|
||||
formatted_tools = "\n".join(tool_strings)
|
||||
tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
_memory_prompts = memory_prompts or []
|
||||
messages = [
|
||||
SystemMessagePromptTemplate.from_template(template),
|
||||
*_memory_prompts,
|
||||
HumanMessagePromptTemplate.from_template(human_message_template),
|
||||
]
|
||||
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
|
||||
|
||||
@classmethod
|
||||
def create_completion_prompt(
|
||||
cls,
|
||||
tools: Sequence[BaseTool],
|
||||
prefix: str = PREFIX,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
) -> PromptTemplate:
|
||||
"""Create prompt in the style of the zero shot agent.
|
||||
|
||||
Args:
|
||||
tools: List of tools the agent will have access to, used to format the
|
||||
prompt.
|
||||
prefix: String to put before the list of tools.
|
||||
input_variables: List of input variables the final prompt will expect.
|
||||
|
||||
Returns:
|
||||
A PromptTemplate with the template assembled from the pieces here.
|
||||
"""
|
||||
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
|
||||
Question: {input}
|
||||
Thought: {agent_scratchpad}
|
||||
"""
|
||||
|
||||
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
|
||||
tool_names = ", ".join([tool.name for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
||||
if input_variables is None:
|
||||
input_variables = ["input", "agent_scratchpad"]
|
||||
return PromptTemplate(template=template, input_variables=input_variables)
|
||||
|
||||
def _construct_scratchpad(
|
||||
self, intermediate_steps: list[tuple[AgentAction, str]]
|
||||
) -> str:
|
||||
agent_scratchpad = ""
|
||||
for action, observation in intermediate_steps:
|
||||
agent_scratchpad += action.log
|
||||
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
|
||||
|
||||
if not isinstance(agent_scratchpad, str):
|
||||
raise ValueError("agent_scratchpad should be of type string.")
|
||||
if agent_scratchpad:
|
||||
llm_chain = cast(LLMChain, self.llm_chain)
|
||||
if llm_chain.model_config.mode == "chat":
|
||||
return (
|
||||
f"This was your previous work "
|
||||
f"(but I haven't seen any of it! I only see what "
|
||||
f"you return as final answer):\n{agent_scratchpad}"
|
||||
)
|
||||
else:
|
||||
return agent_scratchpad
|
||||
else:
|
||||
return agent_scratchpad
|
||||
|
||||
@classmethod
|
||||
def from_llm_and_tools(
|
||||
cls,
|
||||
model_config: ModelConfigEntity,
|
||||
tools: Sequence[BaseTool],
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
output_parser: Optional[AgentOutputParser] = None,
|
||||
prefix: str = PREFIX,
|
||||
suffix: str = SUFFIX,
|
||||
human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
|
||||
format_instructions: str = FORMAT_INSTRUCTIONS,
|
||||
input_variables: Optional[list[str]] = None,
|
||||
memory_prompts: Optional[list[BasePromptTemplate]] = None,
|
||||
agent_llm_callback: Optional[AgentLLMCallback] = None,
|
||||
**kwargs: Any,
|
||||
) -> Agent:
|
||||
"""Construct an agent from an LLM and tools."""
|
||||
cls._validate_tools(tools)
|
||||
if model_config.mode == "chat":
|
||||
prompt = cls.create_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
suffix=suffix,
|
||||
human_message_template=human_message_template,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
memory_prompts=memory_prompts,
|
||||
)
|
||||
else:
|
||||
prompt = cls.create_completion_prompt(
|
||||
tools,
|
||||
prefix=prefix,
|
||||
format_instructions=format_instructions,
|
||||
input_variables=input_variables,
|
||||
)
|
||||
llm_chain = LLMChain(
|
||||
model_config=model_config,
|
||||
prompt=prompt,
|
||||
callback_manager=callback_manager,
|
||||
agent_llm_callback=agent_llm_callback,
|
||||
parameters={
|
||||
'temperature': 0.2,
|
||||
'top_p': 0.3,
|
||||
'max_tokens': 1500
|
||||
}
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
_output_parser = output_parser
|
||||
return cls(
|
||||
llm_chain=llm_chain,
|
||||
allowed_tools=tool_names,
|
||||
output_parser=_output_parser,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
|
|
@ -15,7 +14,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
|||
from core.moderation.base import ModerationException
|
||||
from core.tools.entities.tool_entities import ToolRuntimeVariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, Conversation, Message, MessageAgentThought, MessageChain
|
||||
from models.model import App, Conversation, Message, MessageAgentThought
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -173,11 +172,6 @@ class AssistantApplicationRunner(AppRunner):
|
|||
|
||||
# convert db variables to tool variables
|
||||
tool_variables = self._convert_db_variables_to_tool_variables(tool_conversation_variables)
|
||||
|
||||
message_chain = self._init_message_chain(
|
||||
message=message,
|
||||
query=query
|
||||
)
|
||||
|
||||
# init model instance
|
||||
model_instance = ModelInstance(
|
||||
|
|
@ -290,38 +284,6 @@ class AssistantApplicationRunner(AppRunner):
|
|||
'pool': db_variables.variables
|
||||
})
|
||||
|
||||
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
|
||||
"""
|
||||
Init MessageChain
|
||||
:param message: message
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
message_chain = MessageChain(
|
||||
message_id=message.id,
|
||||
type="AgentExecutor",
|
||||
input=json.dumps({
|
||||
"input": query
|
||||
})
|
||||
)
|
||||
|
||||
db.session.add(message_chain)
|
||||
db.session.commit()
|
||||
|
||||
return message_chain
|
||||
|
||||
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
|
||||
"""
|
||||
Save MessageChain
|
||||
:param message_chain: message chain
|
||||
:param output_text: output text
|
||||
:return:
|
||||
"""
|
||||
message_chain.output = json.dumps({
|
||||
"output": output_text
|
||||
})
|
||||
db.session.commit()
|
||||
|
||||
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
|
||||
message: Message) -> LLMUsage:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from core.app_runner.app_runner import AppRunner
|
|||
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, DatasetEntity, InvokeFrom, ModelConfigEntity
|
||||
from core.features.dataset_retrieval import DatasetRetrievalFeature
|
||||
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.moderation.base import ModerationException
|
||||
|
|
|
|||
|
|
@ -175,7 +175,7 @@ class GenerateTaskPipeline:
|
|||
'id': self._message.id,
|
||||
'message_id': self._message.id,
|
||||
'mode': self._conversation.mode,
|
||||
'answer': event.llm_result.message.content,
|
||||
'answer': self._task_state.llm_result.message.content,
|
||||
'metadata': {},
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,12 +3,12 @@ import logging
|
|||
from typing import Optional, cast
|
||||
|
||||
import numpy as np
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.rag.datasource.entity.embedding import Embeddings
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import helper
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
from enum import Enum
|
||||
|
||||
|
||||
class PlanningStrategy(Enum):
|
||||
ROUTER = 'router'
|
||||
REACT_ROUTER = 'react_router'
|
||||
REACT = 'react'
|
||||
FUNCTION_CALL = 'function_call'
|
||||
|
|
@ -1,199 +0,0 @@
|
|||
import logging
|
||||
from typing import Optional, cast
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
|
||||
from core.entities.application_entities import (
|
||||
AgentEntity,
|
||||
AppOrchestrationConfigEntity,
|
||||
InvokeFrom,
|
||||
ModelConfigEntity,
|
||||
)
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers import model_provider_factory
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentRunnerFeature:
|
||||
def __init__(self, tenant_id: str,
|
||||
app_orchestration_config: AppOrchestrationConfigEntity,
|
||||
model_config: ModelConfigEntity,
|
||||
config: AgentEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
message: Message,
|
||||
user_id: str,
|
||||
agent_llm_callback: AgentLLMCallback,
|
||||
callback: AgentLoopGatherCallbackHandler,
|
||||
memory: Optional[TokenBufferMemory] = None,) -> None:
|
||||
"""
|
||||
Agent runner
|
||||
:param tenant_id: tenant id
|
||||
:param app_orchestration_config: app orchestration config
|
||||
:param model_config: model config
|
||||
:param config: dataset config
|
||||
:param queue_manager: queue manager
|
||||
:param message: message
|
||||
:param user_id: user id
|
||||
:param agent_llm_callback: agent llm callback
|
||||
:param callback: callback
|
||||
:param memory: memory
|
||||
"""
|
||||
self.tenant_id = tenant_id
|
||||
self.app_orchestration_config = app_orchestration_config
|
||||
self.model_config = model_config
|
||||
self.config = config
|
||||
self.queue_manager = queue_manager
|
||||
self.message = message
|
||||
self.user_id = user_id
|
||||
self.agent_llm_callback = agent_llm_callback
|
||||
self.callback = callback
|
||||
self.memory = memory
|
||||
|
||||
def run(self, query: str,
|
||||
invoke_from: InvokeFrom) -> Optional[str]:
|
||||
"""
|
||||
Retrieve agent loop result.
|
||||
:param query: query
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
provider = self.config.provider
|
||||
model = self.config.model
|
||||
tool_configs = self.config.tools
|
||||
|
||||
# check model is support tool calling
|
||||
provider_instance = model_provider_factory.get_provider_instance(provider=provider)
|
||||
model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
# get model schema
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=model,
|
||||
credentials=self.model_config.credentials
|
||||
)
|
||||
|
||||
if not model_schema:
|
||||
return None
|
||||
|
||||
planning_strategy = PlanningStrategy.REACT
|
||||
features = model_schema.features
|
||||
if features:
|
||||
if ModelFeature.TOOL_CALL in features \
|
||||
or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.FUNCTION_CALL
|
||||
|
||||
tools = self.to_tools(
|
||||
tool_configs=tool_configs,
|
||||
invoke_from=invoke_from,
|
||||
callbacks=[self.callback, DifyStdOutCallbackHandler()],
|
||||
)
|
||||
|
||||
if len(tools) == 0:
|
||||
return None
|
||||
|
||||
agent_configuration = AgentConfiguration(
|
||||
strategy=planning_strategy,
|
||||
model_config=self.model_config,
|
||||
tools=tools,
|
||||
memory=self.memory,
|
||||
max_iterations=10,
|
||||
max_execution_time=400.0,
|
||||
early_stopping_method="generate",
|
||||
agent_llm_callback=self.agent_llm_callback,
|
||||
callbacks=[self.callback, DifyStdOutCallbackHandler()]
|
||||
)
|
||||
|
||||
agent_executor = AgentExecutor(agent_configuration)
|
||||
|
||||
try:
|
||||
# check if should use agent
|
||||
should_use_agent = agent_executor.should_use_agent(query)
|
||||
if not should_use_agent:
|
||||
return None
|
||||
|
||||
result = agent_executor.run(query)
|
||||
return result.output
|
||||
except Exception as ex:
|
||||
logger.exception("agent_executor run failed")
|
||||
return None
|
||||
|
||||
def to_dataset_retriever_tool(self, tool_config: dict,
|
||||
invoke_from: InvokeFrom) \
|
||||
-> Optional[BaseTool]:
|
||||
"""
|
||||
A dataset tool is a tool that can be used to retrieve information from a dataset
|
||||
:param tool_config: tool config
|
||||
:param invoke_from: invoke from
|
||||
"""
|
||||
show_retrieve_source = self.app_orchestration_config.show_retrieve_source
|
||||
|
||||
hit_callback = DatasetIndexToolCallbackHandler(
|
||||
queue_manager=self.queue_manager,
|
||||
app_id=self.message.app_id,
|
||||
message_id=self.message.id,
|
||||
user_id=self.user_id,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
|
||||
# get dataset from dataset id
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == self.tenant_id,
|
||||
Dataset.id == tool_config.get("id")
|
||||
).first()
|
||||
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
return None
|
||||
|
||||
# pass if dataset is not available
|
||||
if (dataset and dataset.available_document_count == 0
|
||||
and dataset.available_document_count == 0):
|
||||
return None
|
||||
|
||||
# get retrieval model config
|
||||
default_retrieval_model = {
|
||||
'search_method': 'semantic_search',
|
||||
'reranking_enable': False,
|
||||
'reranking_model': {
|
||||
'reranking_provider_name': '',
|
||||
'reranking_model_name': ''
|
||||
},
|
||||
'top_k': 2,
|
||||
'score_threshold_enabled': False
|
||||
}
|
||||
|
||||
retrieval_model_config = dataset.retrieval_model \
|
||||
if dataset.retrieval_model else default_retrieval_model
|
||||
|
||||
# get top k
|
||||
top_k = retrieval_model_config['top_k']
|
||||
|
||||
# get score threshold
|
||||
score_threshold = None
|
||||
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
|
||||
if score_threshold_enabled:
|
||||
score_threshold = retrieval_model_config.get("score_threshold")
|
||||
|
||||
tool = DatasetRetrieverTool.from_dataset(
|
||||
dataset=dataset,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
hit_callbacks=[hit_callback],
|
||||
return_resource=show_retrieve_source,
|
||||
retriever_from=invoke_from.to_source()
|
||||
)
|
||||
|
||||
return tool
|
||||
|
|
@ -59,7 +59,7 @@ class AnnotationReplyFeature:
|
|||
|
||||
documents = vector.search_by_vector(
|
||||
query=query,
|
||||
k=1,
|
||||
top_k=1,
|
||||
score_threshold=score_threshold,
|
||||
filter={
|
||||
'group_id': [dataset.id]
|
||||
|
|
|
|||
|
|
@ -606,36 +606,42 @@ class BaseAssistantApplicationRunner(AppRunner):
|
|||
for message in messages:
|
||||
result.append(UserPromptMessage(content=message.query))
|
||||
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
|
||||
for agent_thought in agent_thoughts:
|
||||
tools = agent_thought.tool
|
||||
if tools:
|
||||
tools = tools.split(';')
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_call_response: list[ToolPromptMessage] = []
|
||||
tool_inputs = json.loads(agent_thought.tool_input)
|
||||
for tool in tools:
|
||||
# generate a uuid for tool call
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
tool_calls.append(AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
if agent_thoughts:
|
||||
for agent_thought in agent_thoughts:
|
||||
tools = agent_thought.tool
|
||||
if tools:
|
||||
tools = tools.split(';')
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_call_response: list[ToolPromptMessage] = []
|
||||
tool_inputs = json.loads(agent_thought.tool_input)
|
||||
for tool in tools:
|
||||
# generate a uuid for tool call
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
tool_calls.append(AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool,
|
||||
arguments=json.dumps(tool_inputs.get(tool, {})),
|
||||
)
|
||||
))
|
||||
tool_call_response.append(ToolPromptMessage(
|
||||
content=agent_thought.observation,
|
||||
name=tool,
|
||||
arguments=json.dumps(tool_inputs.get(tool, {})),
|
||||
)
|
||||
))
|
||||
tool_call_response.append(ToolPromptMessage(
|
||||
content=agent_thought.observation,
|
||||
name=tool,
|
||||
tool_call_id=tool_call_id,
|
||||
))
|
||||
tool_call_id=tool_call_id,
|
||||
))
|
||||
|
||||
result.extend([
|
||||
AssistantPromptMessage(
|
||||
content=agent_thought.thought,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
*tool_call_response
|
||||
])
|
||||
result.extend([
|
||||
AssistantPromptMessage(
|
||||
content=agent_thought.thought,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
*tool_call_response
|
||||
])
|
||||
if not tools:
|
||||
result.append(AssistantPromptMessage(content=agent_thought.thought))
|
||||
else:
|
||||
if message.answer:
|
||||
result.append(AssistantPromptMessage(content=message.answer))
|
||||
|
||||
return result
|
||||
|
|
@ -154,7 +154,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|||
thought='',
|
||||
action_str='',
|
||||
observation='',
|
||||
action=None
|
||||
action=None,
|
||||
)
|
||||
|
||||
# publish agent thought if it's first iteration
|
||||
|
|
@ -469,7 +469,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|||
thought=message.content,
|
||||
action_str='',
|
||||
action=None,
|
||||
observation=None
|
||||
observation=None,
|
||||
)
|
||||
if message.tool_calls:
|
||||
try:
|
||||
|
|
@ -484,7 +484,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|||
elif isinstance(message, ToolPromptMessage):
|
||||
if current_scratchpad:
|
||||
current_scratchpad.observation = message.content
|
||||
|
||||
|
||||
return agent_scratchpad
|
||||
|
||||
def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"],
|
||||
|
|
@ -607,6 +607,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner):
|
|||
prompt_message.content = system_message
|
||||
overridden = True
|
||||
break
|
||||
|
||||
# convert tool prompt messages to user prompt messages
|
||||
for idx, prompt_message in enumerate(prompt_messages):
|
||||
if isinstance(prompt_message, ToolPromptMessage):
|
||||
prompt_messages[idx] = UserPromptMessage(
|
||||
content=prompt_message.content
|
||||
)
|
||||
|
||||
if not overridden:
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
|
|
|
|||
|
|
@ -5,11 +5,11 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
|
|||
from langchain.schema import Generation, LLMResult
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
|
||||
from core.model_manager import ModelInstance
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
class LLMChain(LCLLMChain):
|
||||
|
|
@ -12,9 +12,9 @@ from pydantic import root_validator
|
|||
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import lc_messages_to_prompt_messages
|
||||
from core.features.dataset_retrieval.agent.fake_llm import FakeLLM
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
from core.third_party.langchain.llms.fake import FakeLLM
|
||||
|
||||
|
||||
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
|
||||
|
|
@ -12,8 +12,8 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy
|
|||
from langchain.schema import AgentAction, AgentFinish, OutputParserException
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.chain.llm_chain import LLMChain
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.features.dataset_retrieval.agent.llm_chain import LLMChain
|
||||
|
||||
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
|
||||
The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
import enum
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
|
|
@ -8,14 +7,13 @@ from langchain.callbacks.manager import Callbacks
|
|||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
||||
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
|
||||
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
|
||||
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.application_entities import ModelConfigEntity
|
||||
from core.entities.message_entities import prompt_messages_to_lc_messages
|
||||
from core.features.dataset_retrieval.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.features.dataset_retrieval.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
|
||||
from core.features.dataset_retrieval.agent.output_parser.structured_chat import StructuredChatOutputParser
|
||||
from core.features.dataset_retrieval.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
|
||||
from core.helper import moderation
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
|
|
@ -23,13 +21,6 @@ from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import Datas
|
|||
from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
|
||||
|
||||
|
||||
class PlanningStrategy(str, enum.Enum):
|
||||
ROUTER = 'router'
|
||||
REACT_ROUTER = 'react_router'
|
||||
REACT = 'react'
|
||||
FUNCTION_CALL = 'function_call'
|
||||
|
||||
|
||||
class AgentConfiguration(BaseModel):
|
||||
strategy: PlanningStrategy
|
||||
model_config: ModelConfigEntity
|
||||
|
|
@ -62,28 +53,7 @@ class AgentExecutor:
|
|||
self.agent = self._init_agent()
|
||||
|
||||
def _init_agent(self) -> Union[BaseSingleActionAgent, BaseMultiActionAgent]:
|
||||
if self.configuration.strategy == PlanningStrategy.REACT:
|
||||
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
|
||||
model_config=self.configuration.model_config,
|
||||
tools=self.configuration.tools,
|
||||
output_parser=StructuredChatOutputParser(),
|
||||
summary_model_config=self.configuration.summary_model_config
|
||||
if self.configuration.summary_model_config else None,
|
||||
agent_llm_callback=self.configuration.agent_llm_callback,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
|
||||
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
|
||||
model_config=self.configuration.model_config,
|
||||
tools=self.configuration.tools,
|
||||
extra_prompt_messages=prompt_messages_to_lc_messages(self.configuration.memory.get_history_prompt_messages())
|
||||
if self.configuration.memory else None, # used for read chat histories memory
|
||||
summary_model_config=self.configuration.summary_model_config
|
||||
if self.configuration.summary_model_config else None,
|
||||
agent_llm_callback=self.configuration.agent_llm_callback,
|
||||
verbose=True
|
||||
)
|
||||
elif self.configuration.strategy == PlanningStrategy.ROUTER:
|
||||
if self.configuration.strategy == PlanningStrategy.ROUTER:
|
||||
self.configuration.tools = [t for t in self.configuration.tools
|
||||
if isinstance(t, DatasetRetrieverTool)
|
||||
or isinstance(t, DatasetMultiRetrieverTool)]
|
||||
|
|
@ -2,9 +2,10 @@ from typing import Optional, cast
|
|||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.application_entities import DatasetEntity, DatasetRetrieveConfigEntity, InvokeFrom, ModelConfigEntity
|
||||
from core.features.dataset_retrieval.agent_based_dataset_executor import AgentConfiguration, AgentExecutor
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
|
@ -104,37 +104,17 @@ class HostingConfiguration:
|
|||
|
||||
if app_config.get("HOSTED_OPENAI_TRIAL_ENABLED"):
|
||||
hosted_quota_limit = int(app_config.get("HOSTED_OPENAI_QUOTA_LIMIT", "200"))
|
||||
trial_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_TRIAL_MODELS")
|
||||
trial_quota = TrialHostingQuota(
|
||||
quota_limit=hosted_quota_limit,
|
||||
restrict_models=[
|
||||
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-0125", model_type=ModelType.LLM),
|
||||
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
|
||||
]
|
||||
restrict_models=trial_models
|
||||
)
|
||||
quotas.append(trial_quota)
|
||||
|
||||
if app_config.get("HOSTED_OPENAI_PAID_ENABLED"):
|
||||
paid_models = self.parse_restrict_models_from_env(app_config, "HOSTED_OPENAI_PAID_MODELS")
|
||||
paid_quota = PaidHostingQuota(
|
||||
restrict_models=[
|
||||
RestrictModel(model="gpt-4", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-turbo-preview", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-1106-preview", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-4-0125-preview", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-16k", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-16k-0613", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-1106", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-0613", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-0125", model_type=ModelType.LLM),
|
||||
RestrictModel(model="gpt-3.5-turbo-instruct", model_type=ModelType.LLM),
|
||||
RestrictModel(model="text-davinci-003", model_type=ModelType.LLM),
|
||||
]
|
||||
restrict_models=paid_models
|
||||
)
|
||||
quotas.append(paid_quota)
|
||||
|
||||
|
|
@ -258,3 +238,11 @@ class HostingConfiguration:
|
|||
return HostedModerationConfig(
|
||||
enabled=False
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[RestrictModel]:
|
||||
models_str = app_config.get(env_var)
|
||||
models_list = models_str.split(",") if models_str else []
|
||||
return [RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) for model_name in models_list if
|
||||
model_name.strip()]
|
||||
|
||||
|
|
|
|||
|
|
@ -365,8 +365,9 @@ class IndexingRunner:
|
|||
notion_info={
|
||||
"notion_workspace_id": data_source_info['notion_workspace_id'],
|
||||
"notion_obj_id": data_source_info['notion_page_id'],
|
||||
"notion_page_type": data_source_info['notion_page_type'],
|
||||
"document": dataset_document
|
||||
"notion_page_type": data_source_info['type'],
|
||||
"document": dataset_document,
|
||||
"tenant_id": dataset_document.tenant_id
|
||||
},
|
||||
document_model=dataset_document.doc_form
|
||||
)
|
||||
|
|
@ -664,6 +665,7 @@ class IndexingRunner:
|
|||
)
|
||||
# load index
|
||||
index_processor.load(dataset, chunk_documents)
|
||||
db.session.add(dataset)
|
||||
|
||||
document_ids = [document.metadata['doc_id'] for document in chunk_documents]
|
||||
db.session.query(DocumentSegment).filter(
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@
|
|||
|
||||

|
||||
|
||||
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./schema.md)。
|
||||
展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等,规则设计详见:[Schema](./docs/zh_Hans/schema.md)。
|
||||
|
||||
- 可选择的模型列表展示
|
||||
|
||||
|
|
@ -86,4 +86,4 @@ Model Runtime 分三层:
|
|||

|
||||
|
||||
### [接口的具体实现 👈🏻](./docs/zh_Hans/interfaces.md)
|
||||
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
|
||||
你可以在这里找到你想要查看的接口的具体实现,以及接口的参数和返回值的具体含义。
|
||||
|
|
|
|||
|
|
@ -81,5 +81,18 @@ PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = {
|
|||
'min': 1,
|
||||
'max': 2048,
|
||||
'precision': 0,
|
||||
},
|
||||
DefaultParameterName.RESPONSE_FORMAT: {
|
||||
'label': {
|
||||
'en_US': 'Response Format',
|
||||
'zh_Hans': '回复格式',
|
||||
},
|
||||
'type': 'string',
|
||||
'help': {
|
||||
'en_US': 'Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.',
|
||||
'zh_Hans': '设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等',
|
||||
},
|
||||
'required': False,
|
||||
'options': ['JSON', 'XML'],
|
||||
}
|
||||
}
|
||||
|
|
@ -91,6 +91,7 @@ class DefaultParameterName(Enum):
|
|||
PRESENCE_PENALTY = "presence_penalty"
|
||||
FREQUENCY_PENALTY = "frequency_penalty"
|
||||
MAX_TOKENS = "max_tokens"
|
||||
RESPONSE_FORMAT = "response_format"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: Any) -> 'DefaultParameterName':
|
||||
|
|
|
|||
|
|
@ -262,23 +262,23 @@ class AIModel(ABC):
|
|||
try:
|
||||
default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template)
|
||||
default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name)
|
||||
if not parameter_rule.max:
|
||||
if not parameter_rule.max and 'max' in default_parameter_rule:
|
||||
parameter_rule.max = default_parameter_rule['max']
|
||||
if not parameter_rule.min:
|
||||
if not parameter_rule.min and 'min' in default_parameter_rule:
|
||||
parameter_rule.min = default_parameter_rule['min']
|
||||
if not parameter_rule.precision:
|
||||
if not parameter_rule.default and 'default' in default_parameter_rule:
|
||||
parameter_rule.default = default_parameter_rule['default']
|
||||
if not parameter_rule.precision:
|
||||
if not parameter_rule.precision and 'precision' in default_parameter_rule:
|
||||
parameter_rule.precision = default_parameter_rule['precision']
|
||||
if not parameter_rule.required:
|
||||
if not parameter_rule.required and 'required' in default_parameter_rule:
|
||||
parameter_rule.required = default_parameter_rule['required']
|
||||
if not parameter_rule.help:
|
||||
if not parameter_rule.help and 'help' in default_parameter_rule:
|
||||
parameter_rule.help = I18nObject(
|
||||
en_US=default_parameter_rule['help']['en_US'],
|
||||
)
|
||||
if not parameter_rule.help.en_US:
|
||||
if not parameter_rule.help.en_US and ('help' in default_parameter_rule and 'en_US' in default_parameter_rule['help']):
|
||||
parameter_rule.help.en_US = default_parameter_rule['help']['en_US']
|
||||
if not parameter_rule.help.zh_Hans:
|
||||
if not parameter_rule.help.zh_Hans and ('help' in default_parameter_rule and 'zh_Hans' in default_parameter_rule['help']):
|
||||
parameter_rule.help.zh_Hans = default_parameter_rule['help'].get('zh_Hans', default_parameter_rule['help']['en_US'])
|
||||
except ValueError:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -9,7 +9,13 @@ from typing import Optional, Union
|
|||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.callbacks.logging_callback import LoggingCallback
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
|
|
@ -74,7 +80,20 @@ class LargeLanguageModel(AIModel):
|
|||
)
|
||||
|
||||
try:
|
||||
result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
if "response_format" in model_parameters:
|
||||
result = self._code_block_mode_wrapper(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
callbacks=callbacks
|
||||
)
|
||||
else:
|
||||
result = self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
except Exception as e:
|
||||
self._trigger_invoke_error_callbacks(
|
||||
model=model,
|
||||
|
|
@ -120,6 +139,239 @@ class LargeLanguageModel(AIModel):
|
|||
|
||||
return result
|
||||
|
||||
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Code block mode wrapper, ensure the response is a code block with output markdown quote
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
block_prompts = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
"""
|
||||
|
||||
code_block = model_parameters.get("response_format", "")
|
||||
if not code_block:
|
||||
return self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
|
||||
model_parameters.pop("response_format")
|
||||
stop = stop or []
|
||||
stop.extend(["\n```", "```\n"])
|
||||
block_prompts = block_prompts.replace("{{block}}", code_block)
|
||||
|
||||
# check if there is a system message
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
# override the system message
|
||||
prompt_messages[0] = SystemPromptMessage(
|
||||
content=block_prompts
|
||||
.replace("{{instructions}}", prompt_messages[0].content)
|
||||
)
|
||||
else:
|
||||
# insert the system message
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
content=block_prompts
|
||||
.replace("{{instructions}}", f"Please output a valid {code_block} object.")
|
||||
))
|
||||
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
|
||||
# add ```JSON\n to the last message
|
||||
prompt_messages[-1].content += f"\n```{code_block}\n"
|
||||
else:
|
||||
# append a user message
|
||||
prompt_messages.append(UserPromptMessage(
|
||||
content=f"```{code_block}\n"
|
||||
))
|
||||
|
||||
response = self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
|
||||
if isinstance(response, Generator):
|
||||
first_chunk = next(response)
|
||||
def new_generator():
|
||||
yield first_chunk
|
||||
yield from response
|
||||
|
||||
if first_chunk.delta.message.content and first_chunk.delta.message.content.startswith("`"):
|
||||
return self._code_block_mode_stream_processor_with_backtick(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
input_generator=new_generator()
|
||||
)
|
||||
else:
|
||||
return self._code_block_mode_stream_processor(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
input_generator=new_generator()
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage],
|
||||
input_generator: Generator[LLMResultChunk, None, None]
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Code block mode stream processor, ensure the response is a code block with output markdown quote
|
||||
|
||||
:param model: model name
|
||||
:param prompt_messages: prompt messages
|
||||
:param input_generator: input generator
|
||||
:return: output generator
|
||||
"""
|
||||
state = "normal"
|
||||
backtick_count = 0
|
||||
for piece in input_generator:
|
||||
if piece.delta.message.content:
|
||||
content = piece.delta.message.content
|
||||
piece.delta.message.content = ""
|
||||
yield piece
|
||||
piece = content
|
||||
else:
|
||||
yield piece
|
||||
continue
|
||||
new_piece = ""
|
||||
for char in piece:
|
||||
if state == "normal":
|
||||
if char == "`":
|
||||
state = "in_backticks"
|
||||
backtick_count = 1
|
||||
else:
|
||||
new_piece += char
|
||||
elif state == "in_backticks":
|
||||
if char == "`":
|
||||
backtick_count += 1
|
||||
if backtick_count == 3:
|
||||
state = "skip_content"
|
||||
backtick_count = 0
|
||||
else:
|
||||
new_piece += "`" * backtick_count + char
|
||||
state = "normal"
|
||||
backtick_count = 0
|
||||
elif state == "skip_content":
|
||||
if char.isspace():
|
||||
state = "normal"
|
||||
|
||||
if new_piece:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=new_piece,
|
||||
tool_calls=[]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list,
|
||||
input_generator: Generator[LLMResultChunk, None, None]) \
|
||||
-> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Code block mode stream processor, ensure the response is a code block with output markdown quote.
|
||||
This version skips the language identifier that follows the opening triple backticks.
|
||||
|
||||
:param model: model name
|
||||
:param prompt_messages: prompt messages
|
||||
:param input_generator: input generator
|
||||
:return: output generator
|
||||
"""
|
||||
state = "search_start"
|
||||
backtick_count = 0
|
||||
|
||||
for piece in input_generator:
|
||||
if piece.delta.message.content:
|
||||
content = piece.delta.message.content
|
||||
# Reset content to ensure we're only processing and yielding the relevant parts
|
||||
piece.delta.message.content = ""
|
||||
# Yield a piece with cleared content before processing it to maintain the generator structure
|
||||
yield piece
|
||||
piece = content
|
||||
else:
|
||||
# Yield pieces without content directly
|
||||
yield piece
|
||||
continue
|
||||
|
||||
if state == "done":
|
||||
continue
|
||||
|
||||
new_piece = ""
|
||||
for char in piece:
|
||||
if state == "search_start":
|
||||
if char == "`":
|
||||
backtick_count += 1
|
||||
if backtick_count == 3:
|
||||
state = "skip_language"
|
||||
backtick_count = 0
|
||||
else:
|
||||
backtick_count = 0
|
||||
elif state == "skip_language":
|
||||
# Skip everything until the first newline, marking the end of the language identifier
|
||||
if char == "\n":
|
||||
state = "in_code_block"
|
||||
elif state == "in_code_block":
|
||||
if char == "`":
|
||||
backtick_count += 1
|
||||
if backtick_count == 3:
|
||||
state = "done"
|
||||
break
|
||||
else:
|
||||
if backtick_count > 0:
|
||||
# If backticks were counted but we're still collecting content, it was a false start
|
||||
new_piece += "`" * backtick_count
|
||||
backtick_count = 0
|
||||
new_piece += char
|
||||
|
||||
elif state == "done":
|
||||
break
|
||||
|
||||
if new_piece:
|
||||
# Only yield content collected within the code block
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=new_piece,
|
||||
tool_calls=[]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def _invoke_result_generator(self, model: str, result: Generator, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
|
|
@ -204,7 +456,7 @@ class LargeLanguageModel(AIModel):
|
|||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
- bedrock
|
||||
- togetherai
|
||||
- ollama
|
||||
- mistralai
|
||||
- replicate
|
||||
- huggingface_hub
|
||||
- zhipuai
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ parameter_rules:
|
|||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '8.00'
|
||||
output: '24.00'
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ parameter_rules:
|
|||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '8.00'
|
||||
output: '24.00'
|
||||
|
|
|
|||
|
|
@ -26,6 +26,8 @@ parameter_rules:
|
|||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '1.63'
|
||||
output: '5.51'
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from anthropic import Anthropic, Stream
|
|||
from anthropic.types import Completion, completion_create_params
|
||||
from httpx import Timeout
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
|
|
@ -25,9 +26,16 @@ from core.model_runtime.errors.invoke import (
|
|||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
"""
|
||||
|
||||
class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
|
|
@ -48,6 +56,53 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
|||
"""
|
||||
# invoke model
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||
|
||||
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Code block mode wrapper for invoking large language model
|
||||
"""
|
||||
if 'response_format' in model_parameters and model_parameters['response_format']:
|
||||
stop = stop or []
|
||||
self._transform_json_prompts(
|
||||
model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, model_parameters['response_format']
|
||||
)
|
||||
model_parameters.pop('response_format')
|
||||
|
||||
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def _transform_json_prompts(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
|
||||
-> None:
|
||||
"""
|
||||
Transform json prompts
|
||||
"""
|
||||
if "```\n" not in stop:
|
||||
stop.append("```\n")
|
||||
|
||||
# check if there is a system message
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
# override the system message
|
||||
prompt_messages[0] = SystemPromptMessage(
|
||||
content=ANTHROPIC_BLOCK_MODE_PROMPT
|
||||
.replace("{{instructions}}", prompt_messages[0].content)
|
||||
.replace("{{block}}", response_format)
|
||||
)
|
||||
else:
|
||||
# insert the system message
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
content=ANTHROPIC_BLOCK_MODE_PROMPT
|
||||
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
|
||||
.replace("{{block}}", response_format)
|
||||
))
|
||||
|
||||
prompt_messages.append(AssistantPromptMessage(
|
||||
content=f"```{response_format}\n"
|
||||
))
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
|
|
|
|||
|
|
@ -27,6 +27,8 @@ parameter_rules:
|
|||
default: 2048
|
||||
min: 1
|
||||
max: 2048
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
|
|
|
|||
|
|
@ -31,6 +31,16 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GEMINI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
"""
|
||||
|
||||
|
||||
class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
|
|
@ -53,7 +63,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
|||
"""
|
||||
# invoke model
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
|
|
|
|||
Binary file not shown.
|
After Width: | Height: | Size: 6.9 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 2.3 KiB |
|
|
@ -0,0 +1,5 @@
|
|||
- open-mistral-7b
|
||||
- open-mixtral-8x7b
|
||||
- mistral-small-latest
|
||||
- mistral-medium-latest
|
||||
- mistral-large-latest
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel
|
||||
|
||||
|
||||
class MistralAILargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
|
||||
self._add_custom_parameters(credentials)
|
||||
|
||||
# mistral dose not support user/stop arguments
|
||||
stop = []
|
||||
user = None
|
||||
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
self._add_custom_parameters(credentials)
|
||||
super().validate_credentials(model, credentials)
|
||||
|
||||
@staticmethod
|
||||
def _add_custom_parameters(credentials: dict) -> None:
|
||||
credentials['mode'] = 'chat'
|
||||
credentials['endpoint_url'] = 'https://api.mistral.ai/v1'
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
model: mistral-large-latest
|
||||
label:
|
||||
zh_Hans: mistral-large-latest
|
||||
en_US: mistral-large-latest
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 1
|
||||
min: 0
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
defulat: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
zh_Hans: 是否开启提示词审查
|
||||
label:
|
||||
en_US: SafePrompt
|
||||
zh_Hans: 提示词审查
|
||||
- name: random_seed
|
||||
type: int
|
||||
help:
|
||||
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
|
||||
label:
|
||||
en_US: RandomSeed
|
||||
zh_Hans: 随机数种子
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2147483647
|
||||
pricing:
|
||||
input: '0.008'
|
||||
output: '0.024'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
model: mistral-medium-latest
|
||||
label:
|
||||
zh_Hans: mistral-medium-latest
|
||||
en_US: mistral-medium-latest
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 1
|
||||
min: 0
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
defulat: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
zh_Hans: 是否开启提示词审查
|
||||
label:
|
||||
en_US: SafePrompt
|
||||
zh_Hans: 提示词审查
|
||||
- name: random_seed
|
||||
type: int
|
||||
help:
|
||||
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
|
||||
label:
|
||||
en_US: RandomSeed
|
||||
zh_Hans: 随机数种子
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2147483647
|
||||
pricing:
|
||||
input: '0.0027'
|
||||
output: '0.0081'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
model: mistral-small-latest
|
||||
label:
|
||||
zh_Hans: mistral-small-latest
|
||||
en_US: mistral-small-latest
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 1
|
||||
min: 0
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
defulat: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
zh_Hans: 是否开启提示词审查
|
||||
label:
|
||||
en_US: SafePrompt
|
||||
zh_Hans: 提示词审查
|
||||
- name: random_seed
|
||||
type: int
|
||||
help:
|
||||
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
|
||||
label:
|
||||
en_US: RandomSeed
|
||||
zh_Hans: 随机数种子
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2147483647
|
||||
pricing:
|
||||
input: '0.002'
|
||||
output: '0.006'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
model: open-mistral-7b
|
||||
label:
|
||||
zh_Hans: open-mistral-7b
|
||||
en_US: open-mistral-7b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
context_size: 8000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 1
|
||||
min: 0
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 2048
|
||||
- name: safe_prompt
|
||||
defulat: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
zh_Hans: 是否开启提示词审查
|
||||
label:
|
||||
en_US: SafePrompt
|
||||
zh_Hans: 提示词审查
|
||||
- name: random_seed
|
||||
type: int
|
||||
help:
|
||||
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
|
||||
label:
|
||||
en_US: RandomSeed
|
||||
zh_Hans: 随机数种子
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2147483647
|
||||
pricing:
|
||||
input: '0.00025'
|
||||
output: '0.00025'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
model: open-mixtral-8x7b
|
||||
label:
|
||||
zh_Hans: open-mixtral-8x7b
|
||||
en_US: open-mixtral-8x7b
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
context_size: 32000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.7
|
||||
min: 0
|
||||
max: 1
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
default: 1
|
||||
min: 0
|
||||
max: 1
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 1024
|
||||
min: 1
|
||||
max: 8000
|
||||
- name: safe_prompt
|
||||
defulat: false
|
||||
type: boolean
|
||||
help:
|
||||
en_US: Whether to inject a safety prompt before all conversations.
|
||||
zh_Hans: 是否开启提示词审查
|
||||
label:
|
||||
en_US: SafePrompt
|
||||
zh_Hans: 提示词审查
|
||||
- name: random_seed
|
||||
type: int
|
||||
help:
|
||||
en_US: The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
zh_Hans: 当开启随机数种子以后,你可以通过指定一个固定的种子来使得回答结果更加稳定
|
||||
label:
|
||||
en_US: RandomSeed
|
||||
zh_Hans: 随机数种子
|
||||
default: 0
|
||||
min: 0
|
||||
max: 2147483647
|
||||
pricing:
|
||||
input: '0.0007'
|
||||
output: '0.0007'
|
||||
unit: '0.001'
|
||||
currency: USD
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MistralAIProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
model_instance.validate_credentials(
|
||||
model='open-mistral-7b',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
provider: mistralai
|
||||
label:
|
||||
en_US: MistralAI
|
||||
description:
|
||||
en_US: Models provided by MistralAI, such as open-mistral-7b and mistral-large-latest.
|
||||
zh_Hans: MistralAI 提供的模型,例如 open-mistral-7b 和 mistral-large-latest。
|
||||
icon_small:
|
||||
en_US: icon_s_en.png
|
||||
icon_large:
|
||||
en_US: icon_l_en.png
|
||||
background: "#FFFFFF"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from MistralAI
|
||||
zh_Hans: 从 MistralAI 获取 API Key
|
||||
url:
|
||||
en_US: https://console.mistral.ai/api-keys/
|
||||
supported_model_types:
|
||||
- llm
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
|
|
@ -13,6 +13,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel):
|
|||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
self._add_custom_parameters(credentials)
|
||||
user = user[:32] if user else None
|
||||
return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
|
|
|
|||
|
|
@ -24,6 +24,18 @@ parameter_rules:
|
|||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.0005'
|
||||
output: '0.0015'
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ parameter_rules:
|
|||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.0015'
|
||||
output: '0.002'
|
||||
|
|
|
|||
|
|
@ -24,6 +24,18 @@ parameter_rules:
|
|||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.001'
|
||||
output: '0.002'
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ parameter_rules:
|
|||
default: 512
|
||||
min: 1
|
||||
max: 16385
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.004'
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ parameter_rules:
|
|||
default: 512
|
||||
min: 1
|
||||
max: 16385
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.003'
|
||||
output: '0.004'
|
||||
|
|
|
|||
|
|
@ -21,6 +21,8 @@ parameter_rules:
|
|||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.0015'
|
||||
output: '0.002'
|
||||
|
|
|
|||
|
|
@ -24,6 +24,18 @@ parameter_rules:
|
|||
default: 512
|
||||
min: 1
|
||||
max: 4096
|
||||
- name: response_format
|
||||
label:
|
||||
zh_Hans: 回复格式
|
||||
en_US: response_format
|
||||
type: string
|
||||
help:
|
||||
zh_Hans: 指定模型必须输出的格式
|
||||
en_US: specifying the format that the model must output
|
||||
required: false
|
||||
options:
|
||||
- text
|
||||
- json_object
|
||||
pricing:
|
||||
input: '0.001'
|
||||
output: '0.002'
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletio
|
|||
from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall
|
||||
from openai.types.chat.chat_completion_message import FunctionCall
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
|
|
@ -28,6 +29,14 @@ from core.model_runtime.model_providers.openai._common import _CommonOpenAI
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENAI_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
"""
|
||||
|
||||
class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
||||
"""
|
||||
|
|
@ -84,6 +93,131 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel):
|
|||
user=user
|
||||
)
|
||||
|
||||
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Code block mode wrapper for invoking large language model
|
||||
"""
|
||||
# handle fine tune remote models
|
||||
base_model = model
|
||||
if model.startswith('ft:'):
|
||||
base_model = model.split(':')[1]
|
||||
|
||||
# get model mode
|
||||
model_mode = self.get_model_mode(base_model, credentials)
|
||||
|
||||
# transform response format
|
||||
if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']:
|
||||
stop = stop or []
|
||||
if model_mode == LLMMode.CHAT:
|
||||
# chat model
|
||||
self._transform_chat_json_prompts(
|
||||
model=base_model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
response_format=model_parameters['response_format']
|
||||
)
|
||||
else:
|
||||
self._transform_completion_json_prompts(
|
||||
model=base_model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user,
|
||||
response_format=model_parameters['response_format']
|
||||
)
|
||||
model_parameters.pop('response_format')
|
||||
|
||||
return self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
|
||||
def _transform_chat_json_prompts(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
|
||||
-> None:
|
||||
"""
|
||||
Transform json prompts
|
||||
"""
|
||||
if "```\n" not in stop:
|
||||
stop.append("```\n")
|
||||
if "\n```" not in stop:
|
||||
stop.append("\n```")
|
||||
|
||||
# check if there is a system message
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
# override the system message
|
||||
prompt_messages[0] = SystemPromptMessage(
|
||||
content=OPENAI_BLOCK_MODE_PROMPT
|
||||
.replace("{{instructions}}", prompt_messages[0].content)
|
||||
.replace("{{block}}", response_format)
|
||||
)
|
||||
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}\n"))
|
||||
else:
|
||||
# insert the system message
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
content=OPENAI_BLOCK_MODE_PROMPT
|
||||
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
|
||||
.replace("{{block}}", response_format)
|
||||
))
|
||||
prompt_messages.append(AssistantPromptMessage(content=f"\n```{response_format}"))
|
||||
|
||||
def _transform_completion_json_prompts(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
|
||||
-> None:
|
||||
"""
|
||||
Transform json prompts
|
||||
"""
|
||||
if "```\n" not in stop:
|
||||
stop.append("```\n")
|
||||
if "\n```" not in stop:
|
||||
stop.append("\n```")
|
||||
|
||||
# override the last user message
|
||||
user_message = None
|
||||
for i in range(len(prompt_messages) - 1, -1, -1):
|
||||
if isinstance(prompt_messages[i], UserPromptMessage):
|
||||
user_message = prompt_messages[i]
|
||||
break
|
||||
|
||||
if user_message:
|
||||
if prompt_messages[i].content[-11:] == 'Assistant: ':
|
||||
# now we are in the chat app, remove the last assistant message
|
||||
prompt_messages[i].content = prompt_messages[i].content[:-11]
|
||||
prompt_messages[i] = UserPromptMessage(
|
||||
content=OPENAI_BLOCK_MODE_PROMPT
|
||||
.replace("{{instructions}}", user_message.content)
|
||||
.replace("{{block}}", response_format)
|
||||
)
|
||||
prompt_messages[i].content += f"Assistant:\n```{response_format}\n"
|
||||
else:
|
||||
prompt_messages[i] = UserPromptMessage(
|
||||
content=OPENAI_BLOCK_MODE_PROMPT
|
||||
.replace("{{instructions}}", user_message.content)
|
||||
.replace("{{block}}", response_format)
|
||||
)
|
||||
prompt_messages[i].content += f"\n```{response_format}\n"
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from dashscope.common.error import (
|
|||
)
|
||||
from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
|
|
@ -57,6 +58,88 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
|||
"""
|
||||
# invoke model
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user)
|
||||
|
||||
def _code_block_mode_wrapper(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
stream: bool = True, user: str | None = None, callbacks: list[Callback] = None) \
|
||||
-> LLMResult | Generator:
|
||||
"""
|
||||
Wrapper for code block mode
|
||||
"""
|
||||
block_prompts = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
"""
|
||||
|
||||
code_block = model_parameters.get("response_format", "")
|
||||
if not code_block:
|
||||
return self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
|
||||
model_parameters.pop("response_format")
|
||||
stop = stop or []
|
||||
stop.extend(["\n```", "```\n"])
|
||||
block_prompts = block_prompts.replace("{{block}}", code_block)
|
||||
|
||||
# check if there is a system message
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
# override the system message
|
||||
prompt_messages[0] = SystemPromptMessage(
|
||||
content=block_prompts
|
||||
.replace("{{instructions}}", prompt_messages[0].content)
|
||||
)
|
||||
else:
|
||||
# insert the system message
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
content=block_prompts
|
||||
.replace("{{instructions}}", f"Please output a valid {code_block} object.")
|
||||
))
|
||||
|
||||
mode = self.get_model_mode(model, credentials)
|
||||
if mode == LLMMode.CHAT:
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
|
||||
# add ```JSON\n to the last message
|
||||
prompt_messages[-1].content += f"\n```{code_block}\n"
|
||||
else:
|
||||
# append a user message
|
||||
prompt_messages.append(UserPromptMessage(
|
||||
content=f"```{code_block}\n"
|
||||
))
|
||||
else:
|
||||
prompt_messages.append(AssistantPromptMessage(content=f"```{code_block}\n"))
|
||||
|
||||
response = self._invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
user=user
|
||||
)
|
||||
|
||||
if isinstance(response, Generator):
|
||||
return self._code_block_mode_stream_processor_with_backtick(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
input_generator=response
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
|
|
@ -117,7 +200,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
|||
"""
|
||||
extra_model_kwargs = {}
|
||||
if stop:
|
||||
extra_model_kwargs['stop_sequences'] = stop
|
||||
extra_model_kwargs['stop'] = stop
|
||||
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
|
|
@ -131,7 +214,8 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
|
|||
params = {
|
||||
'model': model,
|
||||
**model_parameters,
|
||||
**credentials_kwargs
|
||||
**credentials_kwargs,
|
||||
**extra_model_kwargs,
|
||||
}
|
||||
|
||||
mode = self.get_model_mode(model, credentials)
|
||||
|
|
|
|||
|
|
@ -57,3 +57,5 @@ parameter_rules:
|
|||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
||||
required: false
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
|
|
|
|||
|
|
@ -57,3 +57,5 @@ parameter_rules:
|
|||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
||||
required: false
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
|
|
|
|||
|
|
@ -57,3 +57,5 @@ parameter_rules:
|
|||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
||||
required: false
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
|
|
|
|||
|
|
@ -56,6 +56,8 @@ parameter_rules:
|
|||
help:
|
||||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.02'
|
||||
output: '0.02'
|
||||
|
|
|
|||
|
|
@ -57,6 +57,8 @@ parameter_rules:
|
|||
zh_Hans: 用于控制模型生成时的重复度。提高repetition_penalty时可以降低模型生成的重复度。1.0表示不做惩罚。
|
||||
en_US: Used to control the repetition of model generation. Increasing the repetition_penalty can reduce the repetition of model generation. 1.0 means no punishment.
|
||||
required: false
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
pricing:
|
||||
input: '0.008'
|
||||
output: '0.008'
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ parameter_rules:
|
|||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: disable_search
|
||||
label:
|
||||
zh_Hans: 禁用搜索
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ parameter_rules:
|
|||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
- name: disable_search
|
||||
label:
|
||||
zh_Hans: 禁用搜索
|
||||
|
|
|
|||
|
|
@ -25,3 +25,5 @@ parameter_rules:
|
|||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
|
|
|
|||
|
|
@ -34,3 +34,5 @@ parameter_rules:
|
|||
zh_Hans: 禁用模型自行进行外部搜索。
|
||||
en_US: Disable the model to perform external search.
|
||||
required: false
|
||||
- name: response_format
|
||||
use_template: response_format
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from collections.abc import Generator
|
||||
from typing import cast
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
|
|
@ -29,8 +30,18 @@ from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (
|
|||
RateLimitReachedError,
|
||||
)
|
||||
|
||||
ERNIE_BOT_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
|
||||
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
class ErnieBotLarguageModel(LargeLanguageModel):
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
|
||||
You should also complete the text started with ``` but not tell ``` directly.
|
||||
"""
|
||||
|
||||
class ErnieBotLargeLanguageModel(LargeLanguageModel):
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
|
|
@ -39,6 +50,62 @@ class ErnieBotLarguageModel(LargeLanguageModel):
|
|||
return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
||||
|
||||
def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None,
|
||||
callbacks: list[Callback] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Code block mode wrapper for invoking large language model
|
||||
"""
|
||||
if 'response_format' in model_parameters and model_parameters['response_format'] in ['JSON', 'XML']:
|
||||
response_format = model_parameters['response_format']
|
||||
stop = stop or []
|
||||
self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user, response_format)
|
||||
model_parameters.pop('response_format')
|
||||
if stream:
|
||||
return self._code_block_mode_stream_processor(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
input_generator=self._invoke(model=model, credentials=credentials, prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters, tools=tools, stop=stop, stream=stream, user=user)
|
||||
)
|
||||
|
||||
return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def _transform_json_prompts(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
stream: bool = True, user: str | None = None, response_format: str = 'JSON') \
|
||||
-> None:
|
||||
"""
|
||||
Transform json prompts to model prompts
|
||||
"""
|
||||
|
||||
# check if there is a system message
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
# override the system message
|
||||
prompt_messages[0] = SystemPromptMessage(
|
||||
content=ERNIE_BOT_BLOCK_MODE_PROMPT
|
||||
.replace("{{instructions}}", prompt_messages[0].content)
|
||||
.replace("{{block}}", response_format)
|
||||
)
|
||||
else:
|
||||
# insert the system message
|
||||
prompt_messages.insert(0, SystemPromptMessage(
|
||||
content=ERNIE_BOT_BLOCK_MODE_PROMPT
|
||||
.replace("{{instructions}}", f"Please output a valid {response_format} object.")
|
||||
.replace("{{block}}", response_format)
|
||||
))
|
||||
|
||||
if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
|
||||
# add ```JSON\n to the last message
|
||||
prompt_messages[-1].content += "\n```JSON\n{\n"
|
||||
else:
|
||||
# append a user message
|
||||
prompt_messages.append(UserPromptMessage(
|
||||
content="```JSON\n{\n"
|
||||
))
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool] | None = None) -> int:
|
||||
# tools is not supported yet
|
||||
|
|
|
|||
|
|
@ -19,6 +19,17 @@ from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_comp
|
|||
from core.model_runtime.model_providers.zhipuai.zhipuai_sdk.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from core.model_runtime.utils import helper
|
||||
|
||||
GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object.
|
||||
The structure of the JSON object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
|
||||
if you are not sure about the structure.
|
||||
|
||||
And you should always end the block with a "```" to indicate the end of the JSON object.
|
||||
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
|
||||
```JSON"""
|
||||
|
||||
class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
|
||||
|
|
@ -44,8 +55,42 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
|
||||
# invoke model
|
||||
# stop = stop or []
|
||||
# self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
# def _transform_json_prompts(self, model: str, credentials: dict,
|
||||
# prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
# tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None,
|
||||
# stream: bool = True, user: str | None = None) \
|
||||
# -> None:
|
||||
# """
|
||||
# Transform json prompts to model prompts
|
||||
# """
|
||||
# if "}\n\n" not in stop:
|
||||
# stop.append("}\n\n")
|
||||
|
||||
# # check if there is a system message
|
||||
# if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
# # override the system message
|
||||
# prompt_messages[0] = SystemPromptMessage(
|
||||
# content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content)
|
||||
# )
|
||||
# else:
|
||||
# # insert the system message
|
||||
# prompt_messages.insert(0, SystemPromptMessage(
|
||||
# content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", "Please output a valid JSON object.")
|
||||
# ))
|
||||
# # check if the last message is a user message
|
||||
# if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage):
|
||||
# # add ```JSON\n to the last message
|
||||
# prompt_messages[-1].content += "\n```JSON\n"
|
||||
# else:
|
||||
# # append a user message
|
||||
# prompt_messages.append(UserPromptMessage(
|
||||
# content="```JSON\n"
|
||||
# ))
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
|
|
@ -106,7 +151,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||
"""
|
||||
extra_model_kwargs = {}
|
||||
if stop:
|
||||
extra_model_kwargs['stop_sequences'] = stop
|
||||
extra_model_kwargs['stop'] = stop
|
||||
|
||||
client = ZhipuAI(
|
||||
api_key=credentials_kwargs['api_key']
|
||||
|
|
@ -256,10 +301,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
|||
]
|
||||
|
||||
if stream:
|
||||
response = client.chat.completions.create(stream=stream, **params)
|
||||
response = client.chat.completions.create(stream=stream, **params, **extra_model_kwargs)
|
||||
return self._handle_generate_stream_response(model, credentials_kwargs, tools, response, prompt_messages)
|
||||
|
||||
response = client.chat.completions.create(**params)
|
||||
response = client.chat.completions.create(**params, **extra_model_kwargs)
|
||||
return self._handle_generate_response(model, credentials_kwargs, tools, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(self, model: str,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, cast
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
|
||||
|
|
@ -14,7 +14,7 @@ class Keyword:
|
|||
self._keyword_processor = self._init_keyword()
|
||||
|
||||
def _init_keyword(self) -> BaseKeyword:
|
||||
config = cast(dict, current_app.config)
|
||||
config = current_app.config
|
||||
keyword_type = config.get('KEYWORD_STORE')
|
||||
|
||||
if not keyword_type:
|
||||
|
|
@ -25,7 +25,7 @@ class Keyword:
|
|||
dataset=self._dataset
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")
|
||||
raise ValueError(f"Keyword store {keyword_type} is not supported.")
|
||||
|
||||
def create(self, texts: list[Document], **kwargs):
|
||||
self._keyword_processor.create(texts, **kwargs)
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import threading
|
|||
from typing import Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
from flask_login import current_user
|
||||
|
||||
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
|
|
@ -27,6 +26,11 @@ class RetrievalService:
|
|||
@classmethod
|
||||
def retrieve(cls, retrival_method: str, dataset_id: str, query: str,
|
||||
top_k: int, score_threshold: Optional[float] = .0, reranking_model: Optional[dict] = None):
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
if not dataset or dataset.available_document_count == 0 or dataset.available_segment_count == 0:
|
||||
return []
|
||||
all_documents = []
|
||||
threads = []
|
||||
# retrieval_model source with keyword
|
||||
|
|
@ -35,7 +39,8 @@ class RetrievalService:
|
|||
'flask_app': current_app._get_current_object(),
|
||||
'dataset_id': dataset_id,
|
||||
'query': query,
|
||||
'top_k': top_k
|
||||
'top_k': top_k,
|
||||
'all_documents': all_documents
|
||||
})
|
||||
threads.append(keyword_thread)
|
||||
keyword_thread.start()
|
||||
|
|
@ -73,7 +78,7 @@ class RetrievalService:
|
|||
thread.join()
|
||||
|
||||
if retrival_method == 'hybrid_search':
|
||||
data_post_processor = DataPostProcessor(str(current_user.current_tenant_id), reranking_model, False)
|
||||
data_post_processor = DataPostProcessor(str(dataset.tenant_id), reranking_model, False)
|
||||
all_documents = data_post_processor.invoke(
|
||||
query=query,
|
||||
documents=all_documents,
|
||||
|
|
@ -96,7 +101,7 @@ class RetrievalService:
|
|||
|
||||
documents = keyword.search(
|
||||
query,
|
||||
k=top_k
|
||||
top_k=top_k
|
||||
)
|
||||
all_documents.extend(documents)
|
||||
|
||||
|
|
@ -116,7 +121,7 @@ class RetrievalService:
|
|||
documents = vector.search_by_vector(
|
||||
query,
|
||||
search_type='similarity_score_threshold',
|
||||
k=top_k,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
filter={
|
||||
'group_id': [dataset.id]
|
||||
|
|
|
|||
|
|
@ -7,4 +7,4 @@ class Field(Enum):
|
|||
GROUP_KEY = "group_id"
|
||||
VECTOR = "vector"
|
||||
TEXT_KEY = "text"
|
||||
PRIMARY_KEY = " id"
|
||||
PRIMARY_KEY = "id"
|
||||
|
|
|
|||
|
|
@ -124,12 +124,23 @@ class MilvusVector(BaseVector):
|
|||
|
||||
def delete_by_ids(self, doc_ids: list[str]) -> None:
|
||||
|
||||
self._client.delete(collection_name=self._collection_name, pks=doc_ids)
|
||||
result = self._client.query(collection_name=self._collection_name,
|
||||
filter=f'metadata["doc_id"] in {doc_ids}',
|
||||
output_fields=["id"])
|
||||
if result:
|
||||
ids = [item["id"] for item in result]
|
||||
self._client.delete(collection_name=self._collection_name, pks=ids)
|
||||
|
||||
def delete(self) -> None:
|
||||
alias = uuid4().hex
|
||||
if self._client_config.secure:
|
||||
uri = "https://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
else:
|
||||
uri = "http://" + str(self._client_config.host) + ":" + str(self._client_config.port)
|
||||
connections.connect(alias=alias, uri=uri, user=self._client_config.user, password=self._client_config.password)
|
||||
|
||||
from pymilvus import utility
|
||||
utility.drop_collection(self._collection_name, None)
|
||||
utility.drop_collection(self._collection_name, None, using=alias)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Any, cast
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
|
||||
|
|
@ -22,7 +23,7 @@ class Vector:
|
|||
self._vector_processor = self._init_vector()
|
||||
|
||||
def _init_vector(self) -> BaseVector:
|
||||
config = cast(dict, current_app.config)
|
||||
config = current_app.config
|
||||
vector_type = config.get('VECTOR_STORE')
|
||||
|
||||
if self._dataset.index_struct_dict:
|
||||
|
|
@ -39,6 +40,11 @@ class Vector:
|
|||
else:
|
||||
dataset_id = self._dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
index_struct_dict = {
|
||||
"type": 'weaviate',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
self._dataset.index_struct = json.dumps(index_struct_dict)
|
||||
return WeaviateVector(
|
||||
collection_name=collection_name,
|
||||
config=WeaviateConfig(
|
||||
|
|
@ -66,6 +72,13 @@ class Vector:
|
|||
dataset_id = self._dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
|
||||
if not self._dataset.index_struct_dict:
|
||||
index_struct_dict = {
|
||||
"type": 'qdrant',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
self._dataset.index_struct = json.dumps(index_struct_dict)
|
||||
|
||||
return QdrantVector(
|
||||
collection_name=collection_name,
|
||||
group_id=self._dataset.id,
|
||||
|
|
@ -84,6 +97,11 @@ class Vector:
|
|||
else:
|
||||
dataset_id = self._dataset.id
|
||||
collection_name = "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
|
||||
index_struct_dict = {
|
||||
"type": 'milvus',
|
||||
"vector_store": {"class_prefix": collection_name}
|
||||
}
|
||||
self._dataset.index_struct = json.dumps(index_struct_dict)
|
||||
return MilvusVector(
|
||||
collection_name=collection_name,
|
||||
config=MilvusConfig(
|
||||
|
|
|
|||
|
|
@ -127,7 +127,10 @@ class WeaviateVector(BaseVector):
|
|||
)
|
||||
|
||||
def delete(self):
|
||||
self._client.schema.delete_class(self._collection_name)
|
||||
# check whether the index already exists
|
||||
schema = self._default_schema(self._collection_name)
|
||||
if self._client.schema.contains(schema):
|
||||
self._client.schema.delete_class(self._collection_name)
|
||||
|
||||
def text_exists(self, id: str) -> bool:
|
||||
collection_name = self._collection_name
|
||||
|
|
@ -147,10 +150,11 @@ class WeaviateVector(BaseVector):
|
|||
return True
|
||||
|
||||
def delete_by_ids(self, ids: list[str]) -> None:
|
||||
self._client.data_object.delete(
|
||||
ids,
|
||||
class_name=self._collection_name
|
||||
)
|
||||
for uuid in ids:
|
||||
self._client.data_object.delete(
|
||||
class_name=self._collection_name,
|
||||
uuid=uuid,
|
||||
)
|
||||
|
||||
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
|
||||
"""Look up similar documents by embedding vector in Weaviate."""
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ class NotionInfo(BaseModel):
|
|||
notion_obj_id: str
|
||||
notion_page_type: str
|
||||
document: Document = None
|
||||
tenant_id: str
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
|
|
|||
|
|
@ -132,7 +132,8 @@ class ExtractProcessor:
|
|||
notion_workspace_id=extract_setting.notion_info.notion_workspace_id,
|
||||
notion_obj_id=extract_setting.notion_info.notion_obj_id,
|
||||
notion_page_type=extract_setting.notion_info.notion_page_type,
|
||||
document_model=extract_setting.notion_info.document
|
||||
document_model=extract_setting.notion_info.document,
|
||||
tenant_id=extract_setting.notion_info.tenant_id,
|
||||
)
|
||||
return extractor.extract()
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
"""Abstract interface for document loader implementations."""
|
||||
from typing import Optional
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.extractor.helpers import detect_file_encodings
|
||||
from core.rag.models.document import Document
|
||||
|
||||
|
||||
class HtmlExtractor(BaseExtractor):
|
||||
"""Load html files.
|
||||
|
||||
"""
|
||||
Load html files.
|
||||
|
||||
|
||||
Args:
|
||||
|
|
@ -15,57 +16,19 @@ class HtmlExtractor(BaseExtractor):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
encoding: Optional[str] = None,
|
||||
autodetect_encoding: bool = False,
|
||||
source_column: Optional[str] = None,
|
||||
csv_args: Optional[dict] = None,
|
||||
self,
|
||||
file_path: str
|
||||
):
|
||||
"""Initialize with file path."""
|
||||
self._file_path = file_path
|
||||
self._encoding = encoding
|
||||
self._autodetect_encoding = autodetect_encoding
|
||||
self.source_column = source_column
|
||||
self.csv_args = csv_args or {}
|
||||
|
||||
def extract(self) -> list[Document]:
|
||||
"""Load data into document objects."""
|
||||
try:
|
||||
with open(self._file_path, newline="", encoding=self._encoding) as csvfile:
|
||||
docs = self._read_from_file(csvfile)
|
||||
except UnicodeDecodeError as e:
|
||||
if self._autodetect_encoding:
|
||||
detected_encodings = detect_file_encodings(self._file_path)
|
||||
for encoding in detected_encodings:
|
||||
try:
|
||||
with open(self._file_path, newline="", encoding=encoding.encoding) as csvfile:
|
||||
docs = self._read_from_file(csvfile)
|
||||
break
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
else:
|
||||
raise RuntimeError(f"Error loading {self._file_path}") from e
|
||||
return [Document(page_content=self._load_as_text())]
|
||||
|
||||
return docs
|
||||
def _load_as_text(self) -> str:
|
||||
with open(self._file_path, "rb") as fp:
|
||||
soup = BeautifulSoup(fp, 'html.parser')
|
||||
text = soup.get_text()
|
||||
text = text.strip() if text else ''
|
||||
|
||||
def _read_from_file(self, csvfile) -> list[Document]:
|
||||
docs = []
|
||||
csv_reader = csv.DictReader(csvfile, **self.csv_args) # type: ignore
|
||||
for i, row in enumerate(csv_reader):
|
||||
content = "\n".join(f"{k.strip()}: {v.strip()}" for k, v in row.items())
|
||||
try:
|
||||
source = (
|
||||
row[self.source_column]
|
||||
if self.source_column is not None
|
||||
else ''
|
||||
)
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Source column '{self.source_column}' not found in CSV file."
|
||||
)
|
||||
metadata = {"source": source, "row": i}
|
||||
doc = Document(page_content=content, metadata=metadata)
|
||||
docs.append(doc)
|
||||
|
||||
return docs
|
||||
return text
|
||||
|
|
@ -4,7 +4,6 @@ from typing import Any, Optional
|
|||
|
||||
import requests
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
|
||||
from core.rag.extractor.extractor_base import BaseExtractor
|
||||
from core.rag.models.document import Document
|
||||
|
|
@ -30,8 +29,10 @@ class NotionExtractor(BaseExtractor):
|
|||
notion_workspace_id: str,
|
||||
notion_obj_id: str,
|
||||
notion_page_type: str,
|
||||
tenant_id: str,
|
||||
document_model: Optional[DocumentModel] = None,
|
||||
notion_access_token: Optional[str] = None
|
||||
notion_access_token: Optional[str] = None,
|
||||
|
||||
):
|
||||
self._notion_access_token = None
|
||||
self._document_model = document_model
|
||||
|
|
@ -41,7 +42,7 @@ class NotionExtractor(BaseExtractor):
|
|||
if notion_access_token:
|
||||
self._notion_access_token = notion_access_token
|
||||
else:
|
||||
self._notion_access_token = self._get_access_token(current_user.current_tenant_id,
|
||||
self._notion_access_token = self._get_access_token(tenant_id,
|
||||
self._notion_workspace_id)
|
||||
if not self._notion_access_token:
|
||||
integration_token = current_app.config.get('NOTION_INTEGRATION_TOKEN')
|
||||
|
|
|
|||
|
|
@ -1,189 +0,0 @@
|
|||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import queue
|
||||
import ssl
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode, urlparse
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import websocket
|
||||
|
||||
|
||||
class SparkLLMClient:
|
||||
def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
|
||||
domain = 'spark-api.xf-yun.com'
|
||||
endpoint = 'chat'
|
||||
if api_domain:
|
||||
domain = api_domain
|
||||
if model_name == 'spark-v3':
|
||||
endpoint = 'multimodal'
|
||||
|
||||
model_api_configs = {
|
||||
'spark': {
|
||||
'version': 'v1.1',
|
||||
'chat_domain': 'general'
|
||||
},
|
||||
'spark-v2': {
|
||||
'version': 'v2.1',
|
||||
'chat_domain': 'generalv2'
|
||||
},
|
||||
'spark-v3': {
|
||||
'version': 'v3.1',
|
||||
'chat_domain': 'generalv3'
|
||||
},
|
||||
'spark-v3.5': {
|
||||
'version': 'v3.5',
|
||||
'chat_domain': 'generalv3.5'
|
||||
}
|
||||
}
|
||||
|
||||
api_version = model_api_configs[model_name]['version']
|
||||
|
||||
self.chat_domain = model_api_configs[model_name]['chat_domain']
|
||||
self.api_base = f"wss://{domain}/{api_version}/{endpoint}"
|
||||
self.app_id = app_id
|
||||
self.ws_url = self.create_url(
|
||||
urlparse(self.api_base).netloc,
|
||||
urlparse(self.api_base).path,
|
||||
self.api_base,
|
||||
api_key,
|
||||
api_secret
|
||||
)
|
||||
|
||||
self.queue = queue.Queue()
|
||||
self.blocking_message = ''
|
||||
|
||||
def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:
|
||||
# generate timestamp by RFC1123
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
signature_origin = "host: " + host + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + path + " HTTP/1.1"
|
||||
|
||||
# encrypt using hmac-sha256
|
||||
signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
|
||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
|
||||
authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
||||
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": host
|
||||
}
|
||||
# generate url
|
||||
url = api_base + '?' + urlencode(v)
|
||||
return url
|
||||
|
||||
def run(self, messages: list, user_id: str,
|
||||
model_kwargs: Optional[dict] = None, streaming: bool = False):
|
||||
websocket.enableTrace(False)
|
||||
ws = websocket.WebSocketApp(
|
||||
self.ws_url,
|
||||
on_message=self.on_message,
|
||||
on_error=self.on_error,
|
||||
on_close=self.on_close,
|
||||
on_open=self.on_open
|
||||
)
|
||||
ws.messages = messages
|
||||
ws.user_id = user_id
|
||||
ws.model_kwargs = model_kwargs
|
||||
ws.streaming = streaming
|
||||
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
||||
|
||||
def on_error(self, ws, error):
|
||||
self.queue.put({
|
||||
'status_code': error.status_code,
|
||||
'error': error.resp_body.decode('utf-8')
|
||||
})
|
||||
ws.close()
|
||||
|
||||
def on_close(self, ws, close_status_code, close_reason):
|
||||
self.queue.put({'done': True})
|
||||
|
||||
def on_open(self, ws):
|
||||
self.blocking_message = ''
|
||||
data = json.dumps(self.gen_params(
|
||||
messages=ws.messages,
|
||||
user_id=ws.user_id,
|
||||
model_kwargs=ws.model_kwargs
|
||||
))
|
||||
ws.send(data)
|
||||
|
||||
def on_message(self, ws, message):
|
||||
data = json.loads(message)
|
||||
code = data['header']['code']
|
||||
if code != 0:
|
||||
self.queue.put({
|
||||
'status_code': 400,
|
||||
'error': f"Code: {code}, Error: {data['header']['message']}"
|
||||
})
|
||||
ws.close()
|
||||
else:
|
||||
choices = data["payload"]["choices"]
|
||||
status = choices["status"]
|
||||
content = choices["text"][0]["content"]
|
||||
if ws.streaming:
|
||||
self.queue.put({'data': content})
|
||||
else:
|
||||
self.blocking_message += content
|
||||
|
||||
if status == 2:
|
||||
if not ws.streaming:
|
||||
self.queue.put({'data': self.blocking_message})
|
||||
ws.close()
|
||||
|
||||
def gen_params(self, messages: list, user_id: str,
|
||||
model_kwargs: Optional[dict] = None) -> dict:
|
||||
data = {
|
||||
"header": {
|
||||
"app_id": self.app_id,
|
||||
"uid": user_id
|
||||
},
|
||||
"parameter": {
|
||||
"chat": {
|
||||
"domain": self.chat_domain
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"message": {
|
||||
"text": messages
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if model_kwargs:
|
||||
data['parameter']['chat'].update(model_kwargs)
|
||||
|
||||
return data
|
||||
|
||||
def subscribe(self):
|
||||
while True:
|
||||
content = self.queue.get()
|
||||
if 'error' in content:
|
||||
if content['status_code'] == 401:
|
||||
raise SparkError('[Spark] The credentials you provided are incorrect. '
|
||||
'Please double-check and fill them in again.')
|
||||
elif content['status_code'] == 403:
|
||||
raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
|
||||
"Please try again after obtaining the necessary permissions.")
|
||||
else:
|
||||
raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
|
||||
|
||||
if 'data' not in content:
|
||||
break
|
||||
yield content
|
||||
|
||||
|
||||
class SparkError(Exception):
|
||||
pass
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
from datetime import datetime
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DatetimeToolInput(BaseModel):
|
||||
type: str = Field(..., description="Type for current time, must be: datetime.")
|
||||
|
||||
|
||||
class DatetimeTool(BaseTool):
|
||||
"""Tool for querying current datetime."""
|
||||
name: str = "current_datetime"
|
||||
args_schema: type[BaseModel] = DatetimeToolInput
|
||||
description: str = "A tool when you want to get the current date, time, week, month or year, " \
|
||||
"and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\"."
|
||||
|
||||
def _run(self, type: str) -> str:
|
||||
# get current time
|
||||
current_time = datetime.utcnow()
|
||||
return current_time.strftime("%Y-%m-%d %H:%M:%S UTC+0000 %A")
|
||||
|
||||
async def _arun(self, tool_input: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
|
@ -1,63 +0,0 @@
|
|||
import base64
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs import rsa
|
||||
from models.account import Tenant
|
||||
from models.tool import ToolProvider, ToolProviderName
|
||||
|
||||
|
||||
class BaseToolProvider(ABC):
|
||||
def __init__(self, tenant_id: str):
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_name(self) -> ToolProviderName:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def encrypt_credentials(self, credentials: dict) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_credentials(self, obfuscated: bool = False) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def credentials_to_func_kwargs(self) -> Optional[dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def credentials_validate(self, credentials: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_provider(self, must_enabled: bool = False) -> Optional[ToolProvider]:
|
||||
"""
|
||||
Returns the Provider instance for the given tenant_id and tool_name.
|
||||
"""
|
||||
query = db.session.query(ToolProvider).filter(
|
||||
ToolProvider.tenant_id == self.tenant_id,
|
||||
ToolProvider.tool_name == self.get_provider_name().value
|
||||
)
|
||||
|
||||
if must_enabled:
|
||||
query = query.filter(ToolProvider.is_enabled == True)
|
||||
|
||||
return query.first()
|
||||
|
||||
def encrypt_token(self, token) -> str:
|
||||
tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
|
||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||
return base64.b64encode(encrypted_token).decode()
|
||||
|
||||
def decrypt_token(self, token: str, obfuscated: bool = False) -> str:
|
||||
token = rsa.decrypt(base64.b64decode(token), self.tenant_id)
|
||||
|
||||
if obfuscated:
|
||||
return self._obfuscated_token(token)
|
||||
|
||||
return token
|
||||
|
||||
def _obfuscated_token(self, token: str) -> str:
|
||||
return token[:6] + '*' * (len(token) - 8) + token[-2:]
|
||||
|
|
@ -1,2 +0,0 @@
|
|||
class ToolValidateFailedError(Exception):
|
||||
description = "Tool Provider Validate failed"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue