diff --git a/.github/workflows/check_no_chinese_comments.py b/.github/workflows/check_no_chinese_comments.py index fc01b8163a..e59cfb538b 100644 --- a/.github/workflows/check_no_chinese_comments.py +++ b/.github/workflows/check_no_chinese_comments.py @@ -19,7 +19,8 @@ def check_file_for_chinese_comments(file_path): def main(): has_chinese = False - excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', 'indexing_runner.py', 'web_reader_tool.py'] + excluded_files = ["model_template.py", 'stopwords.py', 'commands.py', + 'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py'] for root, _, files in os.walk("."): for file in files: diff --git a/api/.env.example b/api/.env.example index 3ea7e5be34..946e5e9afc 100644 --- a/api/.env.example +++ b/api/.env.example @@ -102,3 +102,29 @@ NOTION_INTEGRATION_TYPE=public NOTION_CLIENT_SECRET=you-client-secret NOTION_CLIENT_ID=you-client-id NOTION_INTERNAL_SECRET=you-internal-secret + +# Hosted Model Credentials +HOSTED_OPENAI_ENABLED=false +HOSTED_OPENAI_API_KEY= +HOSTED_OPENAI_API_BASE= +HOSTED_OPENAI_API_ORGANIZATION= +HOSTED_OPENAI_QUOTA_LIMIT=200 +HOSTED_OPENAI_PAID_ENABLED=false +HOSTED_OPENAI_PAID_STRIPE_PRICE_ID= +HOSTED_OPENAI_PAID_INCREASE_QUOTA=1 + +HOSTED_AZURE_OPENAI_ENABLED=false +HOSTED_AZURE_OPENAI_API_KEY= +HOSTED_AZURE_OPENAI_API_BASE= +HOSTED_AZURE_OPENAI_QUOTA_LIMIT=200 + +HOSTED_ANTHROPIC_ENABLED=false +HOSTED_ANTHROPIC_API_BASE= +HOSTED_ANTHROPIC_API_KEY= +HOSTED_ANTHROPIC_QUOTA_LIMIT=1000000 +HOSTED_ANTHROPIC_PAID_ENABLED=false +HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID= +HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA=1 + +STRIPE_API_KEY= +STRIPE_WEBHOOK_SECRET= \ No newline at end of file diff --git a/api/app.py b/api/app.py index dac357c0e9..4a10212f6f 100644 --- a/api/app.py +++ b/api/app.py @@ -16,8 +16,9 @@ from flask import Flask, request, Response, session import flask_login from flask_cors import CORS +from core.model_providers.providers import hosted from extensions import ext_session, ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \ - ext_database, ext_storage, ext_mail + ext_database, ext_storage, ext_mail, ext_stripe from extensions.ext_database import db from extensions.ext_login import login_manager @@ -71,7 +72,7 @@ def create_app(test_config=None) -> Flask: register_blueprints(app) register_commands(app) - core.init_app(app) + hosted.init_app(app) return app @@ -88,6 +89,7 @@ def initialize_extensions(app): ext_login.init_app(app) ext_mail.init_app(app) ext_sentry.init_app(app) + ext_stripe.init_app(app) def _create_tenant_for_account(account): @@ -246,5 +248,18 @@ def threads(): } +@app.route('/db-pool-stat') +def pool_stat(): + engine = db.engine + return { + 'pool_size': engine.pool.size(), + 'checked_in_connections': engine.pool.checkedin(), + 'checked_out_connections': engine.pool.checkedout(), + 'overflow_connections': engine.pool.overflow(), + 'connection_timeout': engine.pool.timeout(), + 'recycle_time': db.engine.pool._recycle + } + + if __name__ == '__main__': app.run(host='0.0.0.0', port=5001) diff --git a/api/commands.py b/api/commands.py index a25c5fc030..caa5e1ee20 100644 --- a/api/commands.py +++ b/api/commands.py @@ -1,5 +1,5 @@ import datetime -import logging +import math import random import string import time @@ -9,18 +9,18 @@ from flask import current_app from werkzeug.exceptions import NotFound from core.index.index import IndexBuilder +from core.model_providers.providers.hosted import hosted_model_providers from libs.password import password_pattern, valid_password, hash_password from libs.helper import email as email_validate from extensions.ext_database import db from libs.rsa import generate_key_pair from models.account import InvitationCode, Tenant -from models.dataset import Dataset, DatasetQuery, Document, DocumentSegment +from models.dataset import Dataset, DatasetQuery, Document from models.model import Account import secrets import base64 -from models.provider import Provider, ProviderName -from services.provider_service import ProviderService +from models.provider import Provider, ProviderType, ProviderQuotaType @click.command('reset-password', help='Reset the account password.') @@ -251,26 +251,37 @@ def clean_unused_dataset_indexes(): @click.command('sync-anthropic-hosted-providers', help='Sync anthropic hosted providers.') def sync_anthropic_hosted_providers(): + if not hosted_model_providers.anthropic: + click.echo(click.style('Anthropic hosted provider is not configured.', fg='red')) + return + click.echo(click.style('Start sync anthropic hosted providers.', fg='green')) count = 0 page = 1 while True: try: - tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc()).paginate(page=page, per_page=50) + providers = db.session.query(Provider).filter( + Provider.provider_name == 'anthropic', + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == ProviderQuotaType.TRIAL.value, + ).order_by(Provider.created_at.desc()).paginate(page=page, per_page=100) except NotFound: break page += 1 - for tenant in tenants: + for provider in providers: try: - click.echo('Syncing tenant anthropic hosted provider: {}'.format(tenant.id)) - ProviderService.create_system_provider( - tenant, - ProviderName.ANTHROPIC.value, - current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'], - True - ) + click.echo('Syncing tenant anthropic hosted provider: {}'.format(provider.tenant_id)) + original_quota_limit = provider.quota_limit + new_quota_limit = hosted_model_providers.anthropic.quota_limit + division = math.ceil(new_quota_limit / 1000) + + provider.quota_limit = new_quota_limit if original_quota_limit == 1000 \ + else original_quota_limit * division + provider.quota_used = division * provider.quota_used + db.session.commit() + count += 1 except Exception as e: click.echo(click.style( diff --git a/api/config.py b/api/config.py index 26fdfefff2..32e2c66ede 100644 --- a/api/config.py +++ b/api/config.py @@ -41,6 +41,7 @@ DEFAULTS = { 'SESSION_USE_SIGNER': 'True', 'DEPLOY_ENV': 'PRODUCTION', 'SQLALCHEMY_POOL_SIZE': 30, + 'SQLALCHEMY_POOL_RECYCLE': 3600, 'SQLALCHEMY_ECHO': 'False', 'SENTRY_TRACES_SAMPLE_RATE': 1.0, 'SENTRY_PROFILES_SAMPLE_RATE': 1.0, @@ -50,9 +51,16 @@ DEFAULTS = { 'PDF_PREVIEW': 'True', 'LOG_LEVEL': 'INFO', 'DISABLE_PROVIDER_CONFIG_VALIDATION': 'False', - 'DEFAULT_LLM_PROVIDER': 'openai', - 'OPENAI_HOSTED_QUOTA_LIMIT': 200, - 'ANTHROPIC_HOSTED_QUOTA_LIMIT': 1000, + 'HOSTED_OPENAI_QUOTA_LIMIT': 200, + 'HOSTED_OPENAI_ENABLED': 'False', + 'HOSTED_OPENAI_PAID_ENABLED': 'False', + 'HOSTED_OPENAI_PAID_INCREASE_QUOTA': 1, + 'HOSTED_AZURE_OPENAI_ENABLED': 'False', + 'HOSTED_AZURE_OPENAI_QUOTA_LIMIT': 200, + 'HOSTED_ANTHROPIC_QUOTA_LIMIT': 1000000, + 'HOSTED_ANTHROPIC_ENABLED': 'False', + 'HOSTED_ANTHROPIC_PAID_ENABLED': 'False', + 'HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA': 1, 'TENANT_DOCUMENT_COUNT': 100, 'CLEAN_DAY_SETTING': 30 } @@ -182,7 +190,10 @@ class Config: } self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}" - self.SQLALCHEMY_ENGINE_OPTIONS = {'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE'))} + self.SQLALCHEMY_ENGINE_OPTIONS = { + 'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')), + 'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE')) + } self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO') @@ -194,20 +205,35 @@ class Config: self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://') # hosted provider credentials - self.OPENAI_API_KEY = get_env('OPENAI_API_KEY') - self.ANTHROPIC_API_KEY = get_env('ANTHROPIC_API_KEY') + self.HOSTED_OPENAI_ENABLED = get_bool_env('HOSTED_OPENAI_ENABLED') + self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY') + self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE') + self.HOSTED_OPENAI_API_ORGANIZATION = get_env('HOSTED_OPENAI_API_ORGANIZATION') + self.HOSTED_OPENAI_QUOTA_LIMIT = get_env('HOSTED_OPENAI_QUOTA_LIMIT') + self.HOSTED_OPENAI_PAID_ENABLED = get_bool_env('HOSTED_OPENAI_PAID_ENABLED') + self.HOSTED_OPENAI_PAID_STRIPE_PRICE_ID = get_env('HOSTED_OPENAI_PAID_STRIPE_PRICE_ID') + self.HOSTED_OPENAI_PAID_INCREASE_QUOTA = int(get_env('HOSTED_OPENAI_PAID_INCREASE_QUOTA')) - self.OPENAI_HOSTED_QUOTA_LIMIT = get_env('OPENAI_HOSTED_QUOTA_LIMIT') - self.ANTHROPIC_HOSTED_QUOTA_LIMIT = get_env('ANTHROPIC_HOSTED_QUOTA_LIMIT') + self.HOSTED_AZURE_OPENAI_ENABLED = get_bool_env('HOSTED_AZURE_OPENAI_ENABLED') + self.HOSTED_AZURE_OPENAI_API_KEY = get_env('HOSTED_AZURE_OPENAI_API_KEY') + self.HOSTED_AZURE_OPENAI_API_BASE = get_env('HOSTED_AZURE_OPENAI_API_BASE') + self.HOSTED_AZURE_OPENAI_QUOTA_LIMIT = get_env('HOSTED_AZURE_OPENAI_QUOTA_LIMIT') + + self.HOSTED_ANTHROPIC_ENABLED = get_bool_env('HOSTED_ANTHROPIC_ENABLED') + self.HOSTED_ANTHROPIC_API_BASE = get_env('HOSTED_ANTHROPIC_API_BASE') + self.HOSTED_ANTHROPIC_API_KEY = get_env('HOSTED_ANTHROPIC_API_KEY') + self.HOSTED_ANTHROPIC_QUOTA_LIMIT = get_env('HOSTED_ANTHROPIC_QUOTA_LIMIT') + self.HOSTED_ANTHROPIC_PAID_ENABLED = get_bool_env('HOSTED_ANTHROPIC_PAID_ENABLED') + self.HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID = get_env('HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID') + self.HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA = get_env('HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA') + + self.STRIPE_API_KEY = get_env('STRIPE_API_KEY') + self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET') # By default it is False # You could disable it for compatibility with certain OpenAPI providers self.DISABLE_PROVIDER_CONFIG_VALIDATION = get_bool_env('DISABLE_PROVIDER_CONFIG_VALIDATION') - # For temp use only - # set default LLM provider, default is 'openai', support `azure_openai` - self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER') - # notion import setting self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID') self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET') diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index deb4c44250..4834f84555 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -18,10 +18,13 @@ from .auth import login, oauth, data_source_oauth, activate from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source # Import workspace controllers -from .workspace import workspace, members, model_providers, account, tool_providers +from .workspace import workspace, members, providers, model_providers, account, tool_providers, models # Import explore controllers from .explore import installed_app, recommended_app, completion, conversation, message, parameter, saved_message, audio # Import universal chat controllers from .universal_chat import chat, conversation, message, parameter, audio + +# Import webhook controllers +from .webhook import stripe diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index eb443931d2..d0bbbe8bf0 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -2,16 +2,17 @@ import json from datetime import datetime -import flask from flask_login import login_required, current_user from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs -from werkzeug.exceptions import Unauthorized, Forbidden +from werkzeug.exceptions import Forbidden from constants.model_template import model_templates, demo_model_templates from controllers.console import api -from controllers.console.app.error import AppNotFoundError +from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required +from core.model_providers.model_factory import ModelFactory +from core.model_providers.models.entity.model_params import ModelType from events.app_event import app_was_created, app_was_deleted from libs.helper import TimestampField from extensions.ext_database import db @@ -126,9 +127,9 @@ class AppListApi(Resource): if args['model_config'] is not None: # validate config model_configuration = AppModelConfigService.validate_configuration( + tenant_id=current_user.current_tenant_id, account=current_user, - config=args['model_config'], - mode=args['mode'] + config=args['model_config'] ) app = App( @@ -164,6 +165,21 @@ class AppListApi(Resource): app = App(**model_config_template['app']) app_model_config = AppModelConfig(**model_config_template['model_config']) + default_model = ModelFactory.get_default_model( + tenant_id=current_user.current_tenant_id, + model_type=ModelType.TEXT_GENERATION + ) + + if default_model: + model_dict = app_model_config.model_dict + model_dict['provider'] = default_model.provider_name + model_dict['name'] = default_model.model_name + app_model_config.model = json.dumps(model_dict) + else: + raise ProviderNotInitializeError( + f"No Text Generation Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") + app.name = args['name'] app.mode = args['mode'] app.icon = args['icon'] diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 075e8d4a91..16749e994e 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -14,7 +14,7 @@ from controllers.console.app.error import AppUnavailableError, \ UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ +from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from flask_restful import Resource from services.audio_service import AudioService diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index e76186671d..0773abb202 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -17,7 +17,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.conversation_message_task import PubHandler -from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ +from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from libs.helper import uuid_value from flask_restful import Resource, reqparse @@ -41,8 +41,11 @@ class CompletionMessageApi(Resource): parser.add_argument('inputs', type=dict, required=True, location='json') parser.add_argument('query', type=str, location='json') parser.add_argument('model_config', type=dict, required=True, location='json') + parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') args = parser.parse_args() + streaming = args['response_mode'] != 'blocking' + account = flask_login.current_user try: @@ -51,7 +54,7 @@ class CompletionMessageApi(Resource): user=account, args=args, from_source='console', - streaming=True, + streaming=streaming, is_model_config_override=True ) @@ -111,8 +114,11 @@ class ChatMessageApi(Resource): parser.add_argument('query', type=str, required=True, location='json') parser.add_argument('model_config', type=dict, required=True, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json') + parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') args = parser.parse_args() + streaming = args['response_mode'] != 'blocking' + account = flask_login.current_user try: @@ -121,7 +127,7 @@ class ChatMessageApi(Resource): user=account, args=args, from_source='console', - streaming=True, + streaming=streaming, is_model_config_override=True ) diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 6b9a0a2140..f572f855e2 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -7,7 +7,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.generator.llm_generator import LLMGenerator -from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \ +from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, LLMBadRequestError, LLMAPIConnectionError, \ LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index c5764a7ec7..9c527eddc9 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -14,7 +14,7 @@ from controllers.console.app.error import CompletionRequestError, ProviderNotIni AppMoreLikeThisDisabledError, ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ +from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError from libs.helper import uuid_value, TimestampField from libs.infinite_scroll_pagination import InfiniteScrollPagination diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index 3197f9eba1..d0c648ba16 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -28,9 +28,9 @@ class ModelConfigResource(Resource): # validate config model_configuration = AppModelConfigService.validate_configuration( + tenant_id=current_user.current_tenant_id, account=current_user, - config=request.json, - mode=app_model.mode + config=request.json ) new_app_model_config = AppModelConfig( diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 4151132621..65f8225f12 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -255,7 +255,7 @@ class DataSourceNotionApi(Resource): # validate args DocumentService.estimate_args_validate(args) indexing_runner = IndexingRunner() - response = indexing_runner.notion_indexing_estimate(args['notion_info_list'], args['process_rule']) + response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, args['notion_info_list'], args['process_rule']) return response, 200 diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 1881103447..dfe9026eb4 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -5,10 +5,13 @@ from flask_restful import Resource, reqparse, fields, marshal, marshal_with from werkzeug.exceptions import NotFound, Forbidden import services from controllers.console import api +from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.error import DatasetNameDuplicateError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.indexing_runner import IndexingRunner +from core.model_providers.error import LLMBadRequestError +from core.model_providers.model_factory import ModelFactory from libs.helper import TimestampField from extensions.ext_database import db from models.dataset import DocumentSegment, Document @@ -97,6 +100,15 @@ class DatasetListApi(Resource): if current_user.current_tenant.current_role not in ['admin', 'owner']: raise Forbidden() + try: + ModelFactory.get_embedding_model( + tenant_id=current_user.current_tenant_id + ) + except LLMBadRequestError: + raise ProviderNotInitializeError( + f"No Embedding Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") + try: dataset = DatasetService.create_empty_dataset( tenant_id=current_user.current_tenant_id, @@ -235,12 +247,26 @@ class DatasetIndexingEstimateApi(Resource): raise NotFound("File not found.") indexing_runner = IndexingRunner() - response = indexing_runner.file_indexing_estimate(file_details, args['process_rule'], args['doc_form']) + + try: + response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, + args['process_rule'], args['doc_form']) + except LLMBadRequestError: + raise ProviderNotInitializeError( + f"No Embedding Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") elif args['info_list']['data_source_type'] == 'notion_import': indexing_runner = IndexingRunner() - response = indexing_runner.notion_indexing_estimate(args['info_list']['notion_info_list'], - args['process_rule'], args['doc_form']) + + try: + response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, + args['info_list']['notion_info_list'], + args['process_rule'], args['doc_form']) + except LLMBadRequestError: + raise ProviderNotInitializeError( + f"No Embedding Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") else: raise ValueError('Data source type not support') return response, 200 diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 02ddfbf467..a1ef7b767c 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -18,7 +18,9 @@ from controllers.console.datasets.error import DocumentAlreadyFinishedError, Inv from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.indexing_runner import IndexingRunner -from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ + LLMBadRequestError +from core.model_providers.model_factory import ModelFactory from extensions.ext_redis import redis_client from libs.helper import TimestampField from extensions.ext_database import db @@ -280,6 +282,15 @@ class DatasetDocumentListApi(Resource): # validate args DocumentService.document_create_args_validate(args) + try: + ModelFactory.get_embedding_model( + tenant_id=current_user.current_tenant_id + ) + except LLMBadRequestError: + raise ProviderNotInitializeError( + f"No Embedding Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") + try: documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) except ProviderTokenNotInitError as ex: @@ -319,6 +330,15 @@ class DatasetInitApi(Resource): parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') args = parser.parse_args() + try: + ModelFactory.get_embedding_model( + tenant_id=current_user.current_tenant_id + ) + except LLMBadRequestError: + raise ProviderNotInitializeError( + f"No Embedding Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") + # validate args DocumentService.document_create_args_validate(args) @@ -384,7 +404,13 @@ class DocumentIndexingEstimateApi(DocumentResource): indexing_runner = IndexingRunner() - response = indexing_runner.file_indexing_estimate([file], data_process_rule_dict) + try: + response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, [file], + data_process_rule_dict) + except LLMBadRequestError: + raise ProviderNotInitializeError( + f"No Embedding Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") return response @@ -445,12 +471,24 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): raise NotFound("File not found.") indexing_runner = IndexingRunner() - response = indexing_runner.file_indexing_estimate(file_details, data_process_rule_dict) + try: + response = indexing_runner.file_indexing_estimate(current_user.current_tenant_id, file_details, + data_process_rule_dict) + except LLMBadRequestError: + raise ProviderNotInitializeError( + f"No Embedding Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") elif dataset.data_source_type: indexing_runner = IndexingRunner() - response = indexing_runner.notion_indexing_estimate(info_list, - data_process_rule_dict) + try: + response = indexing_runner.notion_indexing_estimate(current_user.current_tenant_id, + info_list, + data_process_rule_dict) + except LLMBadRequestError: + raise ProviderNotInitializeError( + f"No Embedding Model available. Please configure a valid provider " + f"in the Settings -> Model Provider.") else: raise ValueError('Data source type not support') return response diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index f627949d33..399bd4c0c9 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -11,7 +11,7 @@ from controllers.console.app.error import ProviderNotInitializeError, ProviderQu from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError +from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from libs.helper import TimestampField from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService @@ -102,6 +102,8 @@ class HitTestingApi(Resource): raise ProviderQuotaExceededError() except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() + except ValueError as e: + raise ValueError(str(e)) except Exception as e: logging.exception("Hit testing failed.") raise InternalServerError(str(e)) diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 991a228dd5..50ddfac98f 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -11,7 +11,7 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia NoAudioUploadedError, AudioTooLargeError, \ UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError from controllers.console.explore.wraps import InstalledAppResource -from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ +from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from services.audio_service import AudioService from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index bc4b88ad15..d48c85a731 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -15,7 +15,7 @@ from controllers.console.app.error import ConversationCompletedError, AppUnavail from controllers.console.explore.error import NotCompletionAppError, NotChatAppError from controllers.console.explore.wraps import InstalledAppResource from core.conversation_message_task import PubHandler -from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ +from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from libs.helper import uuid_value from services.completion_service import CompletionService diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 1232169eab..160ebee122 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -15,7 +15,7 @@ from controllers.console.app.error import AppMoreLikeThisDisabledError, Provider ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError from controllers.console.explore.error import NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError from controllers.console.explore.wraps import InstalledAppResource -from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ +from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError from libs.helper import uuid_value, TimestampField from services.completion_service import CompletionService diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 09b4f987e7..13e356cb26 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields from controllers.console import api from controllers.console.explore.wraps import InstalledAppResource -from core.llm.llm_builder import LLMBuilder -from models.provider import ProviderName from models.model import InstalledApp @@ -35,13 +33,12 @@ class AppParameterApi(InstalledAppResource): """Retrieve app parameters.""" app_model = installed_app.app app_model_config = app_model.app_model_config - provider_name = LLMBuilder.get_default_provider(installed_app.tenant_id, 'whisper-1') return { 'opening_statement': app_model_config.opening_statement, 'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, - 'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False }, + 'speech_to_text': app_model_config.speech_to_text_dict, 'more_like_this': app_model_config.more_like_this_dict, 'user_input_form': app_model_config.user_input_form_list } diff --git a/api/controllers/console/universal_chat/audio.py b/api/controllers/console/universal_chat/audio.py index 41d5382c7d..38bcc25b29 100644 --- a/api/controllers/console/universal_chat/audio.py +++ b/api/controllers/console/universal_chat/audio.py @@ -11,7 +11,7 @@ from controllers.console.app.error import AppUnavailableError, ProviderNotInitia NoAudioUploadedError, AudioTooLargeError, \ UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError from controllers.console.universal_chat.wraps import UniversalChatResource -from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ +from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from services.audio_service import AudioService from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ diff --git a/api/controllers/console/universal_chat/chat.py b/api/controllers/console/universal_chat/chat.py index 2a95eb992b..a6aa842042 100644 --- a/api/controllers/console/universal_chat/chat.py +++ b/api/controllers/console/universal_chat/chat.py @@ -12,9 +12,8 @@ from controllers.console import api from controllers.console.app.error import ConversationCompletedError, AppUnavailableError, ProviderNotInitializeError, \ ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError from controllers.console.universal_chat.wraps import UniversalChatResource -from core.constant import llm_constant from core.conversation_message_task import PubHandler -from core.llm.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ +from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError, \ LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError from libs.helper import uuid_value from services.completion_service import CompletionService @@ -27,6 +26,7 @@ class UniversalChatApi(UniversalChatResource): parser = reqparse.RequestParser() parser.add_argument('query', type=str, required=True, location='json') parser.add_argument('conversation_id', type=uuid_value, location='json') + parser.add_argument('provider', type=str, required=True, location='json') parser.add_argument('model', type=str, required=True, location='json') parser.add_argument('tools', type=list, required=True, location='json') args = parser.parse_args() @@ -36,11 +36,7 @@ class UniversalChatApi(UniversalChatResource): # update app model config args['model_config'] = app_model_config.to_dict() args['model_config']['model']['name'] = args['model'] - - if not llm_constant.models[args['model']]: - raise ValueError("Model not exists.") - - args['model_config']['model']['provider'] = llm_constant.models[args['model']] + args['model_config']['model']['provider'] = args['provider'] args['model_config']['agent_mode']['tools'] = args['tools'] if not args['model_config']['agent_mode']['tools']: diff --git a/api/controllers/console/universal_chat/message.py b/api/controllers/console/universal_chat/message.py index cbb4134828..07d8b37fee 100644 --- a/api/controllers/console/universal_chat/message.py +++ b/api/controllers/console/universal_chat/message.py @@ -12,7 +12,7 @@ from controllers.console.app.error import ProviderNotInitializeError, \ ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, CompletionRequestError from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError from controllers.console.universal_chat.wraps import UniversalChatResource -from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ +from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError from libs.helper import uuid_value, TimestampField from services.errors.conversation import ConversationNotExistsError diff --git a/api/controllers/console/universal_chat/parameter.py b/api/controllers/console/universal_chat/parameter.py index c8351d0cb5..b492bba501 100644 --- a/api/controllers/console/universal_chat/parameter.py +++ b/api/controllers/console/universal_chat/parameter.py @@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields from controllers.console import api from controllers.console.universal_chat.wraps import UniversalChatResource -from core.llm.llm_builder import LLMBuilder -from models.provider import ProviderName from models.model import App @@ -23,13 +21,12 @@ class UniversalChatParameterApi(UniversalChatResource): """Retrieve app parameters.""" app_model = universal_app app_model_config = app_model.app_model_config - provider_name = LLMBuilder.get_default_provider(universal_app.tenant_id, 'whisper-1') return { 'opening_statement': app_model_config.opening_statement, 'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, - 'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False }, + 'speech_to_text': app_model_config.speech_to_text_dict, } diff --git a/api/tests/test_controllers/__init__.py b/api/controllers/console/webhook/__init__.py similarity index 100% rename from api/tests/test_controllers/__init__.py rename to api/controllers/console/webhook/__init__.py diff --git a/api/controllers/console/webhook/stripe.py b/api/controllers/console/webhook/stripe.py new file mode 100644 index 0000000000..da906b0dc8 --- /dev/null +++ b/api/controllers/console/webhook/stripe.py @@ -0,0 +1,53 @@ +import logging + +import stripe +from flask import request, current_app +from flask_restful import Resource + +from controllers.console import api +from controllers.console.setup import setup_required +from controllers.console.wraps import only_edition_cloud +from services.provider_checkout_service import ProviderCheckoutService + + +class StripeWebhookApi(Resource): + @setup_required + @only_edition_cloud + def post(self): + payload = request.data + sig_header = request.headers.get('STRIPE_SIGNATURE') + webhook_secret = current_app.config.get('STRIPE_WEBHOOK_SECRET') + + try: + event = stripe.Webhook.construct_event( + payload, sig_header, webhook_secret + ) + except ValueError as e: + # Invalid payload + return 'Invalid payload', 400 + except stripe.error.SignatureVerificationError as e: + # Invalid signature + return 'Invalid signature', 400 + + # Handle the checkout.session.completed event + if event['type'] == 'checkout.session.completed': + logging.debug(event['data']['object']['id']) + logging.debug(event['data']['object']['amount_subtotal']) + logging.debug(event['data']['object']['currency']) + logging.debug(event['data']['object']['payment_intent']) + logging.debug(event['data']['object']['payment_status']) + logging.debug(event['data']['object']['metadata']) + + # Fulfill the purchase... + provider_checkout_service = ProviderCheckoutService() + + try: + provider_checkout_service.fulfill_provider_order(event) + except Exception as e: + logging.debug(str(e)) + return 'success', 200 + + return 'success', 200 + + +api.add_resource(StripeWebhookApi, '/webhook/stripe') diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 1991b1c6c9..d32591803a 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -1,24 +1,18 @@ -# -*- coding:utf-8 -*- -import base64 -import json -import logging - -from flask import current_app from flask_login import login_required, current_user -from flask_restful import Resource, reqparse, abort +from flask_restful import Resource, reqparse from werkzeug.exceptions import Forbidden from controllers.console import api +from controllers.console.app.error import ProviderNotInitializeError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.llm.provider.errors import ValidateFailedError -from extensions.ext_database import db -from libs import rsa -from models.provider import Provider, ProviderType, ProviderName +from core.model_providers.error import LLMBadRequestError +from core.model_providers.providers.base import CredentialsValidateFailedError +from services.provider_checkout_service import ProviderCheckoutService from services.provider_service import ProviderService -class ProviderListApi(Resource): +class ModelProviderListApi(Resource): @setup_required @login_required @@ -26,156 +20,36 @@ class ProviderListApi(Resource): def get(self): tenant_id = current_user.current_tenant_id - """ - If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:, - azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the - rest is replaced by * and the last two bits are displayed in plaintext - - If the type is other, decode and return the Token field directly, the field displays the first 6 bits in - plaintext, the rest is replaced by * and the last two bits are displayed in plaintext - """ - - ProviderService.init_supported_provider(current_user.current_tenant) - providers = Provider.query.filter_by(tenant_id=tenant_id).all() - - provider_list = [ - { - 'provider_name': p.provider_name, - 'provider_type': p.provider_type, - 'is_valid': p.is_valid, - 'last_used': p.last_used, - 'is_enabled': p.is_enabled, - **({ - 'quota_type': p.quota_type, - 'quota_limit': p.quota_limit, - 'quota_used': p.quota_used - } if p.provider_type == ProviderType.SYSTEM.value else {}), - 'token': ProviderService.get_obfuscated_api_key(current_user.current_tenant, - ProviderName(p.provider_name), only_custom=True) - if p.provider_type == ProviderType.CUSTOM.value else None - } - for p in providers - ] + provider_service = ProviderService() + provider_list = provider_service.get_provider_list(tenant_id) return provider_list -class ProviderTokenApi(Resource): +class ModelProviderValidateApi(Resource): @setup_required @login_required @account_initialization_required - def post(self, provider): - if provider not in [p.value for p in ProviderName]: - abort(404) - - # The role of the current user in the ta table must be admin or owner - if current_user.current_tenant.current_role not in ['admin', 'owner']: - logging.log(logging.ERROR, - f'User {current_user.id} is not authorized to update provider token, current_role is {current_user.current_tenant.current_role}') - raise Forbidden() + def post(self, provider_name: str): parser = reqparse.RequestParser() - - parser.add_argument('token', type=ProviderService.get_token_type( - tenant=current_user.current_tenant, - provider_name=ProviderName(provider) - ), required=True, nullable=False, location='json') - + parser.add_argument('config', type=dict, required=True, nullable=False, location='json') args = parser.parse_args() - if args['token']: - try: - ProviderService.validate_provider_configs( - tenant=current_user.current_tenant, - provider_name=ProviderName(provider), - configs=args['token'] - ) - token_is_valid = True - except ValidateFailedError as ex: - raise ValueError(str(ex)) - - base64_encrypted_token = ProviderService.get_encrypted_token( - tenant=current_user.current_tenant, - provider_name=ProviderName(provider), - configs=args['token'] - ) - else: - base64_encrypted_token = None - token_is_valid = False - - tenant = current_user.current_tenant - - provider_model = db.session.query(Provider).filter( - Provider.tenant_id == tenant.id, - Provider.provider_name == provider, - Provider.provider_type == ProviderType.CUSTOM.value - ).first() - - # Only allow updating token for CUSTOM provider type - if provider_model: - provider_model.encrypted_config = base64_encrypted_token - provider_model.is_valid = token_is_valid - else: - provider_model = Provider(tenant_id=tenant.id, provider_name=provider, - provider_type=ProviderType.CUSTOM.value, - encrypted_config=base64_encrypted_token, - is_valid=token_is_valid) - db.session.add(provider_model) - - if provider in [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value] and provider_model.is_valid: - other_providers = db.session.query(Provider).filter( - Provider.tenant_id == tenant.id, - Provider.provider_name.in_([ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value]), - Provider.provider_name != provider, - Provider.provider_type == ProviderType.CUSTOM.value - ).all() - - for other_provider in other_providers: - other_provider.is_valid = False - - db.session.commit() - - if provider in [ProviderName.AZURE_OPENAI.value, ProviderName.COHERE.value, - ProviderName.HUGGINGFACEHUB.value]: - return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'}, 201 - - return {'result': 'success'}, 201 - - -class ProviderTokenValidateApi(Resource): - - @setup_required - @login_required - @account_initialization_required - def post(self, provider): - if provider not in [p.value for p in ProviderName]: - abort(404) - - parser = reqparse.RequestParser() - parser.add_argument('token', type=ProviderService.get_token_type( - tenant=current_user.current_tenant, - provider_name=ProviderName(provider) - ), required=True, nullable=False, location='json') - args = parser.parse_args() - - # todo: remove this when the provider is supported - if provider in [ProviderName.COHERE.value, - ProviderName.HUGGINGFACEHUB.value]: - return {'result': 'success', 'warning': 'MOCK: This provider is not supported yet.'} + provider_service = ProviderService() result = True error = None try: - ProviderService.validate_provider_configs( - tenant=current_user.current_tenant, - provider_name=ProviderName(provider), - configs=args['token'] + provider_service.custom_provider_config_validate( + provider_name=provider_name, + config=args['config'] ) - except ValidateFailedError as e: + except CredentialsValidateFailedError as ex: result = False - error = str(e) + error = str(ex) response = {'result': 'success' if result else 'error'} @@ -185,91 +59,227 @@ class ProviderTokenValidateApi(Resource): return response -class ProviderSystemApi(Resource): +class ModelProviderUpdateApi(Resource): @setup_required @login_required @account_initialization_required - def put(self, provider): - if provider not in [p.value for p in ProviderName]: - abort(404) - - parser = reqparse.RequestParser() - parser.add_argument('is_enabled', type=bool, required=True, location='json') - args = parser.parse_args() - - tenant = current_user.current_tenant_id - - provider_model = Provider.query.filter_by(tenant_id=tenant.id, provider_name=provider).first() - - if provider_model and provider_model.provider_type == ProviderType.SYSTEM.value: - provider_model.is_valid = args['is_enabled'] - db.session.commit() - elif not provider_model: - if provider == ProviderName.OPENAI.value: - quota_limit = current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'] - elif provider == ProviderName.ANTHROPIC.value: - quota_limit = current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'] - else: - quota_limit = 0 - - ProviderService.create_system_provider( - tenant, - provider, - quota_limit, - args['is_enabled'] - ) - else: - abort(403) - - return {'result': 'success'} - - @setup_required - @login_required - @account_initialization_required - def get(self, provider): - if provider not in [p.value for p in ProviderName]: - abort(404) - - # The role of the current user in the ta table must be admin or owner + def post(self, provider_name: str): if current_user.current_tenant.current_role not in ['admin', 'owner']: raise Forbidden() - provider_model = db.session.query(Provider).filter(Provider.tenant_id == current_user.current_tenant_id, - Provider.provider_name == provider, - Provider.provider_type == ProviderType.SYSTEM.value).first() + parser = reqparse.RequestParser() + parser.add_argument('config', type=dict, required=True, nullable=False, location='json') + args = parser.parse_args() - system_model = None - if provider_model: - system_model = { - 'result': 'success', - 'provider': { - 'provider_name': provider_model.provider_name, - 'provider_type': provider_model.provider_type, - 'is_valid': provider_model.is_valid, - 'last_used': provider_model.last_used, - 'is_enabled': provider_model.is_enabled, - 'quota_type': provider_model.quota_type, - 'quota_limit': provider_model.quota_limit, - 'quota_used': provider_model.quota_used - } + provider_service = ProviderService() + + try: + provider_service.save_custom_provider_config( + tenant_id=current_user.current_tenant_id, + provider_name=provider_name, + config=args['config'] + ) + except CredentialsValidateFailedError as ex: + raise ValueError(str(ex)) + + return {'result': 'success'}, 201 + + @setup_required + @login_required + @account_initialization_required + def delete(self, provider_name: str): + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + provider_service = ProviderService() + provider_service.delete_custom_provider( + tenant_id=current_user.current_tenant_id, + provider_name=provider_name + ) + + return {'result': 'success'}, 204 + + +class ModelProviderModelValidateApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def post(self, provider_name: str): + parser = reqparse.RequestParser() + parser.add_argument('model_name', type=str, required=True, nullable=False, location='json') + parser.add_argument('model_type', type=str, required=True, nullable=False, + choices=['text-generation', 'embeddings', 'speech2text'], location='json') + parser.add_argument('config', type=dict, required=True, nullable=False, location='json') + args = parser.parse_args() + + provider_service = ProviderService() + + result = True + error = None + + try: + provider_service.custom_provider_model_config_validate( + provider_name=provider_name, + model_name=args['model_name'], + model_type=args['model_type'], + config=args['config'] + ) + except CredentialsValidateFailedError as ex: + result = False + error = str(ex) + + response = {'result': 'success' if result else 'error'} + + if not result: + response['error'] = error + + return response + + +class ModelProviderModelUpdateApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def post(self, provider_name: str): + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument('model_name', type=str, required=True, nullable=False, location='json') + parser.add_argument('model_type', type=str, required=True, nullable=False, + choices=['text-generation', 'embeddings', 'speech2text'], location='json') + parser.add_argument('config', type=dict, required=True, nullable=False, location='json') + args = parser.parse_args() + + provider_service = ProviderService() + + try: + provider_service.add_or_save_custom_provider_model_config( + tenant_id=current_user.current_tenant_id, + provider_name=provider_name, + model_name=args['model_name'], + model_type=args['model_type'], + config=args['config'] + ) + except CredentialsValidateFailedError as ex: + raise ValueError(str(ex)) + + return {'result': 'success'}, 200 + + @setup_required + @login_required + @account_initialization_required + def delete(self, provider_name: str): + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument('model_name', type=str, required=True, nullable=False, location='args') + parser.add_argument('model_type', type=str, required=True, nullable=False, + choices=['text-generation', 'embeddings', 'speech2text'], location='args') + args = parser.parse_args() + + provider_service = ProviderService() + provider_service.delete_custom_provider_model( + tenant_id=current_user.current_tenant_id, + provider_name=provider_name, + model_name=args['model_name'], + model_type=args['model_type'] + ) + + return {'result': 'success'}, 204 + + +class PreferredProviderTypeUpdateApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def post(self, provider_name: str): + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False, + choices=['system', 'custom'], location='json') + args = parser.parse_args() + + provider_service = ProviderService() + provider_service.switch_preferred_provider( + tenant_id=current_user.current_tenant_id, + provider_name=provider_name, + preferred_provider_type=args['preferred_provider_type'] + ) + + return {'result': 'success'} + + +class ModelProviderModelParameterRuleApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def get(self, provider_name: str): + parser = reqparse.RequestParser() + parser.add_argument('model_name', type=str, required=True, nullable=False, location='args') + args = parser.parse_args() + + provider_service = ProviderService() + + try: + parameter_rules = provider_service.get_model_parameter_rules( + tenant_id=current_user.current_tenant_id, + model_provider_name=provider_name, + model_name=args['model_name'], + model_type='text-generation' + ) + except LLMBadRequestError: + raise ProviderNotInitializeError( + f"Current Text Generation Model is invalid. Please switch to the available model.") + + rules = { + k: { + 'enabled': v.enabled, + 'min': v.min, + 'max': v.max, + 'default': v.default } - else: - abort(404) + for k, v in vars(parameter_rules).items() + } - return system_model + return rules -api.add_resource(ProviderTokenApi, '/providers//token', - endpoint='current_providers_token') # Deprecated -api.add_resource(ProviderTokenValidateApi, '/providers//token-validate', - endpoint='current_providers_token_validate') # Deprecated +class ModelProviderPaymentCheckoutUrlApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider_name: str): + provider_service = ProviderCheckoutService() + provider_checkout = provider_service.create_checkout( + tenant_id=current_user.current_tenant_id, + provider_name=provider_name, + account=current_user + ) -api.add_resource(ProviderTokenApi, '/workspaces/current/providers//token', - endpoint='workspaces_current_providers_token') # PUT for updating provider token -api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers//token-validate', - endpoint='workspaces_current_providers_token_validate') # POST for validating provider token + return { + 'url': provider_checkout.get_checkout_url() + } -api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list -api.add_resource(ProviderSystemApi, '/workspaces/current/providers//system', - endpoint='workspaces_current_providers_system') # GET for getting provider quota, PUT for updating provider status + +api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers') +api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers//validate') +api.add_resource(ModelProviderUpdateApi, '/workspaces/current/model-providers/') +api.add_resource(ModelProviderModelValidateApi, + '/workspaces/current/model-providers//models/validate') +api.add_resource(ModelProviderModelUpdateApi, + '/workspaces/current/model-providers//models') +api.add_resource(PreferredProviderTypeUpdateApi, + '/workspaces/current/model-providers//preferred-provider-type') +api.add_resource(ModelProviderModelParameterRuleApi, + '/workspaces/current/model-providers//models/parameter-rules') +api.add_resource(ModelProviderPaymentCheckoutUrlApi, + '/workspaces/current/model-providers//checkout-url') diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py new file mode 100644 index 0000000000..33f8edfed7 --- /dev/null +++ b/api/controllers/console/workspace/models.py @@ -0,0 +1,108 @@ +from flask_login import login_required, current_user +from flask_restful import Resource, reqparse + +from controllers.console import api +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from core.model_providers.model_provider_factory import ModelProviderFactory +from core.model_providers.models.entity.model_params import ModelType +from models.provider import ProviderType +from services.provider_service import ProviderService + + +class DefaultModelApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def get(self): + parser = reqparse.RequestParser() + parser.add_argument('model_type', type=str, required=True, nullable=False, + choices=['text-generation', 'embeddings', 'speech2text'], location='args') + args = parser.parse_args() + + tenant_id = current_user.current_tenant_id + + provider_service = ProviderService() + default_model = provider_service.get_default_model_of_model_type( + tenant_id=tenant_id, + model_type=args['model_type'] + ) + + if not default_model: + return None + + model_provider = ModelProviderFactory.get_preferred_model_provider( + tenant_id, + default_model.provider_name + ) + + if not model_provider: + return { + 'model_name': default_model.model_name, + 'model_type': default_model.model_type, + 'model_provider': { + 'provider_name': default_model.provider_name + } + } + + provider = model_provider.provider + rst = { + 'model_name': default_model.model_name, + 'model_type': default_model.model_type, + 'model_provider': { + 'provider_name': provider.provider_name, + 'provider_type': provider.provider_type + } + } + + model_provider_rules = ModelProviderFactory.get_provider_rule(default_model.provider_name) + if provider.provider_type == ProviderType.SYSTEM.value: + rst['model_provider']['quota_type'] = provider.quota_type + rst['model_provider']['quota_unit'] = model_provider_rules['system_config']['quota_unit'] + rst['model_provider']['quota_limit'] = provider.quota_limit + rst['model_provider']['quota_used'] = provider.quota_used + + return rst + + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument('model_name', type=str, required=True, nullable=False, location='json') + parser.add_argument('model_type', type=str, required=True, nullable=False, + choices=['text-generation', 'embeddings', 'speech2text'], location='json') + parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json') + args = parser.parse_args() + + provider_service = ProviderService() + provider_service.update_default_model_of_model_type( + tenant_id=current_user.current_tenant_id, + model_type=args['model_type'], + provider_name=args['provider_name'], + model_name=args['model_name'] + ) + + return {'result': 'success'} + + +class ValidModelApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def get(self, model_type): + ModelType.value_of(model_type) + + provider_service = ProviderService() + valid_models = provider_service.get_valid_model_list( + tenant_id=current_user.current_tenant_id, + model_type=model_type + ) + + return valid_models + + +api.add_resource(DefaultModelApi, '/workspaces/current/default-model') +api.add_resource(ValidModelApi, '/workspaces/current/models/model-type/') diff --git a/api/controllers/console/workspace/providers.py b/api/controllers/console/workspace/providers.py new file mode 100644 index 0000000000..b6f9c3c697 --- /dev/null +++ b/api/controllers/console/workspace/providers.py @@ -0,0 +1,130 @@ +# -*- coding:utf-8 -*- +from flask_login import login_required, current_user +from flask_restful import Resource, reqparse +from werkzeug.exceptions import Forbidden + +from controllers.console import api +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from core.model_providers.providers.base import CredentialsValidateFailedError +from models.provider import ProviderType +from services.provider_service import ProviderService + + +class ProviderListApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def get(self): + tenant_id = current_user.current_tenant_id + + """ + If the type is AZURE_OPENAI, decode and return the four fields of azure_api_type, azure_api_version:, + azure_api_base, azure_api_key as an object, where azure_api_key displays the first 6 bits in plaintext, and the + rest is replaced by * and the last two bits are displayed in plaintext + + If the type is other, decode and return the Token field directly, the field displays the first 6 bits in + plaintext, the rest is replaced by * and the last two bits are displayed in plaintext + """ + + provider_service = ProviderService() + provider_info_list = provider_service.get_provider_list(tenant_id) + + provider_list = [ + { + 'provider_name': p['provider_name'], + 'provider_type': p['provider_type'], + 'is_valid': p['is_valid'], + 'last_used': p['last_used'], + 'is_enabled': p['is_valid'], + **({ + 'quota_type': p['quota_type'], + 'quota_limit': p['quota_limit'], + 'quota_used': p['quota_used'] + } if p['provider_type'] == ProviderType.SYSTEM.value else {}), + 'token': (p['config'] if p['provider_name'] != 'openai' else p['config']['openai_api_key']) + if p['config'] else None + } + for name, provider_info in provider_info_list.items() + for p in provider_info['providers'] + ] + + return provider_list + + +class ProviderTokenApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def post(self, provider): + # The role of the current user in the ta table must be admin or owner + if current_user.current_tenant.current_role not in ['admin', 'owner']: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument('token', required=True, nullable=False, location='json') + args = parser.parse_args() + + if provider == 'openai': + args['token'] = { + 'openai_api_key': args['token'] + } + + provider_service = ProviderService() + try: + provider_service.save_custom_provider_config( + tenant_id=current_user.current_tenant_id, + provider_name=provider, + config=args['token'] + ) + except CredentialsValidateFailedError as ex: + raise ValueError(str(ex)) + + return {'result': 'success'}, 201 + + +class ProviderTokenValidateApi(Resource): + + @setup_required + @login_required + @account_initialization_required + def post(self, provider): + parser = reqparse.RequestParser() + parser.add_argument('token', required=True, nullable=False, location='json') + args = parser.parse_args() + + provider_service = ProviderService() + + if provider == 'openai': + args['token'] = { + 'openai_api_key': args['token'] + } + + result = True + error = None + + try: + provider_service.custom_provider_config_validate( + provider_name=provider, + config=args['token'] + ) + except CredentialsValidateFailedError as ex: + result = False + error = str(ex) + + response = {'result': 'success' if result else 'error'} + + if not result: + response['error'] = error + + return response + + +api.add_resource(ProviderTokenApi, '/workspaces/current/providers//token', + endpoint='workspaces_current_providers_token') # PUT for updating provider token +api.add_resource(ProviderTokenValidateApi, '/workspaces/current/providers//token-validate', + endpoint='workspaces_current_providers_token_validate') # POST for validating provider token + +api.add_resource(ProviderListApi, '/workspaces/current/providers') # GET for getting providers list diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 2ad457c79b..8b0237eb25 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -30,7 +30,7 @@ tenant_fields = { 'created_at': TimestampField, 'role': fields.String, 'providers': fields.List(fields.Nested(provider_fields)), - 'in_trail': fields.Boolean, + 'in_trial': fields.Boolean, 'trial_end_reason': fields.String, } diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 7c185ec633..481133367e 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -4,8 +4,6 @@ from flask_restful import fields, marshal_with from controllers.service_api import api from controllers.service_api.wraps import AppApiResource -from core.llm.llm_builder import LLMBuilder -from models.provider import ProviderName from models.model import App @@ -35,13 +33,12 @@ class AppParameterApi(AppApiResource): def get(self, app_model: App, end_user): """Retrieve app parameters.""" app_model_config = app_model.app_model_config - provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1') return { 'opening_statement': app_model_config.opening_statement, 'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, - 'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False }, + 'speech_to_text': app_model_config.speech_to_text_dict, 'more_like_this': app_model_config.more_like_this_dict, 'user_input_form': app_model_config.user_input_form_list } diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 470afc6b42..4b03de0637 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -9,7 +9,7 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, UnsupportedAudioTypeError, \ ProviderNotSupportSpeechToTextError from controllers.service_api.wraps import AppApiResource -from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \ +from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \ LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from models.model import App, AppModelConfig from services.audio_service import AudioService diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 448c408bce..2b802dc71c 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -14,7 +14,7 @@ from controllers.service_api.app.error import AppUnavailableError, ProviderNotIn ProviderModelCurrentlyNotSupportError from controllers.service_api.wraps import AppApiResource from core.conversation_message_task import PubHandler -from core.llm.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \ +from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMAPIUnavailableError, LLMAPIConnectionError, \ LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from libs.helper import uuid_value from services.completion_service import CompletionService diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index e00de0f9a1..7cb4d49897 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -11,7 +11,7 @@ from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \ DatasetNotInitedError from controllers.service_api.wraps import DatasetApiResource -from core.llm.error import ProviderTokenNotInitError +from core.model_providers.error import ProviderTokenNotInitError from extensions.ext_database import db from extensions.ext_storage import storage from models.model import UploadFile diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index f4e268941e..dd66707345 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -4,8 +4,6 @@ from flask_restful import marshal_with, fields from controllers.web import api from controllers.web.wraps import WebApiResource -from core.llm.llm_builder import LLMBuilder -from models.provider import ProviderName from models.model import App @@ -34,13 +32,12 @@ class AppParameterApi(WebApiResource): def get(self, app_model: App, end_user): """Retrieve app parameters.""" app_model_config = app_model.app_model_config - provider_name = LLMBuilder.get_default_provider(app_model.tenant_id, 'whisper-1') return { 'opening_statement': app_model_config.opening_statement, 'suggested_questions': app_model_config.suggested_questions_list, 'suggested_questions_after_answer': app_model_config.suggested_questions_after_answer_dict, - 'speech_to_text': app_model_config.speech_to_text_dict if provider_name == ProviderName.OPENAI.value else { 'enabled': False }, + 'speech_to_text': app_model_config.speech_to_text_dict, 'more_like_this': app_model_config.more_like_this_dict, 'user_input_form': app_model_config.user_input_form_list } diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 3e3fe3a28d..b3272de1c7 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -10,7 +10,7 @@ from controllers.web.error import AppUnavailableError, ProviderNotInitializeErro ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError, NoAudioUploadedError, AudioTooLargeError, \ UnsupportedAudioTypeError, ProviderNotSupportSpeechToTextError from controllers.web.wraps import WebApiResource -from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ +from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from services.audio_service import AudioService from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, \ diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index db2f770e5a..4325362a5b 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -14,7 +14,7 @@ from controllers.web.error import AppUnavailableError, ConversationCompletedErro ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError from controllers.web.wraps import WebApiResource from core.conversation_message_task import PubHandler -from core.llm.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ +from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \ LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from libs.helper import uuid_value from services.completion_service import CompletionService diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 3d978a1099..f25f1e5af9 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -14,7 +14,7 @@ from controllers.web.error import NotChatAppError, CompletionRequestError, Provi AppMoreLikeThisDisabledError, NotCompletionAppError, AppSuggestedQuestionsAfterAnswerDisabledError, \ ProviderQuotaExceededError, ProviderModelCurrentlyNotSupportError from controllers.web.wraps import WebApiResource -from core.llm.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ +from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \ ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError from libs.helper import uuid_value, TimestampField from services.completion_service import CompletionService diff --git a/api/core/__init__.py b/api/core/__init__.py index 2dc9a9e869..e69de29bb2 100644 --- a/api/core/__init__.py +++ b/api/core/__init__.py @@ -1,36 +0,0 @@ -import os -from typing import Optional - -import langchain -from flask import Flask -from pydantic import BaseModel - -from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler -from core.prompt.prompt_template import OneLineFormatter - - -class HostedOpenAICredential(BaseModel): - api_key: str - - -class HostedAnthropicCredential(BaseModel): - api_key: str - - -class HostedLLMCredentials(BaseModel): - openai: Optional[HostedOpenAICredential] = None - anthropic: Optional[HostedAnthropicCredential] = None - - -hosted_llm_credentials = HostedLLMCredentials() - - -def init_app(app: Flask): - if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': - langchain.verbose = True - - if app.config.get("OPENAI_API_KEY"): - hosted_llm_credentials.openai = HostedOpenAICredential(api_key=app.config.get("OPENAI_API_KEY")) - - if app.config.get("ANTHROPIC_API_KEY"): - hosted_llm_credentials.anthropic = HostedAnthropicCredential(api_key=app.config.get("ANTHROPIC_API_KEY")) diff --git a/api/core/agent/agent/calc_token_mixin.py b/api/core/agent/agent/calc_token_mixin.py index a07b9f2ad7..97d2b7740f 100644 --- a/api/core/agent/agent/calc_token_mixin.py +++ b/api/core/agent/agent/calc_token_mixin.py @@ -1,20 +1,17 @@ -from typing import cast, List +from typing import List -from langchain import OpenAI -from langchain.base_language import BaseLanguageModel -from langchain.chat_models.openai import ChatOpenAI from langchain.schema import BaseMessage -from core.constant import llm_constant +from core.model_providers.models.entity.message import to_prompt_messages +from core.model_providers.models.llm.base import BaseLLM class CalcTokenMixin: - def get_num_tokens_from_messages(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int: - llm = cast(ChatOpenAI, llm) - return llm.get_num_tokens_from_messages(messages) + def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int: + return model_instance.get_num_tokens(to_prompt_messages(messages)) - def get_message_rest_tokens(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> int: + def get_message_rest_tokens(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int: """ Got the rest tokens available for the model after excluding messages tokens and completion max tokens @@ -22,10 +19,9 @@ class CalcTokenMixin: :param messages: :return: """ - llm = cast(ChatOpenAI, llm) - llm_max_tokens = llm_constant.max_context_token_length[llm.model_name] - completion_max_tokens = llm.max_tokens - used_tokens = self.get_num_tokens_from_messages(llm, messages, **kwargs) + llm_max_tokens = model_instance.model_rules.max_tokens.max + completion_max_tokens = model_instance.model_kwargs.max_tokens + used_tokens = self.get_num_tokens_from_messages(model_instance, messages, **kwargs) rest_tokens = llm_max_tokens - completion_max_tokens - used_tokens return rest_tokens diff --git a/api/core/agent/agent/multi_dataset_router_agent.py b/api/core/agent/agent/multi_dataset_router_agent.py index 34dacaee3d..c23bf24496 100644 --- a/api/core/agent/agent/multi_dataset_router_agent.py +++ b/api/core/agent/agent/multi_dataset_router_agent.py @@ -4,9 +4,11 @@ from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import Callbacks from langchain.prompts.chat import BaseMessagePromptTemplate -from langchain.schema import AgentAction, AgentFinish, BaseLanguageModel, SystemMessage +from langchain.schema import AgentAction, AgentFinish, SystemMessage +from langchain.schema.language_model import BaseLanguageModel from langchain.tools import BaseTool +from core.model_providers.models.llm.base import BaseLLM from core.tool.dataset_retriever_tool import DatasetRetrieverTool @@ -14,6 +16,12 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): """ An Multi Dataset Retrieve Agent driven by Router. """ + model_instance: BaseLLM + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True def should_use_agent(self, query: str): """ diff --git a/api/core/agent/agent/openai_function_call.py b/api/core/agent/agent/openai_function_call.py index 090d35d975..3966525e24 100644 --- a/api/core/agent/agent/openai_function_call.py +++ b/api/core/agent/agent/openai_function_call.py @@ -6,7 +6,8 @@ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \ from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import Callbacks from langchain.prompts.chat import BaseMessagePromptTemplate -from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel +from langchain.schema import AgentAction, AgentFinish, SystemMessage +from langchain.schema.language_model import BaseLanguageModel from langchain.tools import BaseTool from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError @@ -84,7 +85,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio # summarize messages if rest_tokens < 0 try: - messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions) + messages = self.summarize_messages_if_needed(messages, functions=self.functions) except ExceededLLMTokensLimitError as e: return AgentFinish(return_values={"output": str(e)}, log=str(e)) diff --git a/api/core/agent/agent/openai_function_call_summarize_mixin.py b/api/core/agent/agent/openai_function_call_summarize_mixin.py index 0436de2078..a4745e772d 100644 --- a/api/core/agent/agent/openai_function_call_summarize_mixin.py +++ b/api/core/agent/agent/openai_function_call_summarize_mixin.py @@ -3,20 +3,28 @@ from typing import cast, List from langchain.chat_models import ChatOpenAI from langchain.chat_models.openai import _convert_message_to_dict from langchain.memory.summary import SummarizerMixin -from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage, BaseLanguageModel +from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage +from langchain.schema.language_model import BaseLanguageModel from pydantic import BaseModel from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin +from core.model_providers.models.llm.base import BaseLLM class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin): moving_summary_buffer: str = "" moving_summary_index: int = 0 summary_llm: BaseLanguageModel + model_instance: BaseLLM - def summarize_messages_if_needed(self, llm: BaseLanguageModel, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]: + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]: # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0 - rest_tokens = self.get_message_rest_tokens(llm, messages, **kwargs) + rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs) rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens if rest_tokens >= 0: return messages diff --git a/api/core/agent/agent/openai_multi_function_call.py b/api/core/agent/agent/openai_multi_function_call.py index 1524fc6975..9780377181 100644 --- a/api/core/agent/agent/openai_multi_function_call.py +++ b/api/core/agent/agent/openai_multi_function_call.py @@ -6,7 +6,8 @@ from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFuncti from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import Callbacks from langchain.prompts.chat import BaseMessagePromptTemplate -from langchain.schema import AgentAction, AgentFinish, SystemMessage, BaseLanguageModel +from langchain.schema import AgentAction, AgentFinish, SystemMessage +from langchain.schema.language_model import BaseLanguageModel from langchain.tools import BaseTool from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError @@ -84,7 +85,7 @@ class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, Ope # summarize messages if rest_tokens < 0 try: - messages = self.summarize_messages_if_needed(self.llm, messages, functions=self.functions) + messages = self.summarize_messages_if_needed(messages, functions=self.functions) except ExceededLLMTokensLimitError as e: return AgentFinish(return_values={"output": str(e)}, log=str(e)) diff --git a/api/core/agent/agent/structed_multi_dataset_router_agent.py b/api/core/agent/agent/structed_multi_dataset_router_agent.py new file mode 100644 index 0000000000..ac1748611d --- /dev/null +++ b/api/core/agent/agent/structed_multi_dataset_router_agent.py @@ -0,0 +1,162 @@ +import re +from typing import List, Tuple, Any, Union, Sequence, Optional, cast + +from langchain import BasePromptTemplate +from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent +from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE +from langchain.base_language import BaseLanguageModel +from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.manager import Callbacks +from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate +from langchain.schema import AgentAction, AgentFinish, OutputParserException +from langchain.tools import BaseTool +from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX + +from core.model_providers.models.llm.base import BaseLLM +from core.tool.dataset_retriever_tool import DatasetRetrieverTool + +FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). +The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. +Valid "action" values: "Final Answer" or {tool_names} + +Provide only ONE action per $JSON_BLOB, as shown: + +``` +{{{{ + "action": $TOOL_NAME, + "action_input": $INPUT +}}}} +``` + +Follow this format: + +Question: input question to answer +Thought: consider previous and subsequent steps +Action: +``` +$JSON_BLOB +``` +Observation: action result +... (repeat Thought/Action/Observation N times) +Thought: I know what to respond +Action: +``` +{{{{ + "action": "Final Answer", + "action_input": "Final response to human" +}}}} +```""" + + +class StructuredMultiDatasetRouterAgent(StructuredChatAgent): + model_instance: BaseLLM + dataset_tools: Sequence[BaseTool] + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + def should_use_agent(self, query: str): + """ + return should use agent + Using the ReACT mode to determine whether an agent is needed is costly, + so it's better to just use an Agent for reasoning, which is cheaper. + + :param query: + :return: + """ + return True + + def plan( + self, + intermediate_steps: List[Tuple[AgentAction, str]], + callbacks: Callbacks = None, + **kwargs: Any, + ) -> Union[AgentAction, AgentFinish]: + """Given input, decided what to do. + + Args: + intermediate_steps: Steps the LLM has taken to date, + along with observations + callbacks: Callbacks to run. + **kwargs: User inputs. + + Returns: + Action specifying what tool to use. + """ + if len(self.dataset_tools) == 0: + return AgentFinish(return_values={"output": ''}, log='') + elif len(self.dataset_tools) == 1: + tool = next(iter(self.dataset_tools)) + tool = cast(DatasetRetrieverTool, tool) + rst = tool.run(tool_input={'dataset_id': tool.dataset_id, 'query': kwargs['input']}) + return AgentFinish(return_values={"output": rst}, log=rst) + + full_inputs = self.get_full_inputs(intermediate_steps, **kwargs) + full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs) + + try: + return self.output_parser.parse(full_output) + except OutputParserException: + return AgentFinish({"output": "I'm sorry, the answer of model is invalid, " + "I don't know how to respond to that."}, "") + @classmethod + def create_prompt( + cls, + tools: Sequence[BaseTool], + prefix: str = PREFIX, + suffix: str = SUFFIX, + human_message_template: str = HUMAN_MESSAGE_TEMPLATE, + format_instructions: str = FORMAT_INSTRUCTIONS, + input_variables: Optional[List[str]] = None, + memory_prompts: Optional[List[BasePromptTemplate]] = None, + ) -> BasePromptTemplate: + tool_strings = [] + for tool in tools: + args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args))) + tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}") + formatted_tools = "\n".join(tool_strings) + unique_tool_names = set(tool.name for tool in tools) + tool_names = ", ".join('"' + name + '"' for name in unique_tool_names) + format_instructions = format_instructions.format(tool_names=tool_names) + template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix]) + if input_variables is None: + input_variables = ["input", "agent_scratchpad"] + _memory_prompts = memory_prompts or [] + messages = [ + SystemMessagePromptTemplate.from_template(template), + *_memory_prompts, + HumanMessagePromptTemplate.from_template(human_message_template), + ] + return ChatPromptTemplate(input_variables=input_variables, messages=messages) + + @classmethod + def from_llm_and_tools( + cls, + llm: BaseLanguageModel, + tools: Sequence[BaseTool], + callback_manager: Optional[BaseCallbackManager] = None, + output_parser: Optional[AgentOutputParser] = None, + prefix: str = PREFIX, + suffix: str = SUFFIX, + human_message_template: str = HUMAN_MESSAGE_TEMPLATE, + format_instructions: str = FORMAT_INSTRUCTIONS, + input_variables: Optional[List[str]] = None, + memory_prompts: Optional[List[BasePromptTemplate]] = None, + **kwargs: Any, + ) -> Agent: + return super().from_llm_and_tools( + llm=llm, + tools=tools, + callback_manager=callback_manager, + output_parser=output_parser, + prefix=prefix, + suffix=suffix, + human_message_template=human_message_template, + format_instructions=format_instructions, + input_variables=input_variables, + memory_prompts=memory_prompts, + dataset_tools=tools, + **kwargs, + ) diff --git a/api/core/agent/agent/structured_chat.py b/api/core/agent/agent/structured_chat.py index 8c3472845b..96960cf802 100644 --- a/api/core/agent/agent/structured_chat.py +++ b/api/core/agent/agent/structured_chat.py @@ -14,7 +14,7 @@ from langchain.tools import BaseTool from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError - +from core.model_providers.models.llm.base import BaseLLM FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. @@ -53,6 +53,12 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): moving_summary_buffer: str = "" moving_summary_index: int = 0 summary_llm: BaseLanguageModel + model_instance: BaseLLM + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True def should_use_agent(self, query: str): """ @@ -89,7 +95,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): if prompts: messages = prompts[0].to_messages() - rest_tokens = self.get_message_rest_tokens(self.llm_chain.llm, messages) + rest_tokens = self.get_message_rest_tokens(self.model_instance, messages) if rest_tokens < 0: full_inputs = self.summarize_messages(intermediate_steps, **kwargs) diff --git a/api/core/agent/agent_executor.py b/api/core/agent/agent_executor.py index da36533fd2..f345e631d3 100644 --- a/api/core/agent/agent_executor.py +++ b/api/core/agent/agent_executor.py @@ -3,7 +3,6 @@ import logging from typing import Union, Optional from langchain.agents import BaseSingleActionAgent, BaseMultiActionAgent -from langchain.base_language import BaseLanguageModel from langchain.callbacks.manager import Callbacks from langchain.memory.chat_memory import BaseChatMemory from langchain.tools import BaseTool @@ -13,14 +12,17 @@ from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser +from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent from langchain.agents import AgentExecutor as LCAgentExecutor +from core.model_providers.models.llm.base import BaseLLM from core.tool.dataset_retriever_tool import DatasetRetrieverTool class PlanningStrategy(str, enum.Enum): ROUTER = 'router' + REACT_ROUTER = 'react_router' REACT = 'react' FUNCTION_CALL = 'function_call' MULTI_FUNCTION_CALL = 'multi_function_call' @@ -28,10 +30,9 @@ class PlanningStrategy(str, enum.Enum): class AgentConfiguration(BaseModel): strategy: PlanningStrategy - llm: BaseLanguageModel + model_instance: BaseLLM tools: list[BaseTool] - summary_llm: BaseLanguageModel - dataset_llm: BaseLanguageModel + summary_model_instance: BaseLLM memory: Optional[BaseChatMemory] = None callbacks: Callbacks = None max_iterations: int = 6 @@ -60,36 +61,49 @@ class AgentExecutor: def _init_agent(self) -> Union[BaseSingleActionAgent | BaseMultiActionAgent]: if self.configuration.strategy == PlanningStrategy.REACT: agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools( - llm=self.configuration.llm, + model_instance=self.configuration.model_instance, + llm=self.configuration.model_instance.client, tools=self.configuration.tools, output_parser=StructuredChatOutputParser(), - summary_llm=self.configuration.summary_llm, + summary_llm=self.configuration.summary_model_instance.client, verbose=True ) elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL: agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools( - llm=self.configuration.llm, + model_instance=self.configuration.model_instance, + llm=self.configuration.model_instance.client, tools=self.configuration.tools, extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory - summary_llm=self.configuration.summary_llm, + summary_llm=self.configuration.summary_model_instance.client, verbose=True ) elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL: agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools( - llm=self.configuration.llm, + model_instance=self.configuration.model_instance, + llm=self.configuration.model_instance.client, tools=self.configuration.tools, extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory - summary_llm=self.configuration.summary_llm, + summary_llm=self.configuration.summary_model_instance.client, verbose=True ) elif self.configuration.strategy == PlanningStrategy.ROUTER: self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] agent = MultiDatasetRouterAgent.from_llm_and_tools( - llm=self.configuration.dataset_llm, + model_instance=self.configuration.model_instance, + llm=self.configuration.model_instance.client, tools=self.configuration.tools, extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, verbose=True ) + elif self.configuration.strategy == PlanningStrategy.REACT_ROUTER: + self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)] + agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools( + model_instance=self.configuration.model_instance, + llm=self.configuration.model_instance.client, + tools=self.configuration.tools, + output_parser=StructuredChatOutputParser(), + verbose=True + ) else: raise NotImplementedError(f"Unknown Agent Strategy: {self.configuration.strategy}") diff --git a/api/core/callback_handler/agent_loop_gather_callback_handler.py b/api/core/callback_handler/agent_loop_gather_callback_handler.py index bb81771c4c..64fb1bf108 100644 --- a/api/core/callback_handler/agent_loop_gather_callback_handler.py +++ b/api/core/callback_handler/agent_loop_gather_callback_handler.py @@ -10,15 +10,16 @@ from langchain.schema import AgentAction, AgentFinish, LLMResult, ChatGeneration from core.callback_handler.entity.agent_loop import AgentLoop from core.conversation_message_task import ConversationMessageTask +from core.model_providers.models.llm.base import BaseLLM class AgentLoopGatherCallbackHandler(BaseCallbackHandler): """Callback Handler that prints to std out.""" raise_error: bool = True - def __init__(self, model_name, conversation_message_task: ConversationMessageTask) -> None: + def __init__(self, model_instant: BaseLLM, conversation_message_task: ConversationMessageTask) -> None: """Initialize callback handler.""" - self.model_name = model_name + self.model_instant = model_instant self.conversation_message_task = conversation_message_task self._agent_loops = [] self._current_loop = None @@ -152,7 +153,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at self.conversation_message_task.on_agent_end( - self._message_agent_thought, self.model_name, self._current_loop + self._message_agent_thought, self.model_instant, self._current_loop ) self._agent_loops.append(self._current_loop) @@ -183,7 +184,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): ) self.conversation_message_task.on_agent_end( - self._message_agent_thought, self.model_name, self._current_loop + self._message_agent_thought, self.model_instant, self._current_loop ) self._agent_loops.append(self._current_loop) diff --git a/api/core/callback_handler/llm_callback_handler.py b/api/core/callback_handler/llm_callback_handler.py index 03f8ba2625..89b498c3e8 100644 --- a/api/core/callback_handler/llm_callback_handler.py +++ b/api/core/callback_handler/llm_callback_handler.py @@ -3,18 +3,20 @@ import time from typing import Any, Dict, List, Union from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import LLMResult, BaseMessage, BaseLanguageModel +from langchain.schema import LLMResult, BaseMessage from core.callback_handler.entity.llm_message import LLMMessage from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException +from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage +from core.model_providers.models.llm.base import BaseLLM class LLMCallbackHandler(BaseCallbackHandler): raise_error: bool = True - def __init__(self, llm: BaseLanguageModel, + def __init__(self, model_instance: BaseLLM, conversation_message_task: ConversationMessageTask): - self.llm = llm + self.model_instance = model_instance self.llm_message = LLMMessage() self.start_at = None self.conversation_message_task = conversation_message_task @@ -46,7 +48,7 @@ class LLMCallbackHandler(BaseCallbackHandler): }) self.llm_message.prompt = real_prompts - self.llm_message.prompt_tokens = self.llm.get_num_tokens_from_messages(messages[0]) + self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0])) def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any @@ -58,7 +60,7 @@ class LLMCallbackHandler(BaseCallbackHandler): "text": prompts[0] }] - self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0]) + self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])]) def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: end_at = time.perf_counter() @@ -68,7 +70,7 @@ class LLMCallbackHandler(BaseCallbackHandler): self.conversation_message_task.append_message_text(response.generations[0][0].text) self.llm_message.completion = response.generations[0][0].text - self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion) + self.llm_message.completion_tokens = self.model_instance.get_num_tokens([PromptMessage(content=self.llm_message.completion)]) self.conversation_message_task.save_message(self.llm_message) @@ -89,7 +91,9 @@ class LLMCallbackHandler(BaseCallbackHandler): if self.conversation_message_task.streaming: end_at = time.perf_counter() self.llm_message.latency = end_at - self.start_at - self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion) + self.llm_message.completion_tokens = self.model_instance.get_num_tokens( + [PromptMessage(content=self.llm_message.completion)] + ) self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True) else: logging.error(error) diff --git a/api/core/callback_handler/main_chain_gather_callback_handler.py b/api/core/callback_handler/main_chain_gather_callback_handler.py index e03ecd79f5..fc0a65e42a 100644 --- a/api/core/callback_handler/main_chain_gather_callback_handler.py +++ b/api/core/callback_handler/main_chain_gather_callback_handler.py @@ -5,9 +5,7 @@ from typing import Any, Dict, Union from langchain.callbacks.base import BaseCallbackHandler -from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler from core.callback_handler.entity.chain_result import ChainResult -from core.constant import llm_constant from core.conversation_message_task import ConversationMessageTask diff --git a/api/core/completion.py b/api/core/completion.py index 486a3b7257..28d0cec8d5 100644 --- a/api/core/completion.py +++ b/api/core/completion.py @@ -2,27 +2,19 @@ import logging import re from typing import Optional, List, Union, Tuple -from langchain.base_language import BaseLanguageModel -from langchain.callbacks.base import BaseCallbackHandler -from langchain.chat_models.base import BaseChatModel -from langchain.llms import BaseLLM -from langchain.schema import BaseMessage, HumanMessage +from langchain.schema import BaseMessage from requests.exceptions import ChunkedEncodingError from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler -from core.constant import llm_constant from core.callback_handler.llm_callback_handler import LLMCallbackHandler -from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, \ - DifyStdOutCallbackHandler from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException -from core.llm.error import LLMBadRequestError -from core.llm.fake import FakeLLM -from core.llm.llm_builder import LLMBuilder -from core.llm.streamable_chat_open_ai import StreamableChatOpenAI -from core.llm.streamable_open_ai import StreamableOpenAI +from core.model_providers.error import LLMBadRequestError from core.memory.read_only_conversation_token_db_buffer_shared_memory import \ ReadOnlyConversationTokenDBBufferSharedMemory +from core.model_providers.model_factory import ModelFactory +from core.model_providers.models.entity.message import PromptMessage, to_prompt_messages +from core.model_providers.models.llm.base import BaseLLM from core.orchestrator_rule_parser import OrchestratorRuleParser from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_template import JinjaPromptTemplate @@ -51,12 +43,10 @@ class Completion: inputs = conversation.inputs - rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens( - mode=app.mode, + final_model_instance = ModelFactory.get_text_generation_model_from_model_config( tenant_id=app.tenant_id, - app_model_config=app_model_config, - query=query, - inputs=inputs + model_config=app_model_config.model_dict, + streaming=streaming ) conversation_message_task = ConversationMessageTask( @@ -68,10 +58,17 @@ class Completion: is_override=is_override, inputs=inputs, query=query, - streaming=streaming + streaming=streaming, + model_instance=final_model_instance ) - chain_callback = MainChainGatherCallbackHandler(conversation_message_task) + rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens( + mode=app.mode, + model_instance=final_model_instance, + app_model_config=app_model_config, + query=query, + inputs=inputs + ) # init orchestrator rule parser orchestrator_rule_parser = OrchestratorRuleParser( @@ -80,6 +77,7 @@ class Completion: ) # parse sensitive_word_avoidance_chain + chain_callback = MainChainGatherCallbackHandler(conversation_message_task) sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain([chain_callback]) if sensitive_word_avoidance_chain: query = sensitive_word_avoidance_chain.run(query) @@ -102,15 +100,14 @@ class Completion: # run the final llm try: cls.run_final_llm( - tenant_id=app.tenant_id, + model_instance=final_model_instance, mode=app.mode, app_model_config=app_model_config, query=query, inputs=inputs, agent_execute_result=agent_execute_result, conversation_message_task=conversation_message_task, - memory=memory, - streaming=streaming + memory=memory ) except ConversationTaskStoppedException: return @@ -121,31 +118,20 @@ class Completion: return @classmethod - def run_final_llm(cls, tenant_id: str, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict, + def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str, inputs: dict, agent_execute_result: Optional[AgentExecuteResult], conversation_message_task: ConversationMessageTask, - memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory], streaming: bool): + memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]): # When no extra pre prompt is specified, # the output of the agent can be used directly as the main output content without calling LLM again + fake_response = None if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \ and agent_execute_result.strategy != PlanningStrategy.ROUTER: - final_llm = FakeLLM(response=agent_execute_result.output, - origin_llm=agent_execute_result.configuration.llm, - streaming=streaming) - final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task) - response = final_llm.generate([[HumanMessage(content=query)]]) - return response - - final_llm = LLMBuilder.to_llm_from_model( - tenant_id=tenant_id, - model=app_model_config.model_dict, - streaming=streaming - ) + fake_response = agent_execute_result.output # get llm prompt - prompt, stop_words = cls.get_main_llm_prompt( + prompt_messages, stop_words = cls.get_main_llm_prompt( mode=mode, - llm=final_llm, model=app_model_config.model_dict, pre_prompt=app_model_config.pre_prompt, query=query, @@ -154,25 +140,26 @@ class Completion: memory=memory ) - final_llm.callbacks = cls.get_llm_callbacks(final_llm, streaming, conversation_message_task) - cls.recale_llm_max_tokens( - final_llm=final_llm, - model=app_model_config.model_dict, - prompt=prompt, - mode=mode + model_instance=model_instance, + prompt_messages=prompt_messages, ) - response = final_llm.generate([prompt], stop_words) + response = model_instance.run( + messages=prompt_messages, + stop=stop_words, + callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)], + fake_response=fake_response + ) return response @classmethod - def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, model: dict, + def get_main_llm_prompt(cls, mode: str, model: dict, pre_prompt: str, query: str, inputs: dict, agent_execute_result: Optional[AgentExecuteResult], memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \ - Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]: + Tuple[List[PromptMessage], Optional[List[str]]]: if mode == 'completion': prompt_template = JinjaPromptTemplate.from_template( template=("""Use the following context as your learned knowledge, inside XML tags. @@ -200,11 +187,7 @@ And answer according to the language of the user's question. **prompt_inputs ) - if isinstance(llm, BaseChatModel): - # use chat llm as completion model - return [HumanMessage(content=prompt_content)], None - else: - return prompt_content, None + return [PromptMessage(content=prompt_content)], None else: messages: List[BaseMessage] = [] @@ -249,12 +232,14 @@ And answer according to the language of the user's question. inputs=human_inputs ) - curr_message_tokens = memory.llm.get_num_tokens_from_messages([tmp_human_message]) - model_name = model['name'] - max_tokens = model.get("completion_params").get('max_tokens') - rest_tokens = llm_constant.max_context_token_length[model_name] \ - - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) + if memory.model_instance.model_rules.max_tokens.max: + curr_message_tokens = memory.model_instance.get_num_tokens(to_prompt_messages([tmp_human_message])) + max_tokens = model.get("completion_params").get('max_tokens') + rest_tokens = memory.model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + else: + rest_tokens = 2000 + histories = cls.get_history_messages_from_memory(memory, rest_tokens) human_message_prompt += "\n\n" if human_message_prompt else "" human_message_prompt += "Here is the chat histories between human and assistant, " \ @@ -274,17 +259,7 @@ And answer according to the language of the user's question. for message in messages: message.content = re.sub(r'<\|.*?\|>', '', message.content) - return messages, ['\nHuman:', ''] - - @classmethod - def get_llm_callbacks(cls, llm: BaseLanguageModel, - streaming: bool, - conversation_message_task: ConversationMessageTask) -> List[BaseCallbackHandler]: - llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task) - if streaming: - return [llm_callback_handler, DifyStreamingStdOutCallbackHandler()] - else: - return [llm_callback_handler, DifyStdOutCallbackHandler()] + return to_prompt_messages(messages), ['\nHuman:', ''] @classmethod def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory, @@ -300,15 +275,15 @@ And answer according to the language of the user's question. conversation: Conversation, **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory: # only for calc token in memory - memory_llm = LLMBuilder.to_llm_from_model( + memory_model_instance = ModelFactory.get_text_generation_model_from_model_config( tenant_id=tenant_id, - model=app_model_config.model_dict + model_config=app_model_config.model_dict ) # use llm config from conversation memory = ReadOnlyConversationTokenDBBufferSharedMemory( conversation=conversation, - llm=memory_llm, + model_instance=memory_model_instance, max_token_limit=kwargs.get("max_token_limit", 2048), memory_key=kwargs.get("memory_key", "chat_history"), return_messages=kwargs.get("return_messages", True), @@ -320,21 +295,20 @@ And answer according to the language of the user's question. return memory @classmethod - def get_validate_rest_tokens(cls, mode: str, tenant_id: str, app_model_config: AppModelConfig, + def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig, query: str, inputs: dict) -> int: - llm = LLMBuilder.to_llm_from_model( - tenant_id=tenant_id, - model=app_model_config.model_dict - ) + model_limited_tokens = model_instance.model_rules.max_tokens.max + max_tokens = model_instance.get_model_kwargs().max_tokens - model_name = app_model_config.model_dict.get("name") - model_limited_tokens = llm_constant.max_context_token_length[model_name] - max_tokens = app_model_config.model_dict.get("completion_params").get('max_tokens') + if model_limited_tokens is None: + return -1 + + if max_tokens is None: + max_tokens = 0 # get prompt without memory and context - prompt, _ = cls.get_main_llm_prompt( + prompt_messages, _ = cls.get_main_llm_prompt( mode=mode, - llm=llm, model=app_model_config.model_dict, pre_prompt=app_model_config.pre_prompt, query=query, @@ -343,9 +317,7 @@ And answer according to the language of the user's question. memory=None ) - prompt_tokens = llm.get_num_tokens(prompt) if isinstance(prompt, str) \ - else llm.get_num_tokens_from_messages(prompt) - + prompt_tokens = model_instance.get_num_tokens(prompt_messages) rest_tokens = model_limited_tokens - max_tokens - prompt_tokens if rest_tokens < 0: raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, " @@ -354,36 +326,40 @@ And answer according to the language of the user's question. return rest_tokens @classmethod - def recale_llm_max_tokens(cls, final_llm: BaseLanguageModel, model: dict, - prompt: Union[str, List[BaseMessage]], mode: str): + def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]): # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit - model_name = model.get("name") - model_limited_tokens = llm_constant.max_context_token_length[model_name] - max_tokens = model.get("completion_params").get('max_tokens') + model_limited_tokens = model_instance.model_rules.max_tokens.max + max_tokens = model_instance.get_model_kwargs().max_tokens - if mode == 'completion' and isinstance(final_llm, BaseLLM): - prompt_tokens = final_llm.get_num_tokens(prompt) - else: - prompt_tokens = final_llm.get_num_tokens_from_messages(prompt) + if model_limited_tokens is None: + return + + if max_tokens is None: + max_tokens = 0 + + prompt_tokens = model_instance.get_num_tokens(prompt_messages) if prompt_tokens + max_tokens > model_limited_tokens: max_tokens = max(model_limited_tokens - prompt_tokens, 16) - final_llm.max_tokens = max_tokens + + # update model instance max tokens + model_kwargs = model_instance.get_model_kwargs() + model_kwargs.max_tokens = max_tokens + model_instance.set_model_kwargs(model_kwargs) @classmethod def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str, app_model_config: AppModelConfig, user: Account, streaming: bool): - llm = LLMBuilder.to_llm_from_model( + final_model_instance = ModelFactory.get_text_generation_model_from_model_config( tenant_id=app.tenant_id, - model=app_model_config.model_dict, + model_config=app_model_config.model_dict, streaming=streaming ) # get llm prompt - original_prompt, _ = cls.get_main_llm_prompt( + old_prompt_messages, _ = cls.get_main_llm_prompt( mode="completion", - llm=llm, model=app_model_config.model_dict, pre_prompt=pre_prompt, query=message.query, @@ -395,10 +371,9 @@ And answer according to the language of the user's question. original_completion = message.answer.strip() prompt = MORE_LIKE_THIS_GENERATE_PROMPT - prompt = prompt.format(prompt=original_prompt, original_completion=original_completion) + prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion) - if isinstance(llm, BaseChatModel): - prompt = [HumanMessage(content=prompt)] + prompt_messages = [PromptMessage(content=prompt)] conversation_message_task = ConversationMessageTask( task_id=task_id, @@ -408,16 +383,16 @@ And answer according to the language of the user's question. inputs=message.inputs, query=message.query, is_override=True if message.override_model_configs else False, - streaming=streaming + streaming=streaming, + model_instance=final_model_instance ) - llm.callbacks = cls.get_llm_callbacks(llm, streaming, conversation_message_task) - cls.recale_llm_max_tokens( - final_llm=llm, - model=app_model_config.model_dict, - prompt=prompt, - mode='completion' + model_instance=final_model_instance, + prompt_messages=prompt_messages ) - llm.generate([prompt]) + final_model_instance.run( + messages=prompt_messages, + callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)] + ) diff --git a/api/core/constant/llm_constant.py b/api/core/constant/llm_constant.py deleted file mode 100644 index 3a02abc90e..0000000000 --- a/api/core/constant/llm_constant.py +++ /dev/null @@ -1,109 +0,0 @@ -from _decimal import Decimal - -models = { - 'claude-instant-1': 'anthropic', # 100,000 tokens - 'claude-2': 'anthropic', # 100,000 tokens - 'gpt-4': 'openai', # 8,192 tokens - 'gpt-4-32k': 'openai', # 32,768 tokens - 'gpt-3.5-turbo': 'openai', # 4,096 tokens - 'gpt-3.5-turbo-16k': 'openai', # 16384 tokens - 'text-davinci-003': 'openai', # 4,097 tokens - 'text-davinci-002': 'openai', # 4,097 tokens - 'text-curie-001': 'openai', # 2,049 tokens - 'text-babbage-001': 'openai', # 2,049 tokens - 'text-ada-001': 'openai', # 2,049 tokens - 'text-embedding-ada-002': 'openai', # 8191 tokens, 1536 dimensions - 'whisper-1': 'openai' -} - -max_context_token_length = { - 'claude-instant-1': 100000, - 'claude-2': 100000, - 'gpt-4': 8192, - 'gpt-4-32k': 32768, - 'gpt-3.5-turbo': 4096, - 'gpt-3.5-turbo-16k': 16384, - 'text-davinci-003': 4097, - 'text-davinci-002': 4097, - 'text-curie-001': 2049, - 'text-babbage-001': 2049, - 'text-ada-001': 2049, - 'text-embedding-ada-002': 8191, -} - -models_by_mode = { - 'chat': [ - 'claude-instant-1', # 100,000 tokens - 'claude-2', # 100,000 tokens - 'gpt-4', # 8,192 tokens - 'gpt-4-32k', # 32,768 tokens - 'gpt-3.5-turbo', # 4,096 tokens - 'gpt-3.5-turbo-16k', # 16,384 tokens - ], - 'completion': [ - 'claude-instant-1', # 100,000 tokens - 'claude-2', # 100,000 tokens - 'gpt-4', # 8,192 tokens - 'gpt-4-32k', # 32,768 tokens - 'gpt-3.5-turbo', # 4,096 tokens - 'gpt-3.5-turbo-16k', # 16,384 tokens - 'text-davinci-003', # 4,097 tokens - 'text-davinci-002' # 4,097 tokens - 'text-curie-001', # 2,049 tokens - 'text-babbage-001', # 2,049 tokens - 'text-ada-001' # 2,049 tokens - ], - 'embedding': [ - 'text-embedding-ada-002' # 8191 tokens, 1536 dimensions - ] -} - -model_currency = 'USD' - -model_prices = { - 'claude-instant-1': { - 'prompt': Decimal('0.00163'), - 'completion': Decimal('0.00551'), - }, - 'claude-2': { - 'prompt': Decimal('0.01102'), - 'completion': Decimal('0.03268'), - }, - 'gpt-4': { - 'prompt': Decimal('0.03'), - 'completion': Decimal('0.06'), - }, - 'gpt-4-32k': { - 'prompt': Decimal('0.06'), - 'completion': Decimal('0.12') - }, - 'gpt-3.5-turbo': { - 'prompt': Decimal('0.0015'), - 'completion': Decimal('0.002') - }, - 'gpt-3.5-turbo-16k': { - 'prompt': Decimal('0.003'), - 'completion': Decimal('0.004') - }, - 'text-davinci-003': { - 'prompt': Decimal('0.02'), - 'completion': Decimal('0.02') - }, - 'text-curie-001': { - 'prompt': Decimal('0.002'), - 'completion': Decimal('0.002') - }, - 'text-babbage-001': { - 'prompt': Decimal('0.0005'), - 'completion': Decimal('0.0005') - }, - 'text-ada-001': { - 'prompt': Decimal('0.0004'), - 'completion': Decimal('0.0004') - }, - 'text-embedding-ada-002': { - 'usage': Decimal('0.0001'), - } -} - -agent_model_name = 'text-davinci-003' diff --git a/api/core/conversation_message_task.py b/api/core/conversation_message_task.py index 51c2b69023..e9d9f3ec80 100644 --- a/api/core/conversation_message_task.py +++ b/api/core/conversation_message_task.py @@ -6,9 +6,9 @@ from core.callback_handler.entity.agent_loop import AgentLoop from core.callback_handler.entity.dataset_query import DatasetQueryObj from core.callback_handler.entity.llm_message import LLMMessage from core.callback_handler.entity.chain_result import ChainResult -from core.constant import llm_constant -from core.llm.llm_builder import LLMBuilder -from core.llm.provider.llm_provider_service import LLMProviderService +from core.model_providers.model_factory import ModelFactory +from core.model_providers.models.entity.message import to_prompt_messages, MessageType +from core.model_providers.models.llm.base import BaseLLM from core.prompt.prompt_builder import PromptBuilder from core.prompt.prompt_template import JinjaPromptTemplate from events.message_event import message_was_created @@ -16,12 +16,11 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from models.dataset import DatasetQuery from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, MessageChain -from models.provider import ProviderType, Provider class ConversationMessageTask: def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account, - inputs: dict, query: str, streaming: bool, + inputs: dict, query: str, streaming: bool, model_instance: BaseLLM, conversation: Optional[Conversation] = None, is_override: bool = False): self.task_id = task_id @@ -38,9 +37,12 @@ class ConversationMessageTask: self.conversation = conversation self.is_new_conversation = False + self.model_instance = model_instance + self.message = None self.model_dict = self.app_model_config.model_dict + self.provider_name = self.model_dict.get('provider') self.model_name = self.model_dict.get('name') self.mode = app.mode @@ -56,9 +58,6 @@ class ConversationMessageTask: ) def init(self): - provider_name = LLMBuilder.get_default_provider(self.app.tenant_id, self.model_name) - self.model_dict['provider'] = provider_name - override_model_configs = None if self.is_override: override_model_configs = { @@ -89,15 +88,19 @@ class ConversationMessageTask: if self.app_model_config.pre_prompt: system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs) system_instruction = system_message.content - llm = LLMBuilder.to_llm(self.tenant_id, self.model_name) - system_instruction_tokens = llm.get_num_tokens_from_messages([system_message]) + model_instance = ModelFactory.get_text_generation_model( + tenant_id=self.tenant_id, + model_provider_name=self.provider_name, + model_name=self.model_name + ) + system_instruction_tokens = model_instance.get_num_tokens(to_prompt_messages([system_message])) if not self.conversation: self.is_new_conversation = True self.conversation = Conversation( app_id=self.app_model_config.app_id, app_model_config_id=self.app_model_config.id, - model_provider=self.model_dict.get('provider'), + model_provider=self.provider_name, model_id=self.model_name, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, mode=self.mode, @@ -117,7 +120,7 @@ class ConversationMessageTask: self.message = Message( app_id=self.app_model_config.app_id, - model_provider=self.model_dict.get('provider'), + model_provider=self.provider_name, model_id=self.model_name, override_model_configs=json.dumps(override_model_configs) if override_model_configs else None, conversation_id=self.conversation.id, @@ -131,7 +134,7 @@ class ConversationMessageTask: answer_unit_price=0, provider_response_latency=0, total_price=0, - currency=llm_constant.model_currency, + currency=self.model_instance.get_currency(), from_source=('console' if isinstance(self.user, Account) else 'api'), from_end_user_id=(self.user.id if isinstance(self.user, EndUser) else None), from_account_id=(self.user.id if isinstance(self.user, Account) else None), @@ -145,12 +148,10 @@ class ConversationMessageTask: self._pub_handler.pub_text(text) def save_message(self, llm_message: LLMMessage, by_stopped: bool = False): - model_name = self.app_model_config.model_dict.get('name') - message_tokens = llm_message.prompt_tokens answer_tokens = llm_message.completion_tokens - message_unit_price = llm_constant.model_prices[model_name]['prompt'] - answer_unit_price = llm_constant.model_prices[model_name]['completion'] + message_unit_price = self.model_instance.get_token_price(1, MessageType.HUMAN) + answer_unit_price = self.model_instance.get_token_price(1, MessageType.ASSISTANT) total_price = self.calc_total_price(message_tokens, message_unit_price, answer_tokens, answer_unit_price) @@ -163,8 +164,6 @@ class ConversationMessageTask: self.message.provider_response_latency = llm_message.latency self.message.total_price = total_price - self.update_provider_quota() - db.session.commit() message_was_created.send( @@ -176,20 +175,6 @@ class ConversationMessageTask: if not by_stopped: self.end() - def update_provider_quota(self): - llm_provider_service = LLMProviderService( - tenant_id=self.app.tenant_id, - provider_name=self.message.model_provider, - ) - - provider = llm_provider_service.get_provider_db_record() - if provider and provider.provider_type == ProviderType.SYSTEM.value: - db.session.query(Provider).filter( - Provider.tenant_id == self.app.tenant_id, - Provider.provider_name == provider.provider_name, - Provider.quota_limit > Provider.quota_used - ).update({'quota_used': Provider.quota_used + 1}) - def init_chain(self, chain_result: ChainResult): message_chain = MessageChain( message_id=self.message.id, @@ -229,10 +214,10 @@ class ConversationMessageTask: return message_agent_thought - def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_name: str, + def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instant: BaseLLM, agent_loop: AgentLoop): - agent_message_unit_price = llm_constant.model_prices[agent_model_name]['prompt'] - agent_answer_unit_price = llm_constant.model_prices[agent_model_name]['completion'] + agent_message_unit_price = agent_model_instant.get_token_price(1, MessageType.HUMAN) + agent_answer_unit_price = agent_model_instant.get_token_price(1, MessageType.ASSISTANT) loop_message_tokens = agent_loop.prompt_tokens loop_answer_tokens = agent_loop.completion_tokens @@ -253,7 +238,7 @@ class ConversationMessageTask: message_agent_thought.latency = agent_loop.latency message_agent_thought.tokens = agent_loop.prompt_tokens + agent_loop.completion_tokens message_agent_thought.total_price = loop_total_price - message_agent_thought.currency = llm_constant.model_currency + message_agent_thought.currency = agent_model_instant.get_currency() db.session.flush() def on_dataset_query_end(self, dataset_query_obj: DatasetQueryObj): diff --git a/api/core/docstore/dataset_docstore.py b/api/core/docstore/dataset_docstore.py index 016e711378..786ae4469d 100644 --- a/api/core/docstore/dataset_docstore.py +++ b/api/core/docstore/dataset_docstore.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Sequence from langchain.schema import Document from sqlalchemy import func -from core.llm.token_calculator import TokenCalculator +from core.model_providers.model_factory import ModelFactory from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment @@ -13,12 +13,10 @@ class DatesetDocumentStore: self, dataset: Dataset, user_id: str, - embedding_model_name: str, document_id: Optional[str] = None, ): self._dataset = dataset self._user_id = user_id - self._embedding_model_name = embedding_model_name self._document_id = document_id @classmethod @@ -39,10 +37,6 @@ class DatesetDocumentStore: def user_id(self) -> Any: return self._user_id - @property - def embedding_model_name(self) -> Any: - return self._embedding_model_name - @property def docs(self) -> Dict[str, Document]: document_segments = db.session.query(DocumentSegment).filter( @@ -74,6 +68,10 @@ class DatesetDocumentStore: if max_position is None: max_position = 0 + embedding_model = ModelFactory.get_embedding_model( + tenant_id=self._dataset.tenant_id + ) + for doc in docs: if not isinstance(doc, Document): raise ValueError("doc must be a Document") @@ -88,7 +86,7 @@ class DatesetDocumentStore: ) # calc embedding use tokens - tokens = TokenCalculator.get_num_tokens(self._embedding_model_name, doc.page_content) + tokens = embedding_model.get_num_tokens(doc.page_content) if not segment_document: max_position += 1 diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index 045b13ea3f..63bab8cd54 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -4,14 +4,14 @@ from typing import List from langchain.embeddings.base import Embeddings from sqlalchemy.exc import IntegrityError -from core.llm.wrappers.openai_wrapper import handle_openai_exceptions +from core.model_providers.models.embedding.base import BaseEmbedding from extensions.ext_database import db from libs import helper from models.dataset import Embedding class CacheEmbedding(Embeddings): - def __init__(self, embeddings: Embeddings): + def __init__(self, embeddings: BaseEmbedding): self._embeddings = embeddings def embed_documents(self, texts: List[str]) -> List[List[float]]: @@ -21,48 +21,54 @@ class CacheEmbedding(Embeddings): embedding_queue_texts = [] for text in texts: hash = helper.generate_text_hash(text) - embedding = db.session.query(Embedding).filter_by(hash=hash).first() + embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first() if embedding: text_embeddings.append(embedding.get_embedding()) else: embedding_queue_texts.append(text) - embedding_results = self._embeddings.embed_documents(embedding_queue_texts) - - i = 0 - for text in embedding_queue_texts: - hash = helper.generate_text_hash(text) - + if embedding_queue_texts: try: - embedding = Embedding(hash=hash) - embedding.set_embedding(embedding_results[i]) - db.session.add(embedding) - db.session.commit() - except IntegrityError: - db.session.rollback() - continue - except: - logging.exception('Failed to add embedding to db') - continue - finally: - i += 1 + embedding_results = self._embeddings.client.embed_documents(embedding_queue_texts) + except Exception as ex: + raise self._embeddings.handle_exceptions(ex) - text_embeddings.extend(embedding_results) + i = 0 + for text in embedding_queue_texts: + hash = helper.generate_text_hash(text) + + try: + embedding = Embedding(model_name=self._embeddings.name, hash=hash) + embedding.set_embedding(embedding_results[i]) + db.session.add(embedding) + db.session.commit() + except IntegrityError: + db.session.rollback() + continue + except: + logging.exception('Failed to add embedding to db') + continue + finally: + i += 1 + + text_embeddings.extend(embedding_results) return text_embeddings - @handle_openai_exceptions def embed_query(self, text: str) -> List[float]: """Embed query text.""" # use doc embedding cache or store if not exists hash = helper.generate_text_hash(text) - embedding = db.session.query(Embedding).filter_by(hash=hash).first() + embedding = db.session.query(Embedding).filter_by(model_name=self._embeddings.name, hash=hash).first() if embedding: return embedding.get_embedding() - embedding_results = self._embeddings.embed_query(text) + try: + embedding_results = self._embeddings.client.embed_query(text) + except Exception as ex: + raise self._embeddings.handle_exceptions(ex) try: - embedding = Embedding(hash=hash) + embedding = Embedding(model_name=self._embeddings.name, hash=hash) embedding.set_embedding(embedding_results) db.session.add(embedding) db.session.commit() @@ -72,3 +78,5 @@ class CacheEmbedding(Embeddings): logging.exception('Failed to add embedding to db') return embedding_results + + diff --git a/api/core/generator/llm_generator.py b/api/core/generator/llm_generator.py index a5294add23..77cf8a2346 100644 --- a/api/core/generator/llm_generator.py +++ b/api/core/generator/llm_generator.py @@ -1,13 +1,10 @@ import logging -from langchain import PromptTemplate -from langchain.chat_models.base import BaseChatModel -from langchain.schema import HumanMessage, OutputParserException, BaseMessage, SystemMessage +from langchain.schema import OutputParserException -from core.constant import llm_constant -from core.llm.llm_builder import LLMBuilder -from core.llm.streamable_open_ai import StreamableOpenAI -from core.llm.token_calculator import TokenCalculator +from core.model_providers.model_factory import ModelFactory +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelKwargs from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser @@ -15,9 +12,6 @@ from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTempla from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \ GENERATOR_QA_PROMPT -# gpt-3.5-turbo works not well -generate_base_model = 'text-davinci-003' - class LLMGenerator: @classmethod @@ -28,29 +22,35 @@ class LLMGenerator: query = query[:300] + "...[TRUNCATED]..." + query[-300:] prompt = prompt.format(query=query) - llm: StreamableOpenAI = LLMBuilder.to_llm( + + model_instance = ModelFactory.get_text_generation_model( tenant_id=tenant_id, - model_name='gpt-3.5-turbo', - max_tokens=50, - timeout=600 + model_kwargs=ModelKwargs( + max_tokens=50 + ) ) - if isinstance(llm, BaseChatModel): - prompt = [HumanMessage(content=prompt)] - - response = llm.generate([prompt]) - answer = response.generations[0][0].text + prompts = [PromptMessage(content=prompt)] + response = model_instance.run(prompts) + answer = response.content return answer.strip() @classmethod def generate_conversation_summary(cls, tenant_id: str, messages): max_tokens = 200 - model = 'gpt-3.5-turbo' + + model_instance = ModelFactory.get_text_generation_model( + tenant_id=tenant_id, + model_kwargs=ModelKwargs( + max_tokens=max_tokens + ) + ) prompt = CONVERSATION_SUMMARY_PROMPT prompt_with_empty_context = prompt.format(context='') - prompt_tokens = TokenCalculator.get_num_tokens(model, prompt_with_empty_context) - rest_tokens = llm_constant.max_context_token_length[model] - prompt_tokens - max_tokens - 1 + prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)]) + max_context_token_length = model_instance.model_rules.max_tokens.max + rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1 context = '' for message in messages: @@ -68,25 +68,16 @@ class LLMGenerator: answer = message.answer message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer - if rest_tokens - TokenCalculator.get_num_tokens(model, context + message_qa_text) > 0: + if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0: context += message_qa_text if not context: return '[message too long, no summary]' prompt = prompt.format(context=context) - - llm: StreamableOpenAI = LLMBuilder.to_llm( - tenant_id=tenant_id, - model_name=model, - max_tokens=max_tokens - ) - - if isinstance(llm, BaseChatModel): - prompt = [HumanMessage(content=prompt)] - - response = llm.generate([prompt]) - answer = response.generations[0][0].text + prompts = [PromptMessage(content=prompt)] + response = model_instance.run(prompts) + answer = response.content return answer.strip() @classmethod @@ -94,16 +85,13 @@ class LLMGenerator: prompt = INTRODUCTION_GENERATE_PROMPT prompt = prompt.format(prompt=pre_prompt) - llm: StreamableOpenAI = LLMBuilder.to_llm( - tenant_id=tenant_id, - model_name=generate_base_model, + model_instance = ModelFactory.get_text_generation_model( + tenant_id=tenant_id ) - if isinstance(llm, BaseChatModel): - prompt = [HumanMessage(content=prompt)] - - response = llm.generate([prompt]) - answer = response.generations[0][0].text + prompts = [PromptMessage(content=prompt)] + response = model_instance.run(prompts) + answer = response.content return answer.strip() @classmethod @@ -119,23 +107,19 @@ class LLMGenerator: _input = prompt.format_prompt(histories=histories) - llm: StreamableOpenAI = LLMBuilder.to_llm( + model_instance = ModelFactory.get_text_generation_model( tenant_id=tenant_id, - model_name='gpt-3.5-turbo', - temperature=0, - max_tokens=256 + model_kwargs=ModelKwargs( + max_tokens=256, + temperature=0 + ) ) - if isinstance(llm, BaseChatModel): - query = [HumanMessage(content=_input.to_string())] - else: - query = _input.to_string() + prompts = [PromptMessage(content=_input.to_string())] try: - output = llm(query) - if isinstance(output, BaseMessage): - output = output.content - questions = output_parser.parse(output) + output = model_instance.run(prompts) + questions = output_parser.parse(output.content) except Exception: logging.exception("Error generating suggested questions after answer") questions = [] @@ -160,21 +144,19 @@ class LLMGenerator: _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve) - llm: StreamableOpenAI = LLMBuilder.to_llm( + model_instance = ModelFactory.get_text_generation_model( tenant_id=tenant_id, - model_name=generate_base_model, - temperature=0, - max_tokens=512 + model_kwargs=ModelKwargs( + max_tokens=512, + temperature=0 + ) ) - if isinstance(llm, BaseChatModel): - query = [HumanMessage(content=_input.to_string())] - else: - query = _input.to_string() + prompts = [PromptMessage(content=_input.to_string())] try: - output = llm(query) - rule_config = output_parser.parse(output) + output = model_instance.run(prompts) + rule_config = output_parser.parse(output.content) except OutputParserException: raise ValueError('Please give a valid input for intended audience or hoping to solve problems.') except Exception: @@ -188,25 +170,21 @@ class LLMGenerator: return rule_config @classmethod - async def generate_qa_document(cls, llm: StreamableOpenAI, query): + def generate_qa_document(cls, tenant_id: str, query): prompt = GENERATOR_QA_PROMPT + model_instance = ModelFactory.get_text_generation_model( + tenant_id=tenant_id, + model_kwargs=ModelKwargs( + max_tokens=2000 + ) + ) - if isinstance(llm, BaseChatModel): - prompt = [SystemMessage(content=prompt), HumanMessage(content=query)] + prompts = [ + PromptMessage(content=prompt, type=MessageType.SYSTEM), + PromptMessage(content=query) + ] - response = llm.generate([prompt]) - answer = response.generations[0][0].text - return answer.strip() - - @classmethod - def generate_qa_document_sync(cls, llm: StreamableOpenAI, query): - prompt = GENERATOR_QA_PROMPT - - - if isinstance(llm, BaseChatModel): - prompt = [SystemMessage(content=prompt), HumanMessage(content=query)] - - response = llm.generate([prompt]) - answer = response.generations[0][0].text + response = model_instance.run(prompts) + answer = response.content return answer.strip() diff --git a/api/tests/test_helpers/__init__.py b/api/core/helper/__init__.py similarity index 100% rename from api/tests/test_helpers/__init__.py rename to api/core/helper/__init__.py diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py new file mode 100644 index 0000000000..fa94867ba2 --- /dev/null +++ b/api/core/helper/encrypter.py @@ -0,0 +1,20 @@ +import base64 + +from extensions.ext_database import db +from libs import rsa + +from models.account import Tenant + + +def obfuscated_token(token: str): + return token[:6] + '*' * (len(token) - 8) + token[-2:] + + +def encrypt_token(tenant_id: str, token: str): + tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first() + encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) + return base64.b64encode(encrypted_token).decode() + + +def decrypt_token(tenant_id: str, token: str): + return rsa.decrypt(base64.b64decode(token), tenant_id) diff --git a/api/core/index/index.py b/api/core/index/index.py index 657ad221e2..316b604566 100644 --- a/api/core/index/index.py +++ b/api/core/index/index.py @@ -1,10 +1,9 @@ from flask import current_app -from langchain.embeddings import OpenAIEmbeddings from core.embedding.cached_embedding import CacheEmbedding from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig from core.index.vector_index.vector_index import VectorIndex -from core.llm.llm_builder import LLMBuilder +from core.model_providers.model_factory import ModelFactory from models.dataset import Dataset @@ -15,16 +14,11 @@ class IndexBuilder: if not ignore_high_quality_check and dataset.indexing_technique != 'high_quality': return None - model_credentials = LLMBuilder.get_model_credentials( - tenant_id=dataset.tenant_id, - model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'), - model_name='text-embedding-ada-002' + embedding_model = ModelFactory.get_embedding_model( + tenant_id=dataset.tenant_id ) - embeddings = CacheEmbedding(OpenAIEmbeddings( - max_retries=1, - **model_credentials - )) + embeddings = CacheEmbedding(embedding_model) return VectorIndex( dataset=dataset, diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 56df8f2316..57e1e8fab6 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -1,4 +1,3 @@ -import concurrent import datetime import json import logging @@ -6,7 +5,6 @@ import re import threading import time import uuid -from concurrent.futures import ThreadPoolExecutor from typing import Optional, List, cast from flask_login import current_user @@ -18,11 +16,10 @@ from core.data_loader.loader.notion import NotionLoader from core.docstore.dataset_docstore import DatesetDocumentStore from core.generator.llm_generator import LLMGenerator from core.index.index import IndexBuilder -from core.llm.error import ProviderTokenNotInitError -from core.llm.llm_builder import LLMBuilder -from core.llm.streamable_open_ai import StreamableOpenAI +from core.model_providers.error import ProviderTokenNotInitError +from core.model_providers.model_factory import ModelFactory +from core.model_providers.models.entity.message import MessageType from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter -from core.llm.token_calculator import TokenCalculator from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage @@ -35,9 +32,8 @@ from models.source import DataSourceBinding class IndexingRunner: - def __init__(self, embedding_model_name: str = "text-embedding-ada-002"): + def __init__(self): self.storage = storage - self.embedding_model_name = embedding_model_name def run(self, dataset_documents: List[DatasetDocument]): """Run the indexing process.""" @@ -227,11 +223,15 @@ class IndexingRunner: dataset_document.stopped_at = datetime.datetime.utcnow() db.session.commit() - def file_indexing_estimate(self, file_details: List[UploadFile], tmp_processing_rule: dict, + def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict, doc_form: str = None) -> dict: """ Estimate the indexing for the document. """ + embedding_model = ModelFactory.get_embedding_model( + tenant_id=tenant_id + ) + tokens = 0 preview_texts = [] total_segments = 0 @@ -253,44 +253,49 @@ class IndexingRunner: splitter=splitter, processing_rule=processing_rule ) + total_segments += len(documents) + for document in documents: if len(preview_texts) < 5: preview_texts.append(document.page_content) - tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, - self.filter_string(document.page_content)) + tokens += embedding_model.get_num_tokens(self.filter_string(document.page_content)) + + text_generation_model = ModelFactory.get_text_generation_model( + tenant_id=tenant_id + ) + if doc_form and doc_form == 'qa_model': if len(preview_texts) > 0: # qa model document - llm: StreamableOpenAI = LLMBuilder.to_llm( - tenant_id=current_user.current_tenant_id, - model_name='gpt-3.5-turbo', - max_tokens=2000 - ) - response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0]) + response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0]) document_qa_list = self.format_split_text(response) return { "total_segments": total_segments * 20, "tokens": total_segments * 2000, "total_price": '{:f}'.format( - TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')), - "currency": TokenCalculator.get_currency(self.embedding_model_name), + text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)), + "currency": embedding_model.get_currency(), "qa_preview": document_qa_list, "preview": preview_texts } return { "total_segments": total_segments, "tokens": tokens, - "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)), - "currency": TokenCalculator.get_currency(self.embedding_model_name), + "total_price": '{:f}'.format(embedding_model.get_token_price(tokens)), + "currency": embedding_model.get_currency(), "preview": preview_texts } - def notion_indexing_estimate(self, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict: + def notion_indexing_estimate(self, tenant_id: str, notion_info_list: list, tmp_processing_rule: dict, doc_form: str = None) -> dict: """ Estimate the indexing for the document. """ + embedding_model = ModelFactory.get_embedding_model( + tenant_id=tenant_id + ) + # load data from notion tokens = 0 preview_texts = [] @@ -336,31 +341,31 @@ class IndexingRunner: if len(preview_texts) < 5: preview_texts.append(document.page_content) - tokens += TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content) + tokens += embedding_model.get_num_tokens(document.page_content) + + text_generation_model = ModelFactory.get_text_generation_model( + tenant_id=tenant_id + ) + if doc_form and doc_form == 'qa_model': if len(preview_texts) > 0: # qa model document - llm: StreamableOpenAI = LLMBuilder.to_llm( - tenant_id=current_user.current_tenant_id, - model_name='gpt-3.5-turbo', - max_tokens=2000 - ) - response = LLMGenerator.generate_qa_document_sync(llm, preview_texts[0]) + response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0]) document_qa_list = self.format_split_text(response) return { "total_segments": total_segments * 20, "tokens": total_segments * 2000, "total_price": '{:f}'.format( - TokenCalculator.get_token_price('gpt-3.5-turbo', total_segments * 2000, 'completion')), - "currency": TokenCalculator.get_currency(self.embedding_model_name), + text_generation_model.get_token_price(total_segments * 2000, MessageType.HUMAN)), + "currency": embedding_model.get_currency(), "qa_preview": document_qa_list, "preview": preview_texts } return { "total_segments": total_segments, "tokens": tokens, - "total_price": '{:f}'.format(TokenCalculator.get_token_price(self.embedding_model_name, tokens)), - "currency": TokenCalculator.get_currency(self.embedding_model_name), + "total_price": '{:f}'.format(embedding_model.get_token_price(tokens)), + "currency": embedding_model.get_currency(), "preview": preview_texts } @@ -459,7 +464,6 @@ class IndexingRunner: doc_store = DatesetDocumentStore( dataset=dataset, user_id=dataset_document.created_by, - embedding_model_name=self.embedding_model_name, document_id=dataset_document.id ) @@ -513,17 +517,12 @@ class IndexingRunner: all_documents.extend(split_documents) # processing qa document if document_form == 'qa_model': - llm: StreamableOpenAI = LLMBuilder.to_llm( - tenant_id=tenant_id, - model_name='gpt-3.5-turbo', - max_tokens=2000 - ) for i in range(0, len(all_documents), 10): threads = [] sub_documents = all_documents[i:i + 10] for doc in sub_documents: document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={ - 'llm': llm, 'document_node': doc, 'all_qa_documents': all_qa_documents}) + 'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents}) threads.append(document_format_thread) document_format_thread.start() for thread in threads: @@ -531,13 +530,13 @@ class IndexingRunner: return all_qa_documents return all_documents - def format_qa_document(self, llm: StreamableOpenAI, document_node, all_qa_documents): + def format_qa_document(self, tenant_id: str, document_node, all_qa_documents): format_documents = [] if document_node.page_content is None or not document_node.page_content.strip(): return try: # qa model document - response = LLMGenerator.generate_qa_document_sync(llm, document_node.page_content) + response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content) document_qa_list = self.format_split_text(response) qa_documents = [] for result in document_qa_list: @@ -638,6 +637,10 @@ class IndexingRunner: vector_index = IndexBuilder.get_index(dataset, 'high_quality') keyword_table_index = IndexBuilder.get_index(dataset, 'economy') + embedding_model = ModelFactory.get_embedding_model( + tenant_id=dataset.tenant_id + ) + # chunk nodes by chunk size indexing_start_at = time.perf_counter() tokens = 0 @@ -648,7 +651,7 @@ class IndexingRunner: chunk_documents = documents[i:i + chunk_size] tokens += sum( - TokenCalculator.get_num_tokens(self.embedding_model_name, document.page_content) + embedding_model.get_num_tokens(document.page_content) for document in chunk_documents ) diff --git a/api/core/llm/llm_builder.py b/api/core/llm/llm_builder.py deleted file mode 100644 index f054939fd5..0000000000 --- a/api/core/llm/llm_builder.py +++ /dev/null @@ -1,148 +0,0 @@ -from typing import Union, Optional, List - -from langchain.callbacks.base import BaseCallbackHandler - -from core.constant import llm_constant -from core.llm.error import ProviderTokenNotInitError -from core.llm.provider.base import BaseProvider -from core.llm.provider.llm_provider_service import LLMProviderService -from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI -from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI -from core.llm.streamable_chat_anthropic import StreamableChatAnthropic -from core.llm.streamable_chat_open_ai import StreamableChatOpenAI -from core.llm.streamable_open_ai import StreamableOpenAI -from models.provider import ProviderType, ProviderName - - -class LLMBuilder: - """ - This class handles the following logic: - 1. For providers with the name 'OpenAI', the OPENAI_API_KEY value is stored directly in encrypted_config. - 2. For providers with the name 'Azure OpenAI', encrypted_config stores the serialized values of four fields, as shown below: - OPENAI_API_TYPE=azure - OPENAI_API_VERSION=2022-12-01 - OPENAI_API_BASE=https://your-resource-name.openai.azure.com - OPENAI_API_KEY= - 3. For providers with the name 'Anthropic', the ANTHROPIC_API_KEY value is stored directly in encrypted_config. - 4. For providers with the name 'Cohere', the COHERE_API_KEY value is stored directly in encrypted_config. - 5. For providers with the name 'HUGGINGFACEHUB', the HUGGINGFACEHUB_API_KEY value is stored directly in encrypted_config. - 6. Providers with the provider_type 'CUSTOM' can be created through the admin interface, while 'System' providers cannot be created through the admin interface. - 7. If both CUSTOM and System providers exist in the records, the CUSTOM provider is preferred by default, but this preference can be changed via an input parameter. - 8. For providers with the provider_type 'System', the quota_used must not exceed quota_limit. If the quota is exceeded, the provider cannot be used. Currently, only the TRIAL quota_type is supported, which is permanently non-resetting. - """ - - @classmethod - def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]: - provider = cls.get_default_provider(tenant_id, model_name) - - model_credentials = cls.get_model_credentials(tenant_id, provider, model_name) - - llm_cls = None - mode = cls.get_mode_by_model(model_name) - if mode == 'chat': - if provider == ProviderName.OPENAI.value: - llm_cls = StreamableChatOpenAI - elif provider == ProviderName.AZURE_OPENAI.value: - llm_cls = StreamableAzureChatOpenAI - elif provider == ProviderName.ANTHROPIC.value: - llm_cls = StreamableChatAnthropic - elif mode == 'completion': - if provider == ProviderName.OPENAI.value: - llm_cls = StreamableOpenAI - elif provider == ProviderName.AZURE_OPENAI.value: - llm_cls = StreamableAzureOpenAI - - if not llm_cls: - raise ValueError(f"model name {model_name} is not supported.") - - model_kwargs = { - 'model_name': model_name, - 'temperature': kwargs.get('temperature', 0), - 'max_tokens': kwargs.get('max_tokens', 256), - 'top_p': kwargs.get('top_p', 1), - 'frequency_penalty': kwargs.get('frequency_penalty', 0), - 'presence_penalty': kwargs.get('presence_penalty', 0), - 'callbacks': kwargs.get('callbacks', None), - 'streaming': kwargs.get('streaming', False), - } - - model_kwargs.update(model_credentials) - model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs) - - return llm_cls(**model_kwargs) - - @classmethod - def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False, - callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]: - model_name = model.get("name") - completion_params = model.get("completion_params", {}) - - return cls.to_llm( - tenant_id=tenant_id, - model_name=model_name, - temperature=completion_params.get('temperature', 0), - max_tokens=completion_params.get('max_tokens', 256), - top_p=completion_params.get('top_p', 0), - frequency_penalty=completion_params.get('frequency_penalty', 0.1), - presence_penalty=completion_params.get('presence_penalty', 0.1), - streaming=streaming, - callbacks=callbacks - ) - - @classmethod - def get_mode_by_model(cls, model_name: str) -> str: - if not model_name: - raise ValueError(f"empty model name is not supported.") - - if model_name in llm_constant.models_by_mode['chat']: - return "chat" - elif model_name in llm_constant.models_by_mode['completion']: - return "completion" - else: - raise ValueError(f"model name {model_name} is not supported.") - - @classmethod - def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict: - """ - Returns the API credentials for the given tenant_id and model_name, based on the model's provider. - Raises an exception if the model_name is not found or if the provider is not found. - """ - if not model_name: - raise Exception('model name not found') - # - # if model_name not in llm_constant.models: - # raise Exception('model {} not found'.format(model_name)) - - # model_provider = llm_constant.models[model_name] - - provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider) - return provider_service.get_credentials(model_name) - - @classmethod - def get_default_provider(cls, tenant_id: str, model_name: str) -> str: - provider_name = llm_constant.models[model_name] - - if provider_name == 'openai': - # get the default provider (openai / azure_openai) for the tenant - openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value) - azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value) - - provider = None - if openai_provider and openai_provider.provider_type == ProviderType.CUSTOM.value: - provider = openai_provider - elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.CUSTOM.value: - provider = azure_openai_provider - elif openai_provider and openai_provider.provider_type == ProviderType.SYSTEM.value: - provider = openai_provider - elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.SYSTEM.value: - provider = azure_openai_provider - - if not provider: - raise ProviderTokenNotInitError( - f"No valid {provider_name} model provider credentials found. " - f"Please go to Settings -> Model Provider to complete your provider credentials." - ) - - provider_name = provider.provider_name - - return provider_name diff --git a/api/core/llm/moderation.py b/api/core/llm/moderation.py deleted file mode 100644 index d18d6fc5c2..0000000000 --- a/api/core/llm/moderation.py +++ /dev/null @@ -1,15 +0,0 @@ -import openai -from models.provider import ProviderName - - -class Moderation: - - def __init__(self, provider: str, api_key: str): - self.provider = provider - self.api_key = api_key - - if self.provider == ProviderName.OPENAI.value: - self.client = openai.Moderation - - def moderate(self, text): - return self.client.create(input=text, api_key=self.api_key) diff --git a/api/core/llm/provider/anthropic_provider.py b/api/core/llm/provider/anthropic_provider.py deleted file mode 100644 index d6165d0329..0000000000 --- a/api/core/llm/provider/anthropic_provider.py +++ /dev/null @@ -1,138 +0,0 @@ -import json -import logging -from typing import Optional, Union - -import anthropic -from langchain.chat_models import ChatAnthropic -from langchain.schema import HumanMessage - -from core import hosted_llm_credentials -from core.llm.error import ProviderTokenNotInitError -from core.llm.provider.base import BaseProvider -from core.llm.provider.errors import ValidateFailedError -from models.provider import ProviderName, ProviderType - - -class AnthropicProvider(BaseProvider): - def get_models(self, model_id: Optional[str] = None) -> list[dict]: - return [ - { - 'id': 'claude-instant-1', - 'name': 'claude-instant-1', - }, - { - 'id': 'claude-2', - 'name': 'claude-2', - }, - ] - - def get_credentials(self, model_id: Optional[str] = None) -> dict: - return self.get_provider_api_key(model_id=model_id) - - def get_provider_name(self): - return ProviderName.ANTHROPIC - - def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]: - """ - Returns the provider configs. - """ - try: - config = self.get_provider_api_key(only_custom=only_custom) - except: - config = { - 'anthropic_api_key': '' - } - - if obfuscated: - if not config.get('anthropic_api_key'): - config = { - 'anthropic_api_key': '' - } - - config['anthropic_api_key'] = self.obfuscated_token(config.get('anthropic_api_key')) - return config - - return config - - def get_encrypted_token(self, config: Union[dict | str]): - """ - Returns the encrypted token. - """ - return json.dumps({ - 'anthropic_api_key': self.encrypt_token(config['anthropic_api_key']) - }) - - def get_decrypted_token(self, token: str): - """ - Returns the decrypted token. - """ - config = json.loads(token) - config['anthropic_api_key'] = self.decrypt_token(config['anthropic_api_key']) - return config - - def get_token_type(self): - return dict - - def config_validate(self, config: Union[dict | str]): - """ - Validates the given config. - """ - # check OpenAI / Azure OpenAI credential is valid - openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.OPENAI.value) - azure_openai_provider = BaseProvider.get_valid_provider(self.tenant_id, ProviderName.AZURE_OPENAI.value) - - provider = None - if openai_provider: - provider = openai_provider - elif azure_openai_provider: - provider = azure_openai_provider - - if not provider: - raise ValidateFailedError(f"OpenAI or Azure OpenAI provider must be configured first.") - - if provider.provider_type == ProviderType.SYSTEM.value: - quota_used = provider.quota_used if provider.quota_used is not None else 0 - quota_limit = provider.quota_limit if provider.quota_limit is not None else 0 - if quota_used >= quota_limit: - raise ValidateFailedError(f"Your quota for Dify Hosted OpenAI has been exhausted, " - f"please configure OpenAI or Azure OpenAI provider first.") - - try: - if not isinstance(config, dict): - raise ValueError('Config must be a object.') - - if 'anthropic_api_key' not in config: - raise ValueError('anthropic_api_key must be provided.') - - chat_llm = ChatAnthropic( - model='claude-instant-1', - anthropic_api_key=config['anthropic_api_key'], - max_tokens_to_sample=10, - temperature=0, - default_request_timeout=60 - ) - - messages = [ - HumanMessage( - content="ping" - ) - ] - - chat_llm(messages) - except anthropic.APIConnectionError as ex: - raise ValidateFailedError(f"Anthropic: Connection error, cause: {ex.__cause__}") - except (anthropic.APIStatusError, anthropic.RateLimitError) as ex: - raise ValidateFailedError(f"Anthropic: Error code: {ex.status_code} - " - f"{ex.body['error']['type']}: {ex.body['error']['message']}") - except Exception as ex: - logging.exception('Anthropic config validation failed') - raise ex - - def get_hosted_credentials(self) -> Union[str | dict]: - if not hosted_llm_credentials.anthropic or not hosted_llm_credentials.anthropic.api_key: - raise ProviderTokenNotInitError( - f"No valid {self.get_provider_name().value} model provider credentials found. " - f"Please go to Settings -> Model Provider to complete your provider credentials." - ) - - return {'anthropic_api_key': hosted_llm_credentials.anthropic.api_key} diff --git a/api/core/llm/provider/azure_provider.py b/api/core/llm/provider/azure_provider.py deleted file mode 100644 index 8d63450622..0000000000 --- a/api/core/llm/provider/azure_provider.py +++ /dev/null @@ -1,145 +0,0 @@ -import json -import logging -from typing import Optional, Union - -import openai -import requests - -from core.llm.provider.base import BaseProvider -from core.llm.provider.errors import ValidateFailedError -from models.provider import ProviderName - - -AZURE_OPENAI_API_VERSION = '2023-07-01-preview' - - -class AzureProvider(BaseProvider): - def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]: - return [] - - def check_embedding_model(self, credentials: Optional[dict] = None): - credentials = self.get_credentials('text-embedding-ada-002') if not credentials else credentials - try: - result = openai.Embedding.create(input=['test'], - engine='text-embedding-ada-002', - timeout=60, - api_key=str(credentials.get('openai_api_key')), - api_base=str(credentials.get('openai_api_base')), - api_type='azure', - api_version=str(credentials.get('openai_api_version')))["data"][0][ - "embedding"] - except openai.error.AuthenticationError as e: - raise AzureAuthenticationError(str(e)) - except openai.error.APIConnectionError as e: - raise AzureRequestFailedError( - 'Failed to request Azure OpenAI, please check your API Base Endpoint, The format is `https://xxx.openai.azure.com/`') - except openai.error.InvalidRequestError as e: - if e.http_status == 404: - raise AzureRequestFailedError("Please check your 'gpt-3.5-turbo' or 'text-embedding-ada-002' " - "deployment name is exists in Azure AI") - else: - raise AzureRequestFailedError( - 'Failed to request Azure OpenAI. cause: {}'.format(str(e))) - except openai.error.OpenAIError as e: - raise AzureRequestFailedError( - 'Failed to request Azure OpenAI. cause: {}'.format(str(e))) - - if not isinstance(result, list): - raise AzureRequestFailedError('Failed to request Azure OpenAI.') - - def get_credentials(self, model_id: Optional[str] = None) -> dict: - """ - Returns the API credentials for Azure OpenAI as a dictionary. - """ - config = self.get_provider_api_key(model_id=model_id) - config['openai_api_type'] = 'azure' - config['openai_api_version'] = AZURE_OPENAI_API_VERSION - if model_id == 'text-embedding-ada-002': - config['deployment'] = model_id.replace('.', '') if model_id else None - config['chunk_size'] = 16 - else: - config['deployment_name'] = model_id.replace('.', '') if model_id else None - return config - - def get_provider_name(self): - return ProviderName.AZURE_OPENAI - - def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]: - """ - Returns the provider configs. - """ - try: - config = self.get_provider_api_key(only_custom=only_custom) - except: - config = { - 'openai_api_type': 'azure', - 'openai_api_version': AZURE_OPENAI_API_VERSION, - 'openai_api_base': '', - 'openai_api_key': '' - } - - if obfuscated: - if not config.get('openai_api_key'): - config = { - 'openai_api_type': 'azure', - 'openai_api_version': AZURE_OPENAI_API_VERSION, - 'openai_api_base': '', - 'openai_api_key': '' - } - - config['openai_api_key'] = self.obfuscated_token(config.get('openai_api_key')) - return config - - return config - - def get_token_type(self): - return dict - - def config_validate(self, config: Union[dict | str]): - """ - Validates the given config. - """ - try: - if not isinstance(config, dict): - raise ValueError('Config must be a object.') - - if 'openai_api_version' not in config: - config['openai_api_version'] = AZURE_OPENAI_API_VERSION - - self.check_embedding_model(credentials=config) - except ValidateFailedError as e: - raise e - except AzureAuthenticationError: - raise ValidateFailedError('Validation failed, please check your API Key.') - except AzureRequestFailedError as ex: - raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex))) - except Exception as ex: - logging.exception('Azure OpenAI Credentials validation failed') - raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex))) - - def get_encrypted_token(self, config: Union[dict | str]): - """ - Returns the encrypted token. - """ - return json.dumps({ - 'openai_api_type': 'azure', - 'openai_api_version': AZURE_OPENAI_API_VERSION, - 'openai_api_base': config['openai_api_base'], - 'openai_api_key': self.encrypt_token(config['openai_api_key']) - }) - - def get_decrypted_token(self, token: str): - """ - Returns the decrypted token. - """ - config = json.loads(token) - config['openai_api_key'] = self.decrypt_token(config['openai_api_key']) - return config - - -class AzureAuthenticationError(Exception): - pass - - -class AzureRequestFailedError(Exception): - pass diff --git a/api/core/llm/provider/base.py b/api/core/llm/provider/base.py deleted file mode 100644 index c3ff5cf237..0000000000 --- a/api/core/llm/provider/base.py +++ /dev/null @@ -1,132 +0,0 @@ -import base64 -from abc import ABC, abstractmethod -from typing import Optional, Union - -from core.constant import llm_constant -from core.llm.error import QuotaExceededError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError -from extensions.ext_database import db -from libs import rsa -from models.account import Tenant -from models.provider import Provider, ProviderType, ProviderName - - -class BaseProvider(ABC): - def __init__(self, tenant_id: str): - self.tenant_id = tenant_id - - def get_provider_api_key(self, model_id: Optional[str] = None, only_custom: bool = False) -> Union[str | dict]: - """ - Returns the decrypted API key for the given tenant_id and provider_name. - If the provider is of type SYSTEM and the quota is exceeded, raises a QuotaExceededError. - If the provider is not found or not valid, raises a ProviderTokenNotInitError. - """ - provider = self.get_provider(only_custom) - if not provider: - raise ProviderTokenNotInitError( - f"No valid {llm_constant.models[model_id]} model provider credentials found. " - f"Please go to Settings -> Model Provider to complete your provider credentials." - ) - - if provider.provider_type == ProviderType.SYSTEM.value: - quota_used = provider.quota_used if provider.quota_used is not None else 0 - quota_limit = provider.quota_limit if provider.quota_limit is not None else 0 - - if model_id and model_id == 'gpt-4': - raise ModelCurrentlyNotSupportError() - - if quota_used >= quota_limit: - raise QuotaExceededError() - - return self.get_hosted_credentials() - else: - return self.get_decrypted_token(provider.encrypted_config) - - def get_provider(self, only_custom: bool = False) -> Optional[Provider]: - """ - Returns the Provider instance for the given tenant_id and provider_name. - If both CUSTOM and System providers exist, the preferred provider will be returned based on the prefer_custom flag. - """ - return BaseProvider.get_valid_provider(self.tenant_id, self.get_provider_name().value, only_custom) - - @classmethod - def get_valid_provider(cls, tenant_id: str, provider_name: str = None, only_custom: bool = False) -> Optional[ - Provider]: - """ - Returns the Provider instance for the given tenant_id and provider_name. - If both CUSTOM and System providers exist. - """ - query = db.session.query(Provider).filter( - Provider.tenant_id == tenant_id - ) - - if provider_name: - query = query.filter(Provider.provider_name == provider_name) - - if only_custom: - query = query.filter(Provider.provider_type == ProviderType.CUSTOM.value) - - providers = query.order_by(Provider.provider_type.asc()).all() - - for provider in providers: - if provider.provider_type == ProviderType.CUSTOM.value and provider.is_valid and provider.encrypted_config: - return provider - elif provider.provider_type == ProviderType.SYSTEM.value and provider.is_valid: - return provider - - return None - - def get_hosted_credentials(self) -> Union[str | dict]: - raise ProviderTokenNotInitError( - f"No valid {self.get_provider_name().value} model provider credentials found. " - f"Please go to Settings -> Model Provider to complete your provider credentials." - ) - - def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]: - """ - Returns the provider configs. - """ - try: - config = self.get_provider_api_key(only_custom=only_custom) - except: - config = '' - - if obfuscated: - return self.obfuscated_token(config) - - return config - - def obfuscated_token(self, token: str): - return token[:6] + '*' * (len(token) - 8) + token[-2:] - - def get_token_type(self): - return str - - def get_encrypted_token(self, config: Union[dict | str]): - return self.encrypt_token(config) - - def get_decrypted_token(self, token: str): - return self.decrypt_token(token) - - def encrypt_token(self, token): - tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() - encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) - return base64.b64encode(encrypted_token).decode() - - def decrypt_token(self, token): - return rsa.decrypt(base64.b64decode(token), self.tenant_id) - - @abstractmethod - def get_provider_name(self): - raise NotImplementedError - - @abstractmethod - def get_credentials(self, model_id: Optional[str] = None) -> dict: - raise NotImplementedError - - @abstractmethod - def get_models(self, model_id: Optional[str] = None) -> list[dict]: - raise NotImplementedError - - @abstractmethod - def config_validate(self, config: str): - raise NotImplementedError diff --git a/api/core/llm/provider/errors.py b/api/core/llm/provider/errors.py deleted file mode 100644 index 407b7f7906..0000000000 --- a/api/core/llm/provider/errors.py +++ /dev/null @@ -1,2 +0,0 @@ -class ValidateFailedError(Exception): - description = "Provider Validate failed" diff --git a/api/core/llm/provider/huggingface_provider.py b/api/core/llm/provider/huggingface_provider.py deleted file mode 100644 index b3dd3ed573..0000000000 --- a/api/core/llm/provider/huggingface_provider.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Optional - -from core.llm.provider.base import BaseProvider -from models.provider import ProviderName - - -class HuggingfaceProvider(BaseProvider): - def get_models(self, model_id: Optional[str] = None) -> list[dict]: - credentials = self.get_credentials(model_id) - # todo - return [] - - def get_credentials(self, model_id: Optional[str] = None) -> dict: - """ - Returns the API credentials for Huggingface as a dictionary, for the given tenant_id. - """ - return { - 'huggingface_api_key': self.get_provider_api_key(model_id=model_id) - } - - def get_provider_name(self): - return ProviderName.HUGGINGFACEHUB \ No newline at end of file diff --git a/api/core/llm/provider/llm_provider_service.py b/api/core/llm/provider/llm_provider_service.py deleted file mode 100644 index a520e3d6bb..0000000000 --- a/api/core/llm/provider/llm_provider_service.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import Optional, Union - -from core.llm.provider.anthropic_provider import AnthropicProvider -from core.llm.provider.azure_provider import AzureProvider -from core.llm.provider.base import BaseProvider -from core.llm.provider.huggingface_provider import HuggingfaceProvider -from core.llm.provider.openai_provider import OpenAIProvider -from models.provider import Provider - - -class LLMProviderService: - - def __init__(self, tenant_id: str, provider_name: str): - self.provider = self.init_provider(tenant_id, provider_name) - - def init_provider(self, tenant_id: str, provider_name: str) -> BaseProvider: - if provider_name == 'openai': - return OpenAIProvider(tenant_id) - elif provider_name == 'azure_openai': - return AzureProvider(tenant_id) - elif provider_name == 'anthropic': - return AnthropicProvider(tenant_id) - elif provider_name == 'huggingface': - return HuggingfaceProvider(tenant_id) - else: - raise Exception('provider {} not found'.format(provider_name)) - - def get_models(self, model_id: Optional[str] = None) -> list[dict]: - return self.provider.get_models(model_id) - - def get_credentials(self, model_id: Optional[str] = None) -> dict: - return self.provider.get_credentials(model_id) - - def get_provider_configs(self, obfuscated: bool = False, only_custom: bool = False) -> Union[str | dict]: - return self.provider.get_provider_configs(obfuscated=obfuscated, only_custom=only_custom) - - def get_provider_db_record(self) -> Optional[Provider]: - return self.provider.get_provider() - - def config_validate(self, config: Union[dict | str]): - """ - Validates the given config. - - :param config: - :raises: ValidateFailedError - """ - return self.provider.config_validate(config) - - def get_token_type(self): - return self.provider.get_token_type() - - def get_encrypted_token(self, config: Union[dict | str]): - return self.provider.get_encrypted_token(config) diff --git a/api/core/llm/provider/openai_provider.py b/api/core/llm/provider/openai_provider.py deleted file mode 100644 index b24e98e5d1..0000000000 --- a/api/core/llm/provider/openai_provider.py +++ /dev/null @@ -1,55 +0,0 @@ -import logging -from typing import Optional, Union - -import openai -from openai.error import AuthenticationError, OpenAIError - -from core import hosted_llm_credentials -from core.llm.error import ProviderTokenNotInitError -from core.llm.moderation import Moderation -from core.llm.provider.base import BaseProvider -from core.llm.provider.errors import ValidateFailedError -from models.provider import ProviderName - - -class OpenAIProvider(BaseProvider): - def get_models(self, model_id: Optional[str] = None) -> list[dict]: - credentials = self.get_credentials(model_id) - response = openai.Model.list(**credentials) - - return [{ - 'id': model['id'], - 'name': model['id'], - } for model in response['data']] - - def get_credentials(self, model_id: Optional[str] = None) -> dict: - """ - Returns the credentials for the given tenant_id and provider_name. - """ - return { - 'openai_api_key': self.get_provider_api_key(model_id=model_id) - } - - def get_provider_name(self): - return ProviderName.OPENAI - - def config_validate(self, config: Union[dict | str]): - """ - Validates the given config. - """ - try: - Moderation(self.get_provider_name().value, config).moderate('test') - except (AuthenticationError, OpenAIError) as ex: - raise ValidateFailedError(str(ex)) - except Exception as ex: - logging.exception('OpenAI config validation failed') - raise ex - - def get_hosted_credentials(self) -> Union[str | dict]: - if not hosted_llm_credentials.openai or not hosted_llm_credentials.openai.api_key: - raise ProviderTokenNotInitError( - f"No valid {self.get_provider_name().value} model provider credentials found. " - f"Please go to Settings -> Model Provider to complete your provider credentials." - ) - - return hosted_llm_credentials.openai.api_key diff --git a/api/core/llm/streamable_chat_anthropic.py b/api/core/llm/streamable_chat_anthropic.py deleted file mode 100644 index 9b94227912..0000000000 --- a/api/core/llm/streamable_chat_anthropic.py +++ /dev/null @@ -1,62 +0,0 @@ -from typing import List, Optional, Any, Dict - -from httpx import Timeout -from langchain.callbacks.manager import Callbacks -from langchain.chat_models import ChatAnthropic -from langchain.schema import BaseMessage, LLMResult, SystemMessage, AIMessage, HumanMessage, ChatMessage -from pydantic import root_validator - -from core.llm.wrappers.anthropic_wrapper import handle_anthropic_exceptions - - -class StreamableChatAnthropic(ChatAnthropic): - """ - Wrapper around Anthropic's large language model. - """ - - default_request_timeout: Optional[float] = Timeout(timeout=300.0, connect=5.0) - - @root_validator() - def prepare_params(cls, values: Dict) -> Dict: - values['model_name'] = values.get('model') - values['max_tokens'] = values.get('max_tokens_to_sample') - return values - - @handle_anthropic_exceptions - def generate( - self, - messages: List[List[BaseMessage]], - stop: Optional[List[str]] = None, - callbacks: Callbacks = None, - *, - tags: Optional[List[str]] = None, - metadata: Optional[Dict[str, Any]] = None, - **kwargs: Any, - ) -> LLMResult: - return super().generate(messages, stop, callbacks, tags=tags, metadata=metadata, **kwargs) - - @classmethod - def get_kwargs_from_model_params(cls, params: dict): - params['model'] = params.get('model_name') - del params['model_name'] - - params['max_tokens_to_sample'] = params.get('max_tokens') - del params['max_tokens'] - - del params['frequency_penalty'] - del params['presence_penalty'] - - return params - - def _convert_one_message_to_text(self, message: BaseMessage) -> str: - if isinstance(message, ChatMessage): - message_text = f"\n\n{message.role.capitalize()}: {message.content}" - elif isinstance(message, HumanMessage): - message_text = f"{self.HUMAN_PROMPT} {message.content}" - elif isinstance(message, AIMessage): - message_text = f"{self.AI_PROMPT} {message.content}" - elif isinstance(message, SystemMessage): - message_text = f"{message.content}" - else: - raise ValueError(f"Got unknown type {message}") - return message_text \ No newline at end of file diff --git a/api/core/llm/token_calculator.py b/api/core/llm/token_calculator.py deleted file mode 100644 index e45f2b4d62..0000000000 --- a/api/core/llm/token_calculator.py +++ /dev/null @@ -1,41 +0,0 @@ -import decimal -from typing import Optional - -import tiktoken - -from core.constant import llm_constant - - -class TokenCalculator: - @classmethod - def get_num_tokens(cls, model_name: str, text: str): - if len(text) == 0: - return 0 - - enc = tiktoken.encoding_for_model(model_name) - - tokenized_text = enc.encode(text) - - # calculate the number of tokens in the encoded text - return len(tokenized_text) - - @classmethod - def get_token_price(cls, model_name: str, tokens: int, text_type: Optional[str] = None) -> decimal.Decimal: - if model_name in llm_constant.models_by_mode['embedding']: - unit_price = llm_constant.model_prices[model_name]['usage'] - elif text_type == 'prompt': - unit_price = llm_constant.model_prices[model_name]['prompt'] - elif text_type == 'completion': - unit_price = llm_constant.model_prices[model_name]['completion'] - else: - raise Exception('Invalid text type') - - tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), - rounding=decimal.ROUND_HALF_UP) - - total_price = tokens_per_1k * unit_price - return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) - - @classmethod - def get_currency(cls, model_name: str): - return llm_constant.model_currency diff --git a/api/core/llm/whisper.py b/api/core/llm/whisper.py deleted file mode 100644 index 7f3bf3d794..0000000000 --- a/api/core/llm/whisper.py +++ /dev/null @@ -1,26 +0,0 @@ -import openai - -from core.llm.wrappers.openai_wrapper import handle_openai_exceptions -from models.provider import ProviderName -from core.llm.provider.base import BaseProvider - - -class Whisper: - - def __init__(self, provider: BaseProvider): - self.provider = provider - - if self.provider.get_provider_name() == ProviderName.OPENAI: - self.client = openai.Audio - self.credentials = provider.get_credentials() - - @handle_openai_exceptions - def transcribe(self, file): - return self.client.transcribe( - model='whisper-1', - file=file, - api_key=self.credentials.get('openai_api_key'), - api_base=self.credentials.get('openai_api_base'), - api_type=self.credentials.get('openai_api_type'), - api_version=self.credentials.get('openai_api_version'), - ) diff --git a/api/core/llm/wrappers/anthropic_wrapper.py b/api/core/llm/wrappers/anthropic_wrapper.py deleted file mode 100644 index 7fddc277d2..0000000000 --- a/api/core/llm/wrappers/anthropic_wrapper.py +++ /dev/null @@ -1,27 +0,0 @@ -import logging -from functools import wraps - -import anthropic - -from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \ - LLMBadRequestError - - -def handle_anthropic_exceptions(func): - @wraps(func) - def wrapper(*args, **kwargs): - try: - return func(*args, **kwargs) - except anthropic.APIConnectionError as e: - logging.exception("Failed to connect to Anthropic API.") - raise LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {e.__cause__}") - except anthropic.RateLimitError: - raise LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.") - except anthropic.AuthenticationError as e: - raise LLMAuthorizationError(f"Anthropic: {e.message}") - except anthropic.BadRequestError as e: - raise LLMBadRequestError(f"Anthropic: {e.message}") - except anthropic.APIStatusError as e: - raise LLMAPIUnavailableError(f"Anthropic: code: {e.status_code}, cause: {e.message}") - - return wrapper diff --git a/api/core/llm/wrappers/openai_wrapper.py b/api/core/llm/wrappers/openai_wrapper.py deleted file mode 100644 index 7f96e75edf..0000000000 --- a/api/core/llm/wrappers/openai_wrapper.py +++ /dev/null @@ -1,31 +0,0 @@ -import logging -from functools import wraps - -import openai - -from core.llm.error import LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, \ - LLMBadRequestError - - -def handle_openai_exceptions(func): - @wraps(func) - def wrapper(*args, **kwargs): - try: - return func(*args, **kwargs) - except openai.error.InvalidRequestError as e: - logging.exception("Invalid request to OpenAI API.") - raise LLMBadRequestError(str(e)) - except openai.error.APIConnectionError as e: - logging.exception("Failed to connect to OpenAI API.") - raise LLMAPIConnectionError(e.__class__.__name__ + ":" + str(e)) - except (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout) as e: - logging.exception("OpenAI service unavailable.") - raise LLMAPIUnavailableError(e.__class__.__name__ + ":" + str(e)) - except openai.error.RateLimitError as e: - raise LLMRateLimitError(str(e)) - except openai.error.AuthenticationError as e: - raise LLMAuthorizationError(str(e)) - except openai.error.OpenAIError as e: - raise LLMBadRequestError(e.__class__.__name__ + ":" + str(e)) - - return wrapper diff --git a/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py b/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py index d96187ece0..55d70d38ad 100644 --- a/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py +++ b/api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py @@ -1,10 +1,10 @@ -from typing import Any, List, Dict, Union +from typing import Any, List, Dict from langchain.memory.chat_memory import BaseChatMemory -from langchain.schema import get_buffer_string, BaseMessage, HumanMessage, AIMessage, BaseLanguageModel +from langchain.schema import get_buffer_string, BaseMessage -from core.llm.streamable_chat_open_ai import StreamableChatOpenAI -from core.llm.streamable_open_ai import StreamableOpenAI +from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages +from core.model_providers.models.llm.base import BaseLLM from extensions.ext_database import db from models.model import Conversation, Message @@ -13,7 +13,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory): conversation: Conversation human_prefix: str = "Human" ai_prefix: str = "Assistant" - llm: BaseLanguageModel + model_instance: BaseLLM memory_key: str = "chat_history" max_token_limit: int = 2000 message_limit: int = 10 @@ -29,23 +29,23 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory): messages = list(reversed(messages)) - chat_messages: List[BaseMessage] = [] + chat_messages: List[PromptMessage] = [] for message in messages: - chat_messages.append(HumanMessage(content=message.query)) - chat_messages.append(AIMessage(content=message.answer)) + chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN)) + chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT)) if not chat_messages: - return chat_messages + return [] # prune the chat message if it exceeds the max token limit - curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages) + curr_buffer_length = self.model_instance.get_num_tokens(chat_messages) if curr_buffer_length > self.max_token_limit: pruned_memory = [] while curr_buffer_length > self.max_token_limit and chat_messages: pruned_memory.append(chat_messages.pop(0)) - curr_buffer_length = self.llm.get_num_tokens_from_messages(chat_messages) + curr_buffer_length = self.model_instance.get_num_tokens(chat_messages) - return chat_messages + return to_lc_messages(chat_messages) @property def memory_variables(self) -> List[str]: diff --git a/api/core/llm/error.py b/api/core/model_providers/error.py similarity index 100% rename from api/core/llm/error.py rename to api/core/model_providers/error.py diff --git a/api/core/model_providers/model_factory.py b/api/core/model_providers/model_factory.py new file mode 100644 index 0000000000..b76a640256 --- /dev/null +++ b/api/core/model_providers/model_factory.py @@ -0,0 +1,293 @@ +from typing import Optional + +from langchain.callbacks.base import Callbacks + +from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError +from core.model_providers.model_provider_factory import ModelProviderFactory, DEFAULT_MODELS +from core.model_providers.models.base import BaseProviderModel +from core.model_providers.models.embedding.base import BaseEmbedding +from core.model_providers.models.entity.model_params import ModelKwargs, ModelType +from core.model_providers.models.llm.base import BaseLLM +from core.model_providers.models.speech2text.base import BaseSpeech2Text +from extensions.ext_database import db +from models.provider import TenantDefaultModel + + +class ModelFactory: + + @classmethod + def get_text_generation_model_from_model_config(cls, tenant_id: str, + model_config: dict, + streaming: bool = False, + callbacks: Callbacks = None) -> Optional[BaseLLM]: + provider_name = model_config.get("provider") + model_name = model_config.get("name") + completion_params = model_config.get("completion_params", {}) + + return cls.get_text_generation_model( + tenant_id=tenant_id, + model_provider_name=provider_name, + model_name=model_name, + model_kwargs=ModelKwargs( + temperature=completion_params.get('temperature', 0), + max_tokens=completion_params.get('max_tokens', 256), + top_p=completion_params.get('top_p', 0), + frequency_penalty=completion_params.get('frequency_penalty', 0.1), + presence_penalty=completion_params.get('presence_penalty', 0.1) + ), + streaming=streaming, + callbacks=callbacks + ) + + @classmethod + def get_text_generation_model(cls, + tenant_id: str, + model_provider_name: Optional[str] = None, + model_name: Optional[str] = None, + model_kwargs: Optional[ModelKwargs] = None, + streaming: bool = False, + callbacks: Callbacks = None) -> Optional[BaseLLM]: + """ + get text generation model. + + :param tenant_id: a string representing the ID of the tenant. + :param model_provider_name: + :param model_name: + :param model_kwargs: + :param streaming: + :param callbacks: + :return: + """ + is_default_model = False + if model_provider_name is None and model_name is None: + default_model = cls.get_default_model(tenant_id, ModelType.TEXT_GENERATION) + + if not default_model: + raise LLMBadRequestError(f"Default model is not available. " + f"Please configure a Default System Reasoning Model " + f"in the Settings -> Model Provider.") + + model_provider_name = default_model.provider_name + model_name = default_model.model_name + is_default_model = True + + # get model provider + model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) + + if not model_provider: + raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") + + # init text generation model + model_class = model_provider.get_model_class(model_type=ModelType.TEXT_GENERATION) + + try: + model_instance = model_class( + model_provider=model_provider, + name=model_name, + model_kwargs=model_kwargs, + streaming=streaming, + callbacks=callbacks + ) + except LLMBadRequestError as e: + if is_default_model: + raise LLMBadRequestError(f"Default model {model_name} is not available. " + f"Please check your model provider credentials.") + else: + raise e + + if is_default_model: + model_instance.deduct_quota = False + + return model_instance + + @classmethod + def get_embedding_model(cls, + tenant_id: str, + model_provider_name: Optional[str] = None, + model_name: Optional[str] = None) -> Optional[BaseEmbedding]: + """ + get embedding model. + + :param tenant_id: a string representing the ID of the tenant. + :param model_provider_name: + :param model_name: + :return: + """ + if model_provider_name is None and model_name is None: + default_model = cls.get_default_model(tenant_id, ModelType.EMBEDDINGS) + + if not default_model: + raise LLMBadRequestError(f"Default model is not available. " + f"Please configure a Default Embedding Model " + f"in the Settings -> Model Provider.") + + model_provider_name = default_model.provider_name + model_name = default_model.model_name + + # get model provider + model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) + + if not model_provider: + raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") + + # init embedding model + model_class = model_provider.get_model_class(model_type=ModelType.EMBEDDINGS) + return model_class( + model_provider=model_provider, + name=model_name + ) + + @classmethod + def get_speech2text_model(cls, + tenant_id: str, + model_provider_name: Optional[str] = None, + model_name: Optional[str] = None) -> Optional[BaseSpeech2Text]: + """ + get speech to text model. + + :param tenant_id: a string representing the ID of the tenant. + :param model_provider_name: + :param model_name: + :return: + """ + if model_provider_name is None and model_name is None: + default_model = cls.get_default_model(tenant_id, ModelType.SPEECH_TO_TEXT) + + if not default_model: + raise LLMBadRequestError(f"Default model is not available. " + f"Please configure a Default Speech-to-Text Model " + f"in the Settings -> Model Provider.") + + model_provider_name = default_model.provider_name + model_name = default_model.model_name + + # get model provider + model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) + + if not model_provider: + raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") + + # init speech to text model + model_class = model_provider.get_model_class(model_type=ModelType.SPEECH_TO_TEXT) + return model_class( + model_provider=model_provider, + name=model_name + ) + + @classmethod + def get_moderation_model(cls, + tenant_id: str, + model_provider_name: str, + model_name: str) -> Optional[BaseProviderModel]: + """ + get moderation model. + + :param tenant_id: a string representing the ID of the tenant. + :param model_provider_name: + :param model_name: + :return: + """ + # get model provider + model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) + + if not model_provider: + raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") + + # init moderation model + model_class = model_provider.get_model_class(model_type=ModelType.MODERATION) + return model_class( + model_provider=model_provider, + name=model_name + ) + + @classmethod + def get_default_model(cls, tenant_id: str, model_type: ModelType) -> TenantDefaultModel: + """ + get default model of model type. + + :param tenant_id: + :param model_type: + :return: + """ + # get default model + default_model = db.session.query(TenantDefaultModel) \ + .filter( + TenantDefaultModel.tenant_id == tenant_id, + TenantDefaultModel.model_type == model_type.value + ).first() + + if not default_model: + model_provider_rules = ModelProviderFactory.get_provider_rules() + for model_provider_name, model_provider_rule in model_provider_rules.items(): + model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) + if not model_provider: + continue + + model_list = model_provider.get_supported_model_list(model_type) + if model_list: + model_info = model_list[0] + default_model = TenantDefaultModel( + tenant_id=tenant_id, + model_type=model_type.value, + provider_name=model_provider_name, + model_name=model_info['id'] + ) + db.session.add(default_model) + db.session.commit() + break + + return default_model + + @classmethod + def update_default_model(cls, + tenant_id: str, + model_type: ModelType, + provider_name: str, + model_name: str) -> TenantDefaultModel: + """ + update default model of model type. + + :param tenant_id: + :param model_type: + :param provider_name: + :param model_name: + :return: + """ + model_provider_name = ModelProviderFactory.get_provider_names() + if provider_name not in model_provider_name: + raise ValueError(f'Invalid provider name: {provider_name}') + + model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, provider_name) + + if not model_provider: + raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.") + + model_list = model_provider.get_supported_model_list(model_type) + model_ids = [model['id'] for model in model_list] + if model_name not in model_ids: + raise ValueError(f'Invalid model name: {model_name}') + + # get default model + default_model = db.session.query(TenantDefaultModel) \ + .filter( + TenantDefaultModel.tenant_id == tenant_id, + TenantDefaultModel.model_type == model_type.value + ).first() + + if default_model: + # update default model + default_model.provider_name = provider_name + default_model.model_name = model_name + db.session.commit() + else: + # create default model + default_model = TenantDefaultModel( + tenant_id=tenant_id, + model_type=model_type.value, + provider_name=provider_name, + model_name=model_name, + ) + db.session.add(default_model) + db.session.commit() + + return default_model diff --git a/api/core/model_providers/model_provider_factory.py b/api/core/model_providers/model_provider_factory.py new file mode 100644 index 0000000000..e2d8b43603 --- /dev/null +++ b/api/core/model_providers/model_provider_factory.py @@ -0,0 +1,228 @@ +from typing import Type + +from sqlalchemy.exc import IntegrityError + +from core.model_providers.models.entity.model_params import ModelType +from core.model_providers.providers.base import BaseModelProvider +from core.model_providers.rules import provider_rules +from extensions.ext_database import db +from models.provider import TenantPreferredModelProvider, ProviderType, Provider, ProviderQuotaType + +DEFAULT_MODELS = { + ModelType.TEXT_GENERATION.value: { + 'provider_name': 'openai', + 'model_name': 'gpt-3.5-turbo', + }, + ModelType.EMBEDDINGS.value: { + 'provider_name': 'openai', + 'model_name': 'text-embedding-ada-002', + }, + ModelType.SPEECH_TO_TEXT.value: { + 'provider_name': 'openai', + 'model_name': 'whisper-1', + } +} + + +class ModelProviderFactory: + @classmethod + def get_model_provider_class(cls, provider_name: str) -> Type[BaseModelProvider]: + if provider_name == 'openai': + from core.model_providers.providers.openai_provider import OpenAIProvider + return OpenAIProvider + elif provider_name == 'anthropic': + from core.model_providers.providers.anthropic_provider import AnthropicProvider + return AnthropicProvider + elif provider_name == 'minimax': + from core.model_providers.providers.minimax_provider import MinimaxProvider + return MinimaxProvider + elif provider_name == 'spark': + from core.model_providers.providers.spark_provider import SparkProvider + return SparkProvider + elif provider_name == 'tongyi': + from core.model_providers.providers.tongyi_provider import TongyiProvider + return TongyiProvider + elif provider_name == 'wenxin': + from core.model_providers.providers.wenxin_provider import WenxinProvider + return WenxinProvider + elif provider_name == 'chatglm': + from core.model_providers.providers.chatglm_provider import ChatGLMProvider + return ChatGLMProvider + elif provider_name == 'azure_openai': + from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider + return AzureOpenAIProvider + elif provider_name == 'replicate': + from core.model_providers.providers.replicate_provider import ReplicateProvider + return ReplicateProvider + elif provider_name == 'huggingface_hub': + from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider + return HuggingfaceHubProvider + else: + raise NotImplementedError + + @classmethod + def get_provider_names(cls): + """ + Returns a list of provider names. + """ + return list(provider_rules.keys()) + + @classmethod + def get_provider_rules(cls): + """ + Returns a list of provider rules. + + :return: + """ + return provider_rules + + @classmethod + def get_provider_rule(cls, provider_name: str): + """ + Returns provider rule. + """ + return provider_rules[provider_name] + + @classmethod + def get_preferred_model_provider(cls, tenant_id: str, model_provider_name: str): + """ + get preferred model provider. + + :param tenant_id: a string representing the ID of the tenant. + :param model_provider_name: + :return: + """ + # get preferred provider + preferred_provider = cls._get_preferred_provider(tenant_id, model_provider_name) + if not preferred_provider or not preferred_provider.is_valid: + return None + + # init model provider + model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name) + return model_provider_class(provider=preferred_provider) + + @classmethod + def get_preferred_type_by_preferred_model_provider(cls, + tenant_id: str, + model_provider_name: str, + preferred_model_provider: TenantPreferredModelProvider): + """ + get preferred provider type by preferred model provider. + + :param model_provider_name: + :param preferred_model_provider: + :return: + """ + if not preferred_model_provider: + model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name) + support_provider_types = model_provider_rules['support_provider_types'] + + if ProviderType.CUSTOM.value in support_provider_types: + custom_provider = db.session.query(Provider) \ + .filter( + Provider.tenant_id == tenant_id, + Provider.provider_name == model_provider_name, + Provider.provider_type == ProviderType.CUSTOM.value, + Provider.is_valid == True + ).first() + + if custom_provider: + return ProviderType.CUSTOM.value + + model_provider = cls.get_model_provider_class(model_provider_name) + + if ProviderType.SYSTEM.value in support_provider_types \ + and model_provider.is_provider_type_system_supported(): + return ProviderType.SYSTEM.value + elif ProviderType.CUSTOM.value in support_provider_types: + return ProviderType.CUSTOM.value + else: + return preferred_model_provider.preferred_provider_type + + @classmethod + def _get_preferred_provider(cls, tenant_id: str, model_provider_name: str): + """ + get preferred provider of tenant. + + :param tenant_id: + :param model_provider_name: + :return: + """ + # get preferred provider type + preferred_provider_type = cls._get_preferred_provider_type(tenant_id, model_provider_name) + + # get providers by preferred provider type + providers = db.session.query(Provider) \ + .filter( + Provider.tenant_id == tenant_id, + Provider.provider_name == model_provider_name, + Provider.provider_type == preferred_provider_type + ).all() + + no_system_provider = False + if preferred_provider_type == ProviderType.SYSTEM.value: + quota_type_to_provider_dict = {} + for provider in providers: + quota_type_to_provider_dict[provider.quota_type] = provider + + model_provider_rules = ModelProviderFactory.get_provider_rule(model_provider_name) + for quota_type_enum in ProviderQuotaType: + quota_type = quota_type_enum.value + if quota_type in model_provider_rules['system_config']['supported_quota_types'] \ + and quota_type in quota_type_to_provider_dict.keys(): + provider = quota_type_to_provider_dict[quota_type] + if provider.is_valid and provider.quota_limit > provider.quota_used: + return provider + + no_system_provider = True + + if no_system_provider: + providers = db.session.query(Provider) \ + .filter( + Provider.tenant_id == tenant_id, + Provider.provider_name == model_provider_name, + Provider.provider_type == ProviderType.CUSTOM.value + ).all() + + if preferred_provider_type == ProviderType.CUSTOM.value or no_system_provider: + if providers: + return providers[0] + else: + try: + provider = Provider( + tenant_id=tenant_id, + provider_name=model_provider_name, + provider_type=ProviderType.CUSTOM.value, + is_valid=False + ) + db.session.add(provider) + db.session.commit() + except IntegrityError: + db.session.rollback() + provider = db.session.query(Provider) \ + .filter( + Provider.tenant_id == tenant_id, + Provider.provider_name == model_provider_name, + Provider.provider_type == ProviderType.CUSTOM.value + ).first() + + return provider + + return None + + @classmethod + def _get_preferred_provider_type(cls, tenant_id: str, model_provider_name: str): + """ + get preferred provider type of tenant. + + :param tenant_id: + :param model_provider_name: + :return: + """ + preferred_model_provider = db.session.query(TenantPreferredModelProvider) \ + .filter( + TenantPreferredModelProvider.tenant_id == tenant_id, + TenantPreferredModelProvider.provider_name == model_provider_name + ).first() + + return cls.get_preferred_type_by_preferred_model_provider(tenant_id, model_provider_name, preferred_model_provider) diff --git a/api/tests/test_libs/__init__.py b/api/core/model_providers/models/__init__.py similarity index 100% rename from api/tests/test_libs/__init__.py rename to api/core/model_providers/models/__init__.py diff --git a/api/core/model_providers/models/base.py b/api/core/model_providers/models/base.py new file mode 100644 index 0000000000..01f83efa84 --- /dev/null +++ b/api/core/model_providers/models/base.py @@ -0,0 +1,22 @@ +from abc import ABC +from typing import Any + +from core.model_providers.providers.base import BaseModelProvider + + +class BaseProviderModel(ABC): + _client: Any + _model_provider: BaseModelProvider + + def __init__(self, model_provider: BaseModelProvider, client: Any): + self._model_provider = model_provider + self._client = client + + @property + def client(self): + return self._client + + @property + def model_provider(self): + return self._model_provider + diff --git a/api/tests/test_models/__init__.py b/api/core/model_providers/models/embedding/__init__.py similarity index 100% rename from api/tests/test_models/__init__.py rename to api/core/model_providers/models/embedding/__init__.py diff --git a/api/core/model_providers/models/embedding/azure_openai_embedding.py b/api/core/model_providers/models/embedding/azure_openai_embedding.py new file mode 100644 index 0000000000..81f08784b2 --- /dev/null +++ b/api/core/model_providers/models/embedding/azure_openai_embedding.py @@ -0,0 +1,78 @@ +import decimal +import logging + +import openai +import tiktoken +from langchain.embeddings import OpenAIEmbeddings + +from core.model_providers.error import LLMBadRequestError, LLMAuthorizationError, LLMRateLimitError, \ + LLMAPIUnavailableError, LLMAPIConnectionError +from core.model_providers.models.embedding.base import BaseEmbedding +from core.model_providers.providers.base import BaseModelProvider + +AZURE_OPENAI_API_VERSION = '2023-07-01-preview' + + +class AzureOpenAIEmbedding(BaseEmbedding): + def __init__(self, model_provider: BaseModelProvider, name: str): + self.credentials = model_provider.get_model_credentials( + model_name=name, + model_type=self.type + ) + + client = OpenAIEmbeddings( + deployment=name, + openai_api_type='azure', + openai_api_version=AZURE_OPENAI_API_VERSION, + chunk_size=16, + max_retries=1, + **self.credentials + ) + + super().__init__(model_provider, client, name) + + def get_num_tokens(self, text: str) -> int: + """ + get num tokens of text. + + :param text: + :return: + """ + if len(text) == 0: + return 0 + + enc = tiktoken.encoding_for_model(self.credentials.get('base_model_name')) + + tokenized_text = enc.encode(text) + + # calculate the number of tokens in the encoded text + return len(tokenized_text) + + def get_token_price(self, tokens: int): + tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), + rounding=decimal.ROUND_HALF_UP) + + total_price = tokens_per_1k * decimal.Decimal('0.0001') + return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) + + def get_currency(self): + return 'USD' + + def handle_exceptions(self, ex: Exception) -> Exception: + if isinstance(ex, openai.error.InvalidRequestError): + logging.warning("Invalid request to Azure OpenAI API.") + return LLMBadRequestError(str(ex)) + elif isinstance(ex, openai.error.APIConnectionError): + logging.warning("Failed to connect to Azure OpenAI API.") + return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex)) + elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)): + logging.warning("Azure OpenAI service unavailable.") + return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex)) + elif isinstance(ex, openai.error.RateLimitError): + return LLMRateLimitError('Azure ' + str(ex)) + elif isinstance(ex, openai.error.AuthenticationError): + raise LLMAuthorizationError('Azure ' + str(ex)) + elif isinstance(ex, openai.error.OpenAIError): + return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex)) + else: + return ex diff --git a/api/core/model_providers/models/embedding/base.py b/api/core/model_providers/models/embedding/base.py new file mode 100644 index 0000000000..fc42d88bcd --- /dev/null +++ b/api/core/model_providers/models/embedding/base.py @@ -0,0 +1,40 @@ +from abc import abstractmethod +from typing import Any + +import tiktoken +from langchain.schema.language_model import _get_token_ids_default_method + +from core.model_providers.models.base import BaseProviderModel +from core.model_providers.models.entity.model_params import ModelType +from core.model_providers.providers.base import BaseModelProvider + + +class BaseEmbedding(BaseProviderModel): + name: str + type: ModelType = ModelType.EMBEDDINGS + + def __init__(self, model_provider: BaseModelProvider, client: Any, name: str): + super().__init__(model_provider, client) + self.name = name + + def get_num_tokens(self, text: str) -> int: + """ + get num tokens of text. + + :param text: + :return: + """ + if len(text) == 0: + return 0 + + return len(_get_token_ids_default_method(text)) + + def get_token_price(self, tokens: int): + return 0 + + def get_currency(self): + return 'USD' + + @abstractmethod + def handle_exceptions(self, ex: Exception) -> Exception: + raise NotImplementedError diff --git a/api/core/model_providers/models/embedding/minimax_embedding.py b/api/core/model_providers/models/embedding/minimax_embedding.py new file mode 100644 index 0000000000..d8cb22f347 --- /dev/null +++ b/api/core/model_providers/models/embedding/minimax_embedding.py @@ -0,0 +1,35 @@ +import decimal +import logging + +from langchain.embeddings import MiniMaxEmbeddings + +from core.model_providers.error import LLMBadRequestError +from core.model_providers.models.embedding.base import BaseEmbedding +from core.model_providers.providers.base import BaseModelProvider + + +class MinimaxEmbedding(BaseEmbedding): + def __init__(self, model_provider: BaseModelProvider, name: str): + credentials = model_provider.get_model_credentials( + model_name=name, + model_type=self.type + ) + + client = MiniMaxEmbeddings( + model=name, + **credentials + ) + + super().__init__(model_provider, client, name) + + def get_token_price(self, tokens: int): + return decimal.Decimal('0') + + def get_currency(self): + return 'RMB' + + def handle_exceptions(self, ex: Exception) -> Exception: + if isinstance(ex, ValueError): + return LLMBadRequestError(f"Minimax: {str(ex)}") + else: + return ex diff --git a/api/core/model_providers/models/embedding/openai_embedding.py b/api/core/model_providers/models/embedding/openai_embedding.py new file mode 100644 index 0000000000..1d7af94fdb --- /dev/null +++ b/api/core/model_providers/models/embedding/openai_embedding.py @@ -0,0 +1,72 @@ +import decimal +import logging + +import openai +import tiktoken +from langchain.embeddings import OpenAIEmbeddings + +from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ + LLMRateLimitError, LLMAuthorizationError +from core.model_providers.models.embedding.base import BaseEmbedding +from core.model_providers.providers.base import BaseModelProvider + + +class OpenAIEmbedding(BaseEmbedding): + def __init__(self, model_provider: BaseModelProvider, name: str): + credentials = model_provider.get_model_credentials( + model_name=name, + model_type=self.type + ) + + client = OpenAIEmbeddings( + max_retries=1, + **credentials + ) + + super().__init__(model_provider, client, name) + + def get_num_tokens(self, text: str) -> int: + """ + get num tokens of text. + + :param text: + :return: + """ + if len(text) == 0: + return 0 + + enc = tiktoken.encoding_for_model(self.name) + + tokenized_text = enc.encode(text) + + # calculate the number of tokens in the encoded text + return len(tokenized_text) + + def get_token_price(self, tokens: int): + tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), + rounding=decimal.ROUND_HALF_UP) + + total_price = tokens_per_1k * decimal.Decimal('0.0001') + return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) + + def get_currency(self): + return 'USD' + + def handle_exceptions(self, ex: Exception) -> Exception: + if isinstance(ex, openai.error.InvalidRequestError): + logging.warning("Invalid request to OpenAI API.") + return LLMBadRequestError(str(ex)) + elif isinstance(ex, openai.error.APIConnectionError): + logging.warning("Failed to connect to OpenAI API.") + return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex)) + elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)): + logging.warning("OpenAI service unavailable.") + return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex)) + elif isinstance(ex, openai.error.RateLimitError): + return LLMRateLimitError(str(ex)) + elif isinstance(ex, openai.error.AuthenticationError): + raise LLMAuthorizationError(str(ex)) + elif isinstance(ex, openai.error.OpenAIError): + return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex)) + else: + return ex diff --git a/api/core/model_providers/models/embedding/replicate_embedding.py b/api/core/model_providers/models/embedding/replicate_embedding.py new file mode 100644 index 0000000000..3f7ef2851d --- /dev/null +++ b/api/core/model_providers/models/embedding/replicate_embedding.py @@ -0,0 +1,36 @@ +import decimal + +from replicate.exceptions import ModelError, ReplicateError + +from core.model_providers.error import LLMBadRequestError +from core.model_providers.providers.base import BaseModelProvider +from core.third_party.langchain.embeddings.replicate_embedding import ReplicateEmbeddings +from core.model_providers.models.embedding.base import BaseEmbedding + + +class ReplicateEmbedding(BaseEmbedding): + def __init__(self, model_provider: BaseModelProvider, name: str): + credentials = model_provider.get_model_credentials( + model_name=name, + model_type=self.type + ) + + client = ReplicateEmbeddings( + model=name + ':' + credentials.get('model_version'), + replicate_api_token=credentials.get('replicate_api_token') + ) + + super().__init__(model_provider, client, name) + + def get_token_price(self, tokens: int): + # replicate only pay for prediction seconds + return decimal.Decimal('0') + + def get_currency(self): + return 'USD' + + def handle_exceptions(self, ex: Exception) -> Exception: + if isinstance(ex, (ModelError, ReplicateError)): + return LLMBadRequestError(f"Replicate: {str(ex)}") + else: + return ex diff --git a/api/tests/test_services/__init__.py b/api/core/model_providers/models/entity/__init__.py similarity index 100% rename from api/tests/test_services/__init__.py rename to api/core/model_providers/models/entity/__init__.py diff --git a/api/core/model_providers/models/entity/message.py b/api/core/model_providers/models/entity/message.py new file mode 100644 index 0000000000..f2fab9c4b7 --- /dev/null +++ b/api/core/model_providers/models/entity/message.py @@ -0,0 +1,53 @@ +import enum + +from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage +from pydantic import BaseModel + + +class LLMRunResult(BaseModel): + content: str + prompt_tokens: int + completion_tokens: int + + +class MessageType(enum.Enum): + HUMAN = 'human' + ASSISTANT = 'assistant' + SYSTEM = 'system' + + +class PromptMessage(BaseModel): + type: MessageType = MessageType.HUMAN + content: str = '' + + +def to_lc_messages(messages: list[PromptMessage]): + lc_messages = [] + for message in messages: + if message.type == MessageType.HUMAN: + lc_messages.append(HumanMessage(content=message.content)) + elif message.type == MessageType.ASSISTANT: + lc_messages.append(AIMessage(content=message.content)) + elif message.type == MessageType.SYSTEM: + lc_messages.append(SystemMessage(content=message.content)) + + return lc_messages + + +def to_prompt_messages(messages: list[BaseMessage]): + prompt_messages = [] + for message in messages: + if isinstance(message, HumanMessage): + prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN)) + elif isinstance(message, AIMessage): + prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT)) + elif isinstance(message, SystemMessage): + prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM)) + return prompt_messages + + +def str_to_prompt_messages(texts: list[str]): + prompt_messages = [] + for text in texts: + prompt_messages.append(PromptMessage(content=text)) + return prompt_messages diff --git a/api/core/model_providers/models/entity/model_params.py b/api/core/model_providers/models/entity/model_params.py new file mode 100644 index 0000000000..2a6a1bc510 --- /dev/null +++ b/api/core/model_providers/models/entity/model_params.py @@ -0,0 +1,59 @@ +import enum +from typing import Optional, TypeVar, Generic + +from langchain.load.serializable import Serializable +from pydantic import BaseModel + + +class ModelMode(enum.Enum): + COMPLETION = 'completion' + CHAT = 'chat' + + +class ModelType(enum.Enum): + TEXT_GENERATION = 'text-generation' + EMBEDDINGS = 'embeddings' + SPEECH_TO_TEXT = 'speech2text' + IMAGE = 'image' + VIDEO = 'video' + MODERATION = 'moderation' + + @staticmethod + def value_of(value): + for member in ModelType: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + +class ModelKwargs(BaseModel): + max_tokens: Optional[int] + temperature: Optional[float] + top_p: Optional[float] + presence_penalty: Optional[float] + frequency_penalty: Optional[float] + + +class KwargRuleType(enum.Enum): + STRING = 'string' + INTEGER = 'integer' + FLOAT = 'float' + + +T = TypeVar('T') + + +class KwargRule(Generic[T], BaseModel): + enabled: bool = True + min: Optional[T] = None + max: Optional[T] = None + default: Optional[T] = None + alias: Optional[str] = None + + +class ModelKwargsRules(BaseModel): + max_tokens: KwargRule = KwargRule[int](enabled=False) + temperature: KwargRule = KwargRule[float](enabled=False) + top_p: KwargRule = KwargRule[float](enabled=False) + presence_penalty: KwargRule = KwargRule[float](enabled=False) + frequency_penalty: KwargRule = KwargRule[float](enabled=False) diff --git a/api/core/model_providers/models/entity/provider.py b/api/core/model_providers/models/entity/provider.py new file mode 100644 index 0000000000..07249eb37b --- /dev/null +++ b/api/core/model_providers/models/entity/provider.py @@ -0,0 +1,10 @@ +from enum import Enum + + +class ProviderQuotaUnit(Enum): + TIMES = 'times' + TOKENS = 'tokens' + + +class ModelFeature(Enum): + AGENT_THOUGHT = 'agent_thought' diff --git a/api/core/model_providers/models/llm/__init__.py b/api/core/model_providers/models/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_providers/models/llm/anthropic_model.py b/api/core/model_providers/models/llm/anthropic_model.py new file mode 100644 index 0000000000..69dd76611f --- /dev/null +++ b/api/core/model_providers/models/llm/anthropic_model.py @@ -0,0 +1,107 @@ +import decimal +import logging +from functools import wraps +from typing import List, Optional, Any + +import anthropic +from langchain.callbacks.manager import Callbacks +from langchain.chat_models import ChatAnthropic +from langchain.schema import LLMResult + +from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ + LLMRateLimitError, LLMAuthorizationError +from core.model_providers.models.llm.base import BaseLLM +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs + + +class AnthropicModel(BaseLLM): + model_mode: ModelMode = ModelMode.CHAT + + def _init_client(self) -> Any: + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) + return ChatAnthropic( + model=self.name, + streaming=self.streaming, + callbacks=self.callbacks, + default_request_timeout=60, + **self.credentials, + **provider_model_kwargs + ) + + def _run(self, messages: List[PromptMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs) -> LLMResult: + """ + run predict by prompt messages and stop words. + + :param messages: + :param stop: + :param callbacks: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return self._client.generate([prompts], stop, callbacks) + + def get_num_tokens(self, messages: List[PromptMessage]) -> int: + """ + get num tokens of prompt messages. + + :param messages: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) + + def get_token_price(self, tokens: int, message_type: MessageType): + model_unit_prices = { + 'claude-instant-1': { + 'prompt': decimal.Decimal('1.63'), + 'completion': decimal.Decimal('5.51'), + }, + 'claude-2': { + 'prompt': decimal.Decimal('11.02'), + 'completion': decimal.Decimal('32.68'), + }, + } + + if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: + unit_price = model_unit_prices[self.name]['prompt'] + else: + unit_price = model_unit_prices[self.name]['completion'] + + tokens_per_1m = (decimal.Decimal(tokens) / 1000000).quantize(decimal.Decimal('0.000001'), + rounding=decimal.ROUND_HALF_UP) + + total_price = tokens_per_1m * unit_price + return total_price.quantize(decimal.Decimal('0.00000001'), rounding=decimal.ROUND_HALF_UP) + + def get_currency(self): + return 'USD' + + def _set_model_kwargs(self, model_kwargs: ModelKwargs): + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) + for k, v in provider_model_kwargs.items(): + if hasattr(self.client, k): + setattr(self.client, k, v) + + def handle_exceptions(self, ex: Exception) -> Exception: + if isinstance(ex, anthropic.APIConnectionError): + logging.warning("Failed to connect to Anthropic API.") + return LLMAPIConnectionError(f"Anthropic: The server could not be reached, cause: {ex.__cause__}") + elif isinstance(ex, anthropic.RateLimitError): + return LLMRateLimitError("Anthropic: A 429 status code was received; we should back off a bit.") + elif isinstance(ex, anthropic.AuthenticationError): + return LLMAuthorizationError(f"Anthropic: {ex.message}") + elif isinstance(ex, anthropic.BadRequestError): + return LLMBadRequestError(f"Anthropic: {ex.message}") + elif isinstance(ex, anthropic.APIStatusError): + return LLMAPIUnavailableError(f"Anthropic: code: {ex.status_code}, cause: {ex.message}") + else: + return ex + + @classmethod + def support_streaming(cls): + return True + diff --git a/api/core/model_providers/models/llm/azure_openai_model.py b/api/core/model_providers/models/llm/azure_openai_model.py new file mode 100644 index 0000000000..b2f6159b4f --- /dev/null +++ b/api/core/model_providers/models/llm/azure_openai_model.py @@ -0,0 +1,177 @@ +import decimal +import logging +from functools import wraps +from typing import List, Optional, Any + +import openai +from langchain.callbacks.manager import Callbacks +from langchain.schema import LLMResult + +from core.model_providers.providers.base import BaseModelProvider +from core.third_party.langchain.llms.azure_chat_open_ai import EnhanceAzureChatOpenAI +from core.third_party.langchain.llms.azure_open_ai import EnhanceAzureOpenAI +from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ + LLMRateLimitError, LLMAuthorizationError +from core.model_providers.models.llm.base import BaseLLM +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs + +AZURE_OPENAI_API_VERSION = '2023-07-01-preview' + + +class AzureOpenAIModel(BaseLLM): + def __init__(self, model_provider: BaseModelProvider, + name: str, + model_kwargs: ModelKwargs, + streaming: bool = False, + callbacks: Callbacks = None): + if name == 'text-davinci-003': + self.model_mode = ModelMode.COMPLETION + else: + self.model_mode = ModelMode.CHAT + + super().__init__(model_provider, name, model_kwargs, streaming, callbacks) + + def _init_client(self) -> Any: + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) + if self.name == 'text-davinci-003': + client = EnhanceAzureOpenAI( + deployment_name=self.name, + streaming=self.streaming, + request_timeout=60, + openai_api_type='azure', + openai_api_version=AZURE_OPENAI_API_VERSION, + openai_api_key=self.credentials.get('openai_api_key'), + openai_api_base=self.credentials.get('openai_api_base'), + callbacks=self.callbacks, + **provider_model_kwargs + ) + else: + extra_model_kwargs = { + 'top_p': provider_model_kwargs.get('top_p'), + 'frequency_penalty': provider_model_kwargs.get('frequency_penalty'), + 'presence_penalty': provider_model_kwargs.get('presence_penalty'), + } + + client = EnhanceAzureChatOpenAI( + deployment_name=self.name, + temperature=provider_model_kwargs.get('temperature'), + max_tokens=provider_model_kwargs.get('max_tokens'), + model_kwargs=extra_model_kwargs, + streaming=self.streaming, + request_timeout=60, + openai_api_type='azure', + openai_api_version=AZURE_OPENAI_API_VERSION, + openai_api_key=self.credentials.get('openai_api_key'), + openai_api_base=self.credentials.get('openai_api_base'), + callbacks=self.callbacks, + ) + + return client + + def _run(self, messages: List[PromptMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs) -> LLMResult: + """ + run predict by prompt messages and stop words. + + :param messages: + :param stop: + :param callbacks: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return self._client.generate([prompts], stop, callbacks) + + def get_num_tokens(self, messages: List[PromptMessage]) -> int: + """ + get num tokens of prompt messages. + + :param messages: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + if isinstance(prompts, str): + return self._client.get_num_tokens(prompts) + else: + return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) + + def get_token_price(self, tokens: int, message_type: MessageType): + model_unit_prices = { + 'gpt-4': { + 'prompt': decimal.Decimal('0.03'), + 'completion': decimal.Decimal('0.06'), + }, + 'gpt-4-32k': { + 'prompt': decimal.Decimal('0.06'), + 'completion': decimal.Decimal('0.12') + }, + 'gpt-35-turbo': { + 'prompt': decimal.Decimal('0.0015'), + 'completion': decimal.Decimal('0.002') + }, + 'gpt-35-turbo-16k': { + 'prompt': decimal.Decimal('0.003'), + 'completion': decimal.Decimal('0.004') + }, + 'text-davinci-003': { + 'prompt': decimal.Decimal('0.02'), + 'completion': decimal.Decimal('0.02') + }, + } + + base_model_name = self.credentials.get("base_model_name") + if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: + unit_price = model_unit_prices[base_model_name]['prompt'] + else: + unit_price = model_unit_prices[base_model_name]['completion'] + + tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), + rounding=decimal.ROUND_HALF_UP) + + total_price = tokens_per_1k * unit_price + return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) + + def get_currency(self): + return 'USD' + + def _set_model_kwargs(self, model_kwargs: ModelKwargs): + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) + if self.name == 'text-davinci-003': + for k, v in provider_model_kwargs.items(): + if hasattr(self.client, k): + setattr(self.client, k, v) + else: + extra_model_kwargs = { + 'top_p': provider_model_kwargs.get('top_p'), + 'frequency_penalty': provider_model_kwargs.get('frequency_penalty'), + 'presence_penalty': provider_model_kwargs.get('presence_penalty'), + } + + self.client.temperature = provider_model_kwargs.get('temperature') + self.client.max_tokens = provider_model_kwargs.get('max_tokens') + self.client.model_kwargs = extra_model_kwargs + + def handle_exceptions(self, ex: Exception) -> Exception: + if isinstance(ex, openai.error.InvalidRequestError): + logging.warning("Invalid request to Azure OpenAI API.") + return LLMBadRequestError(str(ex)) + elif isinstance(ex, openai.error.APIConnectionError): + logging.warning("Failed to connect to Azure OpenAI API.") + return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex)) + elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)): + logging.warning("Azure OpenAI service unavailable.") + return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex)) + elif isinstance(ex, openai.error.RateLimitError): + return LLMRateLimitError('Azure ' + str(ex)) + elif isinstance(ex, openai.error.AuthenticationError): + raise LLMAuthorizationError('Azure ' + str(ex)) + elif isinstance(ex, openai.error.OpenAIError): + return LLMBadRequestError('Azure ' + ex.__class__.__name__ + ":" + str(ex)) + else: + return ex + + @classmethod + def support_streaming(cls): + return True \ No newline at end of file diff --git a/api/core/model_providers/models/llm/base.py b/api/core/model_providers/models/llm/base.py new file mode 100644 index 0000000000..31573dd580 --- /dev/null +++ b/api/core/model_providers/models/llm/base.py @@ -0,0 +1,269 @@ +from abc import abstractmethod +from typing import List, Optional, Any, Union + +from langchain.callbacks.manager import Callbacks +from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration + +from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler +from core.model_providers.models.base import BaseProviderModel +from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult +from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules +from core.model_providers.providers.base import BaseModelProvider +from core.third_party.langchain.llms.fake import FakeLLM + + +class BaseLLM(BaseProviderModel): + model_mode: ModelMode = ModelMode.COMPLETION + name: str + model_kwargs: ModelKwargs + credentials: dict + streaming: bool = False + type: ModelType = ModelType.TEXT_GENERATION + deduct_quota: bool = True + + def __init__(self, model_provider: BaseModelProvider, + name: str, + model_kwargs: ModelKwargs, + streaming: bool = False, + callbacks: Callbacks = None): + self.name = name + self.model_rules = model_provider.get_model_parameter_rules(name, self.type) + self.model_kwargs = model_kwargs if model_kwargs else ModelKwargs( + max_tokens=None, + temperature=None, + top_p=None, + presence_penalty=None, + frequency_penalty=None + ) + self.credentials = model_provider.get_model_credentials( + model_name=name, + model_type=self.type + ) + self.streaming = streaming + + if streaming: + default_callback = DifyStreamingStdOutCallbackHandler() + else: + default_callback = DifyStdOutCallbackHandler() + + if not callbacks: + callbacks = [default_callback] + else: + callbacks.append(default_callback) + + self.callbacks = callbacks + + client = self._init_client() + super().__init__(model_provider, client) + + @abstractmethod + def _init_client(self) -> Any: + raise NotImplementedError + + def run(self, messages: List[PromptMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs) -> LLMRunResult: + """ + run predict by prompt messages and stop words. + + :param messages: + :param stop: + :param callbacks: + :return: + """ + if self.deduct_quota: + self.model_provider.check_quota_over_limit() + + if not callbacks: + callbacks = self.callbacks + else: + callbacks.extend(self.callbacks) + + if 'fake_response' in kwargs and kwargs['fake_response']: + prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT) + fake_llm = FakeLLM( + response=kwargs['fake_response'], + num_token_func=self.get_num_tokens, + streaming=self.streaming, + callbacks=callbacks + ) + result = fake_llm.generate([prompts]) + else: + try: + result = self._run( + messages=messages, + stop=stop, + callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None, + **kwargs + ) + except Exception as ex: + raise self.handle_exceptions(ex) + + if isinstance(result.generations[0][0], ChatGeneration): + completion_content = result.generations[0][0].message.content + else: + completion_content = result.generations[0][0].text + + if self.streaming and not self.support_streaming(): + # use FakeLLM to simulate streaming when current model not support streaming but streaming is True + prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT) + fake_llm = FakeLLM( + response=completion_content, + num_token_func=self.get_num_tokens, + streaming=self.streaming, + callbacks=callbacks + ) + fake_llm.generate([prompts]) + + if result.llm_output and result.llm_output['token_usage']: + prompt_tokens = result.llm_output['token_usage']['prompt_tokens'] + completion_tokens = result.llm_output['token_usage']['completion_tokens'] + total_tokens = result.llm_output['token_usage']['total_tokens'] + else: + prompt_tokens = self.get_num_tokens(messages) + completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)]) + total_tokens = prompt_tokens + completion_tokens + + if self.deduct_quota: + self.model_provider.deduct_quota(total_tokens) + + return LLMRunResult( + content=completion_content, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens + ) + + @abstractmethod + def _run(self, messages: List[PromptMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs) -> LLMResult: + """ + run predict by prompt messages and stop words. + + :param messages: + :param stop: + :param callbacks: + :return: + """ + raise NotImplementedError + + @abstractmethod + def get_num_tokens(self, messages: List[PromptMessage]) -> int: + """ + get num tokens of prompt messages. + + :param messages: + :return: + """ + raise NotImplementedError + + @abstractmethod + def get_token_price(self, tokens: int, message_type: MessageType): + """ + get token price. + + :param tokens: + :param message_type: + :return: + """ + raise NotImplementedError + + @abstractmethod + def get_currency(self): + """ + get token currency. + + :return: + """ + raise NotImplementedError + + def get_model_kwargs(self): + return self.model_kwargs + + def set_model_kwargs(self, model_kwargs: ModelKwargs): + self.model_kwargs = model_kwargs + self._set_model_kwargs(model_kwargs) + + @abstractmethod + def _set_model_kwargs(self, model_kwargs: ModelKwargs): + raise NotImplementedError + + @abstractmethod + def handle_exceptions(self, ex: Exception) -> Exception: + """ + Handle llm run exceptions. + + :param ex: + :return: + """ + raise NotImplementedError + + def add_callbacks(self, callbacks: Callbacks): + """ + Add callbacks to client. + + :param callbacks: + :return: + """ + if not self.client.callbacks: + self.client.callbacks = callbacks + else: + self.client.callbacks.extend(callbacks) + + @classmethod + def support_streaming(cls): + return False + + def _get_prompt_from_messages(self, messages: List[PromptMessage], + model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]: + if len(messages) == 0: + raise ValueError("prompt must not be empty.") + + if not model_mode: + model_mode = self.model_mode + + if model_mode == ModelMode.COMPLETION: + return messages[0].content + else: + chat_messages = [] + for message in messages: + if message.type == MessageType.HUMAN: + chat_messages.append(HumanMessage(content=message.content)) + elif message.type == MessageType.ASSISTANT: + chat_messages.append(AIMessage(content=message.content)) + elif message.type == MessageType.SYSTEM: + chat_messages.append(SystemMessage(content=message.content)) + + return chat_messages + + def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict: + """ + convert model kwargs to provider model kwargs. + + :param model_rules: + :param model_kwargs: + :return: + """ + model_kwargs_input = {} + for key, value in model_kwargs.dict().items(): + rule = getattr(model_rules, key) + if not rule.enabled: + continue + + if rule.alias: + key = rule.alias + + if rule.default is not None and value is None: + value = rule.default + + if rule.min is not None: + value = max(value, rule.min) + + if rule.max is not None: + value = min(value, rule.max) + + model_kwargs_input[key] = value + + return model_kwargs_input diff --git a/api/core/model_providers/models/llm/chatglm_model.py b/api/core/model_providers/models/llm/chatglm_model.py new file mode 100644 index 0000000000..42036dbfdd --- /dev/null +++ b/api/core/model_providers/models/llm/chatglm_model.py @@ -0,0 +1,70 @@ +import decimal +from typing import List, Optional, Any + +from langchain.callbacks.manager import Callbacks +from langchain.llms import ChatGLM +from langchain.schema import LLMResult + +from core.model_providers.error import LLMBadRequestError +from core.model_providers.models.llm.base import BaseLLM +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs + + +class ChatGLMModel(BaseLLM): + model_mode: ModelMode = ModelMode.COMPLETION + + def _init_client(self) -> Any: + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) + return ChatGLM( + callbacks=self.callbacks, + endpoint_url=self.credentials.get('api_base'), + **provider_model_kwargs + ) + + def _run(self, messages: List[PromptMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs) -> LLMResult: + """ + run predict by prompt messages and stop words. + + :param messages: + :param stop: + :param callbacks: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return self._client.generate([prompts], stop, callbacks) + + def get_num_tokens(self, messages: List[PromptMessage]) -> int: + """ + get num tokens of prompt messages. + + :param messages: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return max(self._client.get_num_tokens(prompts), 0) + + def get_token_price(self, tokens: int, message_type: MessageType): + return decimal.Decimal('0') + + def get_currency(self): + return 'RMB' + + def _set_model_kwargs(self, model_kwargs: ModelKwargs): + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) + for k, v in provider_model_kwargs.items(): + if hasattr(self.client, k): + setattr(self.client, k, v) + + def handle_exceptions(self, ex: Exception) -> Exception: + if isinstance(ex, ValueError): + return LLMBadRequestError(f"ChatGLM: {str(ex)}") + else: + return ex + + @classmethod + def support_streaming(cls): + return False diff --git a/api/core/model_providers/models/llm/huggingface_hub_model.py b/api/core/model_providers/models/llm/huggingface_hub_model.py new file mode 100644 index 0000000000..f5deded517 --- /dev/null +++ b/api/core/model_providers/models/llm/huggingface_hub_model.py @@ -0,0 +1,82 @@ +import decimal +from functools import wraps +from typing import List, Optional, Any + +from langchain import HuggingFaceHub +from langchain.callbacks.manager import Callbacks +from langchain.llms import HuggingFaceEndpoint +from langchain.schema import LLMResult + +from core.model_providers.error import LLMBadRequestError +from core.model_providers.models.llm.base import BaseLLM +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs + + +class HuggingfaceHubModel(BaseLLM): + model_mode: ModelMode = ModelMode.COMPLETION + + def _init_client(self) -> Any: + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) + if self.credentials['huggingfacehub_api_type'] == 'inference_endpoints': + client = HuggingFaceEndpoint( + endpoint_url=self.credentials['huggingfacehub_endpoint_url'], + task='text2text-generation', + model_kwargs=provider_model_kwargs, + huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'], + callbacks=self.callbacks, + ) + else: + client = HuggingFaceHub( + repo_id=self.name, + task=self.credentials['task_type'], + model_kwargs=provider_model_kwargs, + huggingfacehub_api_token=self.credentials['huggingfacehub_api_token'], + callbacks=self.callbacks, + ) + + return client + + def _run(self, messages: List[PromptMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs) -> LLMResult: + """ + run predict by prompt messages and stop words. + + :param messages: + :param stop: + :param callbacks: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return self._client.generate([prompts], stop, callbacks) + + def get_num_tokens(self, messages: List[PromptMessage]) -> int: + """ + get num tokens of prompt messages. + + :param messages: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return self._client.get_num_tokens(prompts) + + def get_token_price(self, tokens: int, message_type: MessageType): + # not support calc price + return decimal.Decimal('0') + + def get_currency(self): + return 'USD' + + def _set_model_kwargs(self, model_kwargs: ModelKwargs): + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) + self.client.model_kwargs = provider_model_kwargs + + def handle_exceptions(self, ex: Exception) -> Exception: + return LLMBadRequestError(f"Huggingface Hub: {str(ex)}") + + @classmethod + def support_streaming(cls): + return False + diff --git a/api/core/model_providers/models/llm/minimax_model.py b/api/core/model_providers/models/llm/minimax_model.py new file mode 100644 index 0000000000..b7e38462f0 --- /dev/null +++ b/api/core/model_providers/models/llm/minimax_model.py @@ -0,0 +1,70 @@ +import decimal +from typing import List, Optional, Any + +from langchain.callbacks.manager import Callbacks +from langchain.llms import Minimax +from langchain.schema import LLMResult + +from core.model_providers.error import LLMBadRequestError +from core.model_providers.models.llm.base import BaseLLM +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs + + +class MinimaxModel(BaseLLM): + model_mode: ModelMode = ModelMode.COMPLETION + + def _init_client(self) -> Any: + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) + return Minimax( + model=self.name, + model_kwargs={ + 'stream': False + }, + callbacks=self.callbacks, + **self.credentials, + **provider_model_kwargs + ) + + def _run(self, messages: List[PromptMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs) -> LLMResult: + """ + run predict by prompt messages and stop words. + + :param messages: + :param stop: + :param callbacks: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return self._client.generate([prompts], stop, callbacks) + + def get_num_tokens(self, messages: List[PromptMessage]) -> int: + """ + get num tokens of prompt messages. + + :param messages: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return max(self._client.get_num_tokens(prompts), 0) + + def get_token_price(self, tokens: int, message_type: MessageType): + return decimal.Decimal('0') + + def get_currency(self): + return 'RMB' + + def _set_model_kwargs(self, model_kwargs: ModelKwargs): + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) + for k, v in provider_model_kwargs.items(): + if hasattr(self.client, k): + setattr(self.client, k, v) + + def handle_exceptions(self, ex: Exception) -> Exception: + if isinstance(ex, ValueError): + return LLMBadRequestError(f"Minimax: {str(ex)}") + else: + return ex diff --git a/api/core/model_providers/models/llm/openai_model.py b/api/core/model_providers/models/llm/openai_model.py new file mode 100644 index 0000000000..e3dab3e9d7 --- /dev/null +++ b/api/core/model_providers/models/llm/openai_model.py @@ -0,0 +1,219 @@ +import decimal +import logging +from typing import List, Optional, Any + +import openai +from langchain.callbacks.manager import Callbacks +from langchain.schema import LLMResult + +from core.model_providers.providers.base import BaseModelProvider +from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI +from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ + LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError +from core.third_party.langchain.llms.open_ai import EnhanceOpenAI +from core.model_providers.models.llm.base import BaseLLM +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs +from models.provider import ProviderType, ProviderQuotaType + +COMPLETION_MODELS = [ + 'text-davinci-003', # 4,097 tokens +] + +CHAT_MODELS = [ + 'gpt-4', # 8,192 tokens + 'gpt-4-32k', # 32,768 tokens + 'gpt-3.5-turbo', # 4,096 tokens + 'gpt-3.5-turbo-16k', # 16,384 tokens +] + +MODEL_MAX_TOKENS = { + 'gpt-4': 8192, + 'gpt-4-32k': 32768, + 'gpt-3.5-turbo': 4096, + 'gpt-3.5-turbo-16k': 16384, + 'text-davinci-003': 4097, +} + + +class OpenAIModel(BaseLLM): + def __init__(self, model_provider: BaseModelProvider, + name: str, + model_kwargs: ModelKwargs, + streaming: bool = False, + callbacks: Callbacks = None): + if name in COMPLETION_MODELS: + self.model_mode = ModelMode.COMPLETION + else: + self.model_mode = ModelMode.CHAT + + super().__init__(model_provider, name, model_kwargs, streaming, callbacks) + + def _init_client(self) -> Any: + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) + if self.name in COMPLETION_MODELS: + client = EnhanceOpenAI( + model_name=self.name, + streaming=self.streaming, + callbacks=self.callbacks, + request_timeout=60, + **self.credentials, + **provider_model_kwargs + ) + else: + # Fine-tuning is currently only available for the following base models: + # davinci, curie, babbage, and ada. + # This means that except for the fixed `completion` model, + # all other fine-tuned models are `completion` models. + extra_model_kwargs = { + 'top_p': provider_model_kwargs.get('top_p'), + 'frequency_penalty': provider_model_kwargs.get('frequency_penalty'), + 'presence_penalty': provider_model_kwargs.get('presence_penalty'), + } + + client = EnhanceChatOpenAI( + model_name=self.name, + temperature=provider_model_kwargs.get('temperature'), + max_tokens=provider_model_kwargs.get('max_tokens'), + model_kwargs=extra_model_kwargs, + streaming=self.streaming, + callbacks=self.callbacks, + request_timeout=60, + **self.credentials + ) + + return client + + def _run(self, messages: List[PromptMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs) -> LLMResult: + """ + run predict by prompt messages and stop words. + + :param messages: + :param stop: + :param callbacks: + :return: + """ + if self.name == 'gpt-4' \ + and self.model_provider.provider.provider_type == ProviderType.SYSTEM.value \ + and self.model_provider.provider.quota_type == ProviderQuotaType.TRIAL.value: + raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.") + + prompts = self._get_prompt_from_messages(messages) + return self._client.generate([prompts], stop, callbacks) + + def get_num_tokens(self, messages: List[PromptMessage]) -> int: + """ + get num tokens of prompt messages. + + :param messages: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + if isinstance(prompts, str): + return self._client.get_num_tokens(prompts) + else: + return max(self._client.get_num_tokens_from_messages(prompts) - len(prompts), 0) + + def get_token_price(self, tokens: int, message_type: MessageType): + model_unit_prices = { + 'gpt-4': { + 'prompt': decimal.Decimal('0.03'), + 'completion': decimal.Decimal('0.06'), + }, + 'gpt-4-32k': { + 'prompt': decimal.Decimal('0.06'), + 'completion': decimal.Decimal('0.12') + }, + 'gpt-3.5-turbo': { + 'prompt': decimal.Decimal('0.0015'), + 'completion': decimal.Decimal('0.002') + }, + 'gpt-3.5-turbo-16k': { + 'prompt': decimal.Decimal('0.003'), + 'completion': decimal.Decimal('0.004') + }, + 'text-davinci-003': { + 'prompt': decimal.Decimal('0.02'), + 'completion': decimal.Decimal('0.02') + }, + } + + if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: + unit_price = model_unit_prices[self.name]['prompt'] + else: + unit_price = model_unit_prices[self.name]['completion'] + + tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), + rounding=decimal.ROUND_HALF_UP) + + total_price = tokens_per_1k * unit_price + return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) + + def get_currency(self): + return 'USD' + + def _set_model_kwargs(self, model_kwargs: ModelKwargs): + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) + if self.name in COMPLETION_MODELS: + for k, v in provider_model_kwargs.items(): + if hasattr(self.client, k): + setattr(self.client, k, v) + else: + extra_model_kwargs = { + 'top_p': provider_model_kwargs.get('top_p'), + 'frequency_penalty': provider_model_kwargs.get('frequency_penalty'), + 'presence_penalty': provider_model_kwargs.get('presence_penalty'), + } + + self.client.temperature = provider_model_kwargs.get('temperature') + self.client.max_tokens = provider_model_kwargs.get('max_tokens') + self.client.model_kwargs = extra_model_kwargs + + def handle_exceptions(self, ex: Exception) -> Exception: + if isinstance(ex, openai.error.InvalidRequestError): + logging.warning("Invalid request to OpenAI API.") + return LLMBadRequestError(str(ex)) + elif isinstance(ex, openai.error.APIConnectionError): + logging.warning("Failed to connect to OpenAI API.") + return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex)) + elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)): + logging.warning("OpenAI service unavailable.") + return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex)) + elif isinstance(ex, openai.error.RateLimitError): + return LLMRateLimitError(str(ex)) + elif isinstance(ex, openai.error.AuthenticationError): + raise LLMAuthorizationError(str(ex)) + elif isinstance(ex, openai.error.OpenAIError): + return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex)) + else: + return ex + + @classmethod + def support_streaming(cls): + return True + + # def is_model_valid_or_raise(self): + # """ + # check is a valid model. + # + # :return: + # """ + # credentials = self._model_provider.get_credentials() + # + # try: + # result = openai.Model.retrieve( + # id=self.name, + # api_key=credentials.get('openai_api_key'), + # request_timeout=60 + # ) + # + # if 'id' not in result or result['id'] != self.name: + # raise LLMNotExistsError(f"OpenAI Model {self.name} not exists.") + # except openai.error.OpenAIError as e: + # raise LLMNotExistsError(f"OpenAI Model {self.name} not exists, cause: {e.__class__.__name__}:{str(e)}") + # except Exception as e: + # logging.exception("OpenAI Model retrieve failed.") + # raise e diff --git a/api/core/model_providers/models/llm/replicate_model.py b/api/core/model_providers/models/llm/replicate_model.py new file mode 100644 index 0000000000..7dd7eb8531 --- /dev/null +++ b/api/core/model_providers/models/llm/replicate_model.py @@ -0,0 +1,103 @@ +import decimal +from functools import wraps +from typing import List, Optional, Any + +from langchain.callbacks.manager import Callbacks +from langchain.schema import LLMResult, get_buffer_string +from replicate.exceptions import ReplicateError, ModelError + +from core.model_providers.providers.base import BaseModelProvider +from core.model_providers.error import LLMBadRequestError +from core.third_party.langchain.llms.replicate_llm import EnhanceReplicate +from core.model_providers.models.llm.base import BaseLLM +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs + + +class ReplicateModel(BaseLLM): + def __init__(self, model_provider: BaseModelProvider, + name: str, + model_kwargs: ModelKwargs, + streaming: bool = False, + callbacks: Callbacks = None): + self.model_mode = ModelMode.CHAT if name.endswith('-chat') else ModelMode.COMPLETION + + super().__init__(model_provider, name, model_kwargs, streaming, callbacks) + + def _init_client(self) -> Any: + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) + + return EnhanceReplicate( + model=self.name + ':' + self.credentials.get('model_version'), + input=provider_model_kwargs, + streaming=self.streaming, + replicate_api_token=self.credentials.get('replicate_api_token'), + callbacks=self.callbacks, + ) + + def _run(self, messages: List[PromptMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs) -> LLMResult: + """ + run predict by prompt messages and stop words. + + :param messages: + :param stop: + :param callbacks: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + extra_kwargs = {} + if isinstance(prompts, list): + system_messages = [message for message in messages if message.type == 'system'] + if system_messages: + system_message = system_messages[0] + extra_kwargs['system_prompt'] = system_message.content + prompts = [message for message in messages if message.type != 'system'] + + prompts = get_buffer_string(prompts) + + # The maximum length the generated tokens can have. + # Corresponds to the length of the input prompt + max_new_tokens. + if 'max_length' in self._client.input: + self._client.input['max_length'] = min( + self._client.input['max_length'] + self.get_num_tokens(messages), + self.model_rules.max_tokens.max + ) + + return self._client.generate([prompts], stop, callbacks, **extra_kwargs) + + def get_num_tokens(self, messages: List[PromptMessage]) -> int: + """ + get num tokens of prompt messages. + + :param messages: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + if isinstance(prompts, list): + prompts = get_buffer_string(prompts) + + return self._client.get_num_tokens(prompts) + + def get_token_price(self, tokens: int, message_type: MessageType): + # replicate only pay for prediction seconds + return decimal.Decimal('0') + + def get_currency(self): + return 'USD' + + def _set_model_kwargs(self, model_kwargs: ModelKwargs): + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) + self.client.input = provider_model_kwargs + + def handle_exceptions(self, ex: Exception) -> Exception: + if isinstance(ex, (ModelError, ReplicateError)): + return LLMBadRequestError(f"Replicate: {str(ex)}") + else: + return ex + + @classmethod + def support_streaming(cls): + return True \ No newline at end of file diff --git a/api/core/model_providers/models/llm/spark_model.py b/api/core/model_providers/models/llm/spark_model.py new file mode 100644 index 0000000000..5d8c97c463 --- /dev/null +++ b/api/core/model_providers/models/llm/spark_model.py @@ -0,0 +1,73 @@ +import decimal +from functools import wraps +from typing import List, Optional, Any + +from langchain.callbacks.manager import Callbacks +from langchain.schema import LLMResult + +from core.model_providers.error import LLMBadRequestError +from core.model_providers.models.llm.base import BaseLLM +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs +from core.third_party.langchain.llms.spark import ChatSpark +from core.third_party.spark.spark_llm import SparkError + + +class SparkModel(BaseLLM): + model_mode: ModelMode = ModelMode.CHAT + + def _init_client(self) -> Any: + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) + return ChatSpark( + streaming=self.streaming, + callbacks=self.callbacks, + **self.credentials, + **provider_model_kwargs + ) + + def _run(self, messages: List[PromptMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs) -> LLMResult: + """ + run predict by prompt messages and stop words. + + :param messages: + :param stop: + :param callbacks: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return self._client.generate([prompts], stop, callbacks) + + def get_num_tokens(self, messages: List[PromptMessage]) -> int: + """ + get num tokens of prompt messages. + + :param messages: + :return: + """ + contents = [message.content for message in messages] + return max(self._client.get_num_tokens("".join(contents)), 0) + + def get_token_price(self, tokens: int, message_type: MessageType): + return decimal.Decimal('0') + + def get_currency(self): + return 'RMB' + + def _set_model_kwargs(self, model_kwargs: ModelKwargs): + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) + for k, v in provider_model_kwargs.items(): + if hasattr(self.client, k): + setattr(self.client, k, v) + + def handle_exceptions(self, ex: Exception) -> Exception: + if isinstance(ex, SparkError): + return LLMBadRequestError(f"Spark: {str(ex)}") + else: + return ex + + @classmethod + def support_streaming(cls): + return True \ No newline at end of file diff --git a/api/core/model_providers/models/llm/tongyi_model.py b/api/core/model_providers/models/llm/tongyi_model.py new file mode 100644 index 0000000000..f950275f77 --- /dev/null +++ b/api/core/model_providers/models/llm/tongyi_model.py @@ -0,0 +1,77 @@ +import decimal +from functools import wraps +from typing import List, Optional, Any + +from langchain.callbacks.manager import Callbacks +from langchain.schema import LLMResult +from requests import HTTPError + +from core.model_providers.error import LLMBadRequestError +from core.model_providers.models.llm.base import BaseLLM +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs +from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi + + +class TongyiModel(BaseLLM): + model_mode: ModelMode = ModelMode.COMPLETION + + def _init_client(self) -> Any: + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) + del provider_model_kwargs['max_tokens'] + return EnhanceTongyi( + model_name=self.name, + max_retries=1, + streaming=self.streaming, + callbacks=self.callbacks, + **self.credentials, + **provider_model_kwargs + ) + + def _run(self, messages: List[PromptMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs) -> LLMResult: + """ + run predict by prompt messages and stop words. + + :param messages: + :param stop: + :param callbacks: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return self._client.generate([prompts], stop, callbacks) + + def get_num_tokens(self, messages: List[PromptMessage]) -> int: + """ + get num tokens of prompt messages. + + :param messages: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return max(self._client.get_num_tokens(prompts), 0) + + def get_token_price(self, tokens: int, message_type: MessageType): + return decimal.Decimal('0') + + def get_currency(self): + return 'RMB' + + def _set_model_kwargs(self, model_kwargs: ModelKwargs): + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) + del provider_model_kwargs['max_tokens'] + for k, v in provider_model_kwargs.items(): + if hasattr(self.client, k): + setattr(self.client, k, v) + + def handle_exceptions(self, ex: Exception) -> Exception: + if isinstance(ex, (ValueError, HTTPError)): + return LLMBadRequestError(f"Tongyi: {str(ex)}") + else: + return ex + + @classmethod + def support_streaming(cls): + return True diff --git a/api/core/model_providers/models/llm/wenxin_model.py b/api/core/model_providers/models/llm/wenxin_model.py new file mode 100644 index 0000000000..2c950679ab --- /dev/null +++ b/api/core/model_providers/models/llm/wenxin_model.py @@ -0,0 +1,92 @@ +import decimal +from typing import List, Optional, Any + +from langchain.callbacks.manager import Callbacks +from langchain.schema import LLMResult + +from core.model_providers.error import LLMBadRequestError +from core.model_providers.models.llm.base import BaseLLM +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs +from core.third_party.langchain.llms.wenxin import Wenxin + + +class WenxinModel(BaseLLM): + model_mode: ModelMode = ModelMode.COMPLETION + + def _init_client(self) -> Any: + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs) + return Wenxin( + streaming=self.streaming, + callbacks=self.callbacks, + **self.credentials, + **provider_model_kwargs + ) + + def _run(self, messages: List[PromptMessage], + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs) -> LLMResult: + """ + run predict by prompt messages and stop words. + + :param messages: + :param stop: + :param callbacks: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return self._client.generate([prompts], stop, callbacks) + + def get_num_tokens(self, messages: List[PromptMessage]) -> int: + """ + get num tokens of prompt messages. + + :param messages: + :return: + """ + prompts = self._get_prompt_from_messages(messages) + return max(self._client.get_num_tokens(prompts), 0) + + def get_token_price(self, tokens: int, message_type: MessageType): + model_unit_prices = { + 'ernie-bot': { + 'prompt': decimal.Decimal('0.012'), + 'completion': decimal.Decimal('0.012'), + }, + 'ernie-bot-turbo': { + 'prompt': decimal.Decimal('0.008'), + 'completion': decimal.Decimal('0.008') + }, + 'bloomz-7b': { + 'prompt': decimal.Decimal('0.006'), + 'completion': decimal.Decimal('0.006') + } + } + + if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM: + unit_price = model_unit_prices[self.name]['prompt'] + else: + unit_price = model_unit_prices[self.name]['completion'] + + tokens_per_1k = (decimal.Decimal(tokens) / 1000).quantize(decimal.Decimal('0.001'), + rounding=decimal.ROUND_HALF_UP) + + total_price = tokens_per_1k * unit_price + return total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP) + + def get_currency(self): + return 'RMB' + + def _set_model_kwargs(self, model_kwargs: ModelKwargs): + provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs) + for k, v in provider_model_kwargs.items(): + if hasattr(self.client, k): + setattr(self.client, k, v) + + def handle_exceptions(self, ex: Exception) -> Exception: + return LLMBadRequestError(f"Wenxin: {str(ex)}") + + @classmethod + def support_streaming(cls): + return False diff --git a/api/core/model_providers/models/moderation/__init__.py b/api/core/model_providers/models/moderation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_providers/models/moderation/openai_moderation.py b/api/core/model_providers/models/moderation/openai_moderation.py new file mode 100644 index 0000000000..c1e792966b --- /dev/null +++ b/api/core/model_providers/models/moderation/openai_moderation.py @@ -0,0 +1,48 @@ +import logging + +import openai + +from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ + LLMRateLimitError, LLMAuthorizationError +from core.model_providers.models.base import BaseProviderModel +from core.model_providers.models.entity.model_params import ModelType +from core.model_providers.providers.base import BaseModelProvider + +DEFAULT_AUDIO_MODEL = 'whisper-1' + + +class OpenAIModeration(BaseProviderModel): + type: ModelType = ModelType.MODERATION + + def __init__(self, model_provider: BaseModelProvider, name: str): + super().__init__(model_provider, openai.Moderation) + + def run(self, text): + credentials = self.model_provider.get_model_credentials( + model_name=DEFAULT_AUDIO_MODEL, + model_type=self.type + ) + + try: + return self._client.create(input=text, api_key=credentials['openai_api_key']) + except Exception as ex: + raise self.handle_exceptions(ex) + + def handle_exceptions(self, ex: Exception) -> Exception: + if isinstance(ex, openai.error.InvalidRequestError): + logging.warning("Invalid request to OpenAI API.") + return LLMBadRequestError(str(ex)) + elif isinstance(ex, openai.error.APIConnectionError): + logging.warning("Failed to connect to OpenAI API.") + return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex)) + elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)): + logging.warning("OpenAI service unavailable.") + return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex)) + elif isinstance(ex, openai.error.RateLimitError): + return LLMRateLimitError(str(ex)) + elif isinstance(ex, openai.error.AuthenticationError): + raise LLMAuthorizationError(str(ex)) + elif isinstance(ex, openai.error.OpenAIError): + return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex)) + else: + return ex diff --git a/api/core/model_providers/models/speech2text/__init__.py b/api/core/model_providers/models/speech2text/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_providers/models/speech2text/base.py b/api/core/model_providers/models/speech2text/base.py new file mode 100644 index 0000000000..0b1ec1d558 --- /dev/null +++ b/api/core/model_providers/models/speech2text/base.py @@ -0,0 +1,29 @@ +from abc import abstractmethod +from typing import Any + +from core.model_providers.models.base import BaseProviderModel +from core.model_providers.models.entity.model_params import ModelType +from core.model_providers.providers.base import BaseModelProvider + + +class BaseSpeech2Text(BaseProviderModel): + name: str + type: ModelType = ModelType.SPEECH_TO_TEXT + + def __init__(self, model_provider: BaseModelProvider, client: Any, name: str): + super().__init__(model_provider, client) + self.name = name + + def run(self, file): + try: + return self._run(file) + except Exception as ex: + raise self.handle_exceptions(ex) + + @abstractmethod + def _run(self, file): + raise NotImplementedError + + @abstractmethod + def handle_exceptions(self, ex: Exception) -> Exception: + raise NotImplementedError diff --git a/api/core/model_providers/models/speech2text/openai_whisper.py b/api/core/model_providers/models/speech2text/openai_whisper.py new file mode 100644 index 0000000000..8bca2aaa6d --- /dev/null +++ b/api/core/model_providers/models/speech2text/openai_whisper.py @@ -0,0 +1,47 @@ +import logging + +import openai + +from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \ + LLMRateLimitError, LLMAuthorizationError +from core.model_providers.models.speech2text.base import BaseSpeech2Text +from core.model_providers.providers.base import BaseModelProvider + + +class OpenAIWhisper(BaseSpeech2Text): + + def __init__(self, model_provider: BaseModelProvider, name: str): + super().__init__(model_provider, openai.Audio, name) + + def _run(self, file): + credentials = self.model_provider.get_model_credentials( + model_name=self.name, + model_type=self.type + ) + + return self._client.transcribe( + model=self.name, + file=file, + api_key=credentials.get('openai_api_key'), + api_base=credentials.get('openai_api_base'), + organization=credentials.get('openai_organization'), + ) + + def handle_exceptions(self, ex: Exception) -> Exception: + if isinstance(ex, openai.error.InvalidRequestError): + logging.warning("Invalid request to OpenAI API.") + return LLMBadRequestError(str(ex)) + elif isinstance(ex, openai.error.APIConnectionError): + logging.warning("Failed to connect to OpenAI API.") + return LLMAPIConnectionError(ex.__class__.__name__ + ":" + str(ex)) + elif isinstance(ex, (openai.error.APIError, openai.error.ServiceUnavailableError, openai.error.Timeout)): + logging.warning("OpenAI service unavailable.") + return LLMAPIUnavailableError(ex.__class__.__name__ + ":" + str(ex)) + elif isinstance(ex, openai.error.RateLimitError): + return LLMRateLimitError(str(ex)) + elif isinstance(ex, openai.error.AuthenticationError): + raise LLMAuthorizationError(str(ex)) + elif isinstance(ex, openai.error.OpenAIError): + return LLMBadRequestError(ex.__class__.__name__ + ":" + str(ex)) + else: + return ex diff --git a/api/core/model_providers/providers/__init__.py b/api/core/model_providers/providers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_providers/providers/anthropic_provider.py b/api/core/model_providers/providers/anthropic_provider.py new file mode 100644 index 0000000000..8daeff44e7 --- /dev/null +++ b/api/core/model_providers/providers/anthropic_provider.py @@ -0,0 +1,224 @@ +import json +import logging +from json import JSONDecodeError +from typing import Type, Optional + +import anthropic +from flask import current_app +from langchain.chat_models import ChatAnthropic +from langchain.schema import HumanMessage + +from core.helper import encrypter +from core.model_providers.models.base import BaseProviderModel +from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule +from core.model_providers.models.entity.provider import ModelFeature +from core.model_providers.models.llm.anthropic_model import AnthropicModel +from core.model_providers.models.llm.base import ModelType +from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError +from core.model_providers.providers.hosted import hosted_model_providers +from models.provider import ProviderType + + +class AnthropicProvider(BaseModelProvider): + + @property + def provider_name(self): + """ + Returns the name of a provider. + """ + return 'anthropic' + + def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: + if model_type == ModelType.TEXT_GENERATION: + return [ + { + 'id': 'claude-instant-1', + 'name': 'claude-instant-1', + }, + { + 'id': 'claude-2', + 'name': 'claude-2', + 'features': [ + ModelFeature.AGENT_THOUGHT.value + ] + }, + ] + else: + return [] + + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: + """ + Returns the model class. + + :param model_type: + :return: + """ + if model_type == ModelType.TEXT_GENERATION: + model_class = AnthropicModel + else: + raise NotImplementedError + + return model_class + + def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: + """ + get model parameter rules. + + :param model_name: + :param model_type: + :return: + """ + return ModelKwargsRules( + temperature=KwargRule[float](min=0, max=1, default=1), + top_p=KwargRule[float](min=0, max=1, default=0.7), + presence_penalty=KwargRule[float](enabled=False), + frequency_penalty=KwargRule[float](enabled=False), + max_tokens=KwargRule[int](alias="max_tokens_to_sample", min=10, max=100000, default=256), + ) + + @classmethod + def is_provider_credentials_valid_or_raise(cls, credentials: dict): + """ + Validates the given credentials. + """ + if 'anthropic_api_key' not in credentials: + raise CredentialsValidateFailedError('Anthropic API Key must be provided.') + + try: + credential_kwargs = { + 'anthropic_api_key': credentials['anthropic_api_key'] + } + + if 'anthropic_api_url' in credentials: + credential_kwargs['anthropic_api_url'] = credentials['anthropic_api_url'] + + chat_llm = ChatAnthropic( + model='claude-instant-1', + max_tokens_to_sample=10, + temperature=0, + default_request_timeout=60, + **credential_kwargs + ) + + messages = [ + HumanMessage( + content="ping" + ) + ] + + chat_llm(messages) + except anthropic.APIConnectionError as ex: + raise CredentialsValidateFailedError(str(ex)) + except (anthropic.APIStatusError, anthropic.RateLimitError) as ex: + raise CredentialsValidateFailedError(str(ex)) + except Exception as ex: + logging.exception('Anthropic config validation failed') + raise ex + + @classmethod + def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: + credentials['anthropic_api_key'] = encrypter.encrypt_token(tenant_id, credentials['anthropic_api_key']) + return credentials + + def get_provider_credentials(self, obfuscated: bool = False) -> dict: + if self.provider.provider_type == ProviderType.CUSTOM.value: + try: + credentials = json.loads(self.provider.encrypted_config) + except JSONDecodeError: + credentials = { + 'anthropic_api_url': None, + 'anthropic_api_key': None + } + + if credentials['anthropic_api_key']: + credentials['anthropic_api_key'] = encrypter.decrypt_token( + self.provider.tenant_id, + credentials['anthropic_api_key'] + ) + + if obfuscated: + credentials['anthropic_api_key'] = encrypter.obfuscated_token(credentials['anthropic_api_key']) + + if 'anthropic_api_url' not in credentials: + credentials['anthropic_api_url'] = None + + return credentials + else: + if hosted_model_providers.anthropic: + return { + 'anthropic_api_url': hosted_model_providers.anthropic.api_base, + 'anthropic_api_key': hosted_model_providers.anthropic.api_key, + } + else: + return { + 'anthropic_api_url': None, + 'anthropic_api_key': None + } + + @classmethod + def is_provider_type_system_supported(cls) -> bool: + if current_app.config['EDITION'] != 'CLOUD': + return False + + if hosted_model_providers.anthropic: + return True + + return False + + def should_deduct_quota(self): + if hosted_model_providers.anthropic and \ + hosted_model_providers.anthropic.quota_limit and hosted_model_providers.anthropic.quota_limit > 0: + return True + + return False + + def get_payment_info(self) -> Optional[dict]: + """ + get product info if it payable. + + :return: + """ + if hosted_model_providers.anthropic \ + and hosted_model_providers.anthropic.paid_enabled: + return { + 'product_id': hosted_model_providers.anthropic.paid_stripe_price_id, + 'increase_quota': hosted_model_providers.anthropic.paid_increase_quota, + } + + return None + + @classmethod + def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): + """ + check model credentials valid. + + :param model_name: + :param model_type: + :param credentials: + """ + return + + @classmethod + def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, + credentials: dict) -> dict: + """ + encrypt model credentials for save. + + :param tenant_id: + :param model_name: + :param model_type: + :param credentials: + :return: + """ + return {} + + def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: + """ + get credentials for llm use. + + :param model_name: + :param model_type: + :param obfuscated: + :return: + """ + return self.get_provider_credentials(obfuscated) diff --git a/api/core/model_providers/providers/azure_openai_provider.py b/api/core/model_providers/providers/azure_openai_provider.py new file mode 100644 index 0000000000..3dbb78237d --- /dev/null +++ b/api/core/model_providers/providers/azure_openai_provider.py @@ -0,0 +1,387 @@ +import json +import logging +from json import JSONDecodeError +from typing import Type + +import openai +from flask import current_app +from langchain.embeddings import OpenAIEmbeddings +from langchain.schema import HumanMessage + +from core.helper import encrypter +from core.model_providers.models.base import BaseProviderModel +from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding, \ + AZURE_OPENAI_API_VERSION +from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule +from core.model_providers.models.entity.provider import ModelFeature +from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel +from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError +from core.model_providers.providers.hosted import hosted_model_providers +from core.third_party.langchain.llms.azure_chat_open_ai import EnhanceAzureChatOpenAI +from extensions.ext_database import db +from models.provider import ProviderType, ProviderModel, ProviderQuotaType + +BASE_MODELS = [ + 'gpt-4', + 'gpt-4-32k', + 'gpt-35-turbo', + 'gpt-35-turbo-16k', + 'text-davinci-003', + 'text-embedding-ada-002', +] + + +class AzureOpenAIProvider(BaseModelProvider): + + @property + def provider_name(self): + """ + Returns the name of a provider. + """ + return 'azure_openai' + + def get_supported_model_list(self, model_type: ModelType) -> list[dict]: + # convert old provider config to provider models + self._convert_provider_config_to_model_config() + + if self.provider.provider_type == ProviderType.CUSTOM.value: + # get configurable provider models + provider_models = db.session.query(ProviderModel).filter( + ProviderModel.tenant_id == self.provider.tenant_id, + ProviderModel.provider_name == self.provider.provider_name, + ProviderModel.model_type == model_type.value, + ProviderModel.is_valid == True + ).order_by(ProviderModel.created_at.asc()).all() + + model_list = [] + for provider_model in provider_models: + model_dict = { + 'id': provider_model.model_name, + 'name': provider_model.model_name + } + + credentials = json.loads(provider_model.encrypted_config) + if credentials['base_model_name'] in [ + 'gpt-4', + 'gpt-4-32k', + 'gpt-35-turbo', + 'gpt-35-turbo-16k', + ]: + model_dict['features'] = [ + ModelFeature.AGENT_THOUGHT.value + ] + + model_list.append(model_dict) + else: + model_list = self._get_fixed_model_list(model_type) + + return model_list + + def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: + if model_type == ModelType.TEXT_GENERATION: + models = [ + { + 'id': 'gpt-3.5-turbo', + 'name': 'gpt-3.5-turbo', + 'features': [ + ModelFeature.AGENT_THOUGHT.value + ] + }, + { + 'id': 'gpt-3.5-turbo-16k', + 'name': 'gpt-3.5-turbo-16k', + 'features': [ + ModelFeature.AGENT_THOUGHT.value + ] + }, + { + 'id': 'gpt-4', + 'name': 'gpt-4', + 'features': [ + ModelFeature.AGENT_THOUGHT.value + ] + }, + { + 'id': 'gpt-4-32k', + 'name': 'gpt-4-32k', + 'features': [ + ModelFeature.AGENT_THOUGHT.value + ] + }, + { + 'id': 'text-davinci-003', + 'name': 'text-davinci-003', + } + ] + + if self.provider.provider_type == ProviderType.SYSTEM.value \ + and self.provider.quota_type == ProviderQuotaType.TRIAL.value: + models = [item for item in models if item['id'] not in ['gpt-4', 'gpt-4-32k']] + + return models + elif model_type == ModelType.EMBEDDINGS: + return [ + { + 'id': 'text-embedding-ada-002', + 'name': 'text-embedding-ada-002' + } + ] + else: + return [] + + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: + """ + Returns the model class. + + :param model_type: + :return: + """ + if model_type == ModelType.TEXT_GENERATION: + model_class = AzureOpenAIModel + elif model_type == ModelType.EMBEDDINGS: + model_class = AzureOpenAIEmbedding + else: + raise NotImplementedError + + return model_class + + def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: + """ + get model parameter rules. + + :param model_name: + :param model_type: + :return: + """ + base_model_max_tokens = { + 'gpt-4': 8192, + 'gpt-4-32k': 32768, + 'gpt-35-turbo': 4096, + 'gpt-35-turbo-16k': 16384, + 'text-davinci-003': 4097, + } + + model_credentials = self.get_model_credentials(model_name, model_type) + + return ModelKwargsRules( + temperature=KwargRule[float](min=0, max=2, default=1), + top_p=KwargRule[float](min=0, max=1, default=1), + presence_penalty=KwargRule[float](min=-2, max=2, default=0), + frequency_penalty=KwargRule[float](min=-2, max=2, default=0), + max_tokens=KwargRule[int](min=10, max=base_model_max_tokens.get( + model_credentials['base_model_name'], + 4097 + ), default=16), + ) + + @classmethod + def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): + """ + check model credentials valid. + + :param model_name: + :param model_type: + :param credentials: + """ + if 'openai_api_key' not in credentials: + raise CredentialsValidateFailedError('Azure OpenAI API key is required') + + if 'openai_api_base' not in credentials: + raise CredentialsValidateFailedError('Azure OpenAI API Base Endpoint is required') + + if 'base_model_name' not in credentials: + raise CredentialsValidateFailedError('Base Model Name is required') + + if credentials['base_model_name'] not in BASE_MODELS: + raise CredentialsValidateFailedError('Base Model Name is invalid') + + if model_type == ModelType.TEXT_GENERATION: + try: + client = EnhanceAzureChatOpenAI( + deployment_name=model_name, + temperature=0, + max_tokens=15, + request_timeout=10, + openai_api_type='azure', + openai_api_version='2023-07-01-preview', + openai_api_key=credentials['openai_api_key'], + openai_api_base=credentials['openai_api_base'], + ) + + client.generate([[HumanMessage(content='hi!')]]) + except openai.error.OpenAIError as e: + raise CredentialsValidateFailedError( + f"Azure OpenAI deployment {model_name} not exists, cause: {e.__class__.__name__}:{str(e)}") + except Exception as e: + logging.exception("Azure OpenAI Model retrieve failed.") + raise e + elif model_type == ModelType.EMBEDDINGS: + try: + client = OpenAIEmbeddings( + openai_api_type='azure', + openai_api_version=AZURE_OPENAI_API_VERSION, + deployment=model_name, + chunk_size=16, + max_retries=1, + openai_api_key=credentials['openai_api_key'], + openai_api_base=credentials['openai_api_base'] + ) + + client.embed_query('hi') + except openai.error.OpenAIError as e: + logging.exception("Azure OpenAI Model check error.") + raise CredentialsValidateFailedError( + f"Azure OpenAI deployment {model_name} not exists, cause: {e.__class__.__name__}:{str(e)}") + except Exception as e: + logging.exception("Azure OpenAI Model retrieve failed.") + raise e + + @classmethod + def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, + credentials: dict) -> dict: + """ + encrypt model credentials for save. + + :param tenant_id: + :param model_name: + :param model_type: + :param credentials: + :return: + """ + credentials['openai_api_key'] = encrypter.encrypt_token(tenant_id, credentials['openai_api_key']) + return credentials + + def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: + """ + get credentials for llm use. + + :param model_name: + :param model_type: + :param obfuscated: + :return: + """ + if self.provider.provider_type == ProviderType.CUSTOM.value: + # convert old provider config to provider models + self._convert_provider_config_to_model_config() + + provider_model = self._get_provider_model(model_name, model_type) + + if not provider_model.encrypted_config: + return { + 'openai_api_base': '', + 'openai_api_key': '', + 'base_model_name': '' + } + + credentials = json.loads(provider_model.encrypted_config) + if credentials['openai_api_key']: + credentials['openai_api_key'] = encrypter.decrypt_token( + self.provider.tenant_id, + credentials['openai_api_key'] + ) + + if obfuscated: + credentials['openai_api_key'] = encrypter.obfuscated_token(credentials['openai_api_key']) + + return credentials + else: + if hosted_model_providers.azure_openai: + return { + 'openai_api_base': hosted_model_providers.azure_openai.api_base, + 'openai_api_key': hosted_model_providers.azure_openai.api_key, + 'base_model_name': model_name + } + else: + return { + 'openai_api_base': None, + 'openai_api_key': None, + 'base_model_name': None + } + + @classmethod + def is_provider_type_system_supported(cls) -> bool: + if current_app.config['EDITION'] != 'CLOUD': + return False + + if hosted_model_providers.azure_openai: + return True + + return False + + def should_deduct_quota(self): + if hosted_model_providers.azure_openai \ + and hosted_model_providers.azure_openai.quota_limit and hosted_model_providers.azure_openai.quota_limit > 0: + return True + + return False + + @classmethod + def is_provider_credentials_valid_or_raise(cls, credentials: dict): + return + + @classmethod + def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: + return {} + + def get_provider_credentials(self, obfuscated: bool = False) -> dict: + return {} + + def _convert_provider_config_to_model_config(self): + if self.provider.provider_type == ProviderType.CUSTOM.value \ + and self.provider.is_valid \ + and self.provider.encrypted_config: + try: + credentials = json.loads(self.provider.encrypted_config) + except JSONDecodeError: + credentials = { + 'openai_api_base': '', + 'openai_api_key': '', + 'base_model_name': '' + } + + self._add_provider_model( + model_name='gpt-35-turbo', + model_type=ModelType.TEXT_GENERATION, + provider_credentials=credentials + ) + + self._add_provider_model( + model_name='gpt-35-turbo-16k', + model_type=ModelType.TEXT_GENERATION, + provider_credentials=credentials + ) + + self._add_provider_model( + model_name='gpt-4', + model_type=ModelType.TEXT_GENERATION, + provider_credentials=credentials + ) + + self._add_provider_model( + model_name='text-davinci-003', + model_type=ModelType.TEXT_GENERATION, + provider_credentials=credentials + ) + + self._add_provider_model( + model_name='text-embedding-ada-002', + model_type=ModelType.EMBEDDINGS, + provider_credentials=credentials + ) + + self.provider.encrypted_config = None + db.session.commit() + + def _add_provider_model(self, model_name: str, model_type: ModelType, provider_credentials: dict): + credentials = provider_credentials.copy() + credentials['base_model_name'] = model_name + provider_model = ProviderModel( + tenant_id=self.provider.tenant_id, + provider_name=self.provider.provider_name, + model_name=model_name, + model_type=model_type.value, + encrypted_config=json.dumps(credentials), + is_valid=True + ) + db.session.add(provider_model) + db.session.commit() diff --git a/api/core/model_providers/providers/base.py b/api/core/model_providers/providers/base.py new file mode 100644 index 0000000000..f10aa9f99d --- /dev/null +++ b/api/core/model_providers/providers/base.py @@ -0,0 +1,283 @@ +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Type, Optional + +from flask import current_app +from pydantic import BaseModel + +from core.model_providers.error import QuotaExceededError, LLMBadRequestError +from extensions.ext_database import db +from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules +from core.model_providers.models.entity.provider import ProviderQuotaUnit +from core.model_providers.rules import provider_rules +from models.provider import Provider, ProviderType, ProviderModel + + +class BaseModelProvider(BaseModel, ABC): + + provider: Provider + + class Config: + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + @property + @abstractmethod + def provider_name(self): + """ + Returns the name of a provider. + """ + raise NotImplementedError + + def get_rules(self): + """ + Returns the rules of a provider. + """ + return provider_rules[self.provider_name] + + def get_supported_model_list(self, model_type: ModelType) -> list[dict]: + """ + get supported model object list for use. + + :param model_type: + :return: + """ + rules = self.get_rules() + if 'custom' not in rules['support_provider_types']: + return self._get_fixed_model_list(model_type) + + if 'model_flexibility' not in rules: + return self._get_fixed_model_list(model_type) + + if rules['model_flexibility'] == 'fixed': + return self._get_fixed_model_list(model_type) + + # get configurable provider models + provider_models = db.session.query(ProviderModel).filter( + ProviderModel.tenant_id == self.provider.tenant_id, + ProviderModel.provider_name == self.provider.provider_name, + ProviderModel.model_type == model_type.value, + ProviderModel.is_valid == True + ).order_by(ProviderModel.created_at.asc()).all() + + return [{ + 'id': provider_model.model_name, + 'name': provider_model.model_name + } for provider_model in provider_models] + + @abstractmethod + def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: + """ + get supported model object list for use. + + :param model_type: + :return: + """ + raise NotImplementedError + + @abstractmethod + def get_model_class(self, model_type: ModelType) -> Type: + """ + get specific model class. + + :param model_type: + :return: + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def is_provider_credentials_valid_or_raise(cls, credentials: dict): + """ + check provider credentials valid. + + :param credentials: + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: + """ + encrypt provider credentials for save. + + :param tenant_id: + :param credentials: + :return: + """ + raise NotImplementedError + + @abstractmethod + def get_provider_credentials(self, obfuscated: bool = False) -> dict: + """ + get credentials for llm use. + + :param obfuscated: + :return: + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): + """ + check model credentials valid. + + :param model_name: + :param model_type: + :param credentials: + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, + credentials: dict) -> dict: + """ + encrypt model credentials for save. + + :param tenant_id: + :param model_name: + :param model_type: + :param credentials: + :return: + """ + raise NotImplementedError + + @abstractmethod + def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: + """ + get model parameter rules. + + :param model_name: + :param model_type: + :return: + """ + raise NotImplementedError + + @abstractmethod + def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: + """ + get credentials for llm use. + + :param model_name: + :param model_type: + :param obfuscated: + :return: + """ + raise NotImplementedError + + @classmethod + def is_provider_type_system_supported(cls) -> bool: + return current_app.config['EDITION'] == 'CLOUD' + + def check_quota_over_limit(self): + """ + check provider quota over limit. + + :return: + """ + if self.provider.provider_type != ProviderType.SYSTEM.value: + return + + rules = self.get_rules() + if 'system' not in rules['support_provider_types']: + return + + provider = db.session.query(Provider).filter( + db.and_( + Provider.id == self.provider.id, + Provider.is_valid == True, + Provider.quota_limit > Provider.quota_used + ) + ).first() + + if not provider: + raise QuotaExceededError() + + def deduct_quota(self, used_tokens: int = 0) -> None: + """ + deduct available quota when provider type is system or paid. + + :return: + """ + if self.provider.provider_type != ProviderType.SYSTEM.value: + return + + rules = self.get_rules() + if 'system' not in rules['support_provider_types']: + return + + if not self.should_deduct_quota(): + return + + if 'system_config' not in rules: + quota_unit = ProviderQuotaUnit.TIMES.value + elif 'quota_unit' not in rules['system_config']: + quota_unit = ProviderQuotaUnit.TIMES.value + else: + quota_unit = rules['system_config']['quota_unit'] + + if quota_unit == ProviderQuotaUnit.TOKENS.value: + used_quota = used_tokens + else: + used_quota = 1 + + db.session.query(Provider).filter( + Provider.tenant_id == self.provider.tenant_id, + Provider.provider_name == self.provider.provider_name, + Provider.provider_type == self.provider.provider_type, + Provider.quota_type == self.provider.quota_type, + Provider.quota_limit > Provider.quota_used + ).update({'quota_used': Provider.quota_used + used_quota}) + db.session.commit() + + def should_deduct_quota(self): + return False + + def update_last_used(self) -> None: + """ + update last used time. + + :return: + """ + db.session.query(Provider).filter( + Provider.tenant_id == self.provider.tenant_id, + Provider.provider_name == self.provider.provider_name + ).update({'last_used': datetime.utcnow()}) + db.session.commit() + + def get_payment_info(self) -> Optional[dict]: + """ + get product info if it payable. + + :return: + """ + return None + + def _get_provider_model(self, model_name: str, model_type: ModelType) -> ProviderModel: + """ + get provider model. + + :param model_name: + :param model_type: + :return: + """ + provider_model = db.session.query(ProviderModel).filter( + ProviderModel.tenant_id == self.provider.tenant_id, + ProviderModel.provider_name == self.provider.provider_name, + ProviderModel.model_name == model_name, + ProviderModel.model_type == model_type.value, + ProviderModel.is_valid == True + ).first() + + if not provider_model: + raise LLMBadRequestError(f"The model {model_name} does not exist. " + f"Please check the configuration.") + + return provider_model + + +class CredentialsValidateFailedError(Exception): + pass diff --git a/api/core/model_providers/providers/chatglm_provider.py b/api/core/model_providers/providers/chatglm_provider.py new file mode 100644 index 0000000000..f905da6f23 --- /dev/null +++ b/api/core/model_providers/providers/chatglm_provider.py @@ -0,0 +1,157 @@ +import json +from json import JSONDecodeError +from typing import Type + +from langchain.llms import ChatGLM + +from core.helper import encrypter +from core.model_providers.models.base import BaseProviderModel +from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType +from core.model_providers.models.llm.chatglm_model import ChatGLMModel +from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError +from models.provider import ProviderType + + +class ChatGLMProvider(BaseModelProvider): + + @property + def provider_name(self): + """ + Returns the name of a provider. + """ + return 'chatglm' + + def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: + if model_type == ModelType.TEXT_GENERATION: + return [ + { + 'id': 'chatglm2-6b', + 'name': 'ChatGLM2-6B', + }, + { + 'id': 'chatglm-6b', + 'name': 'ChatGLM-6B', + } + ] + else: + return [] + + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: + """ + Returns the model class. + + :param model_type: + :return: + """ + if model_type == ModelType.TEXT_GENERATION: + model_class = ChatGLMModel + else: + raise NotImplementedError + + return model_class + + def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: + """ + get model parameter rules. + + :param model_name: + :param model_type: + :return: + """ + model_max_tokens = { + 'chatglm-6b': 2000, + 'chatglm2-6b': 32000, + } + + return ModelKwargsRules( + temperature=KwargRule[float](min=0, max=2, default=1), + top_p=KwargRule[float](min=0, max=1, default=0.7), + presence_penalty=KwargRule[float](enabled=False), + frequency_penalty=KwargRule[float](enabled=False), + max_tokens=KwargRule[int](alias='max_token', min=10, max=model_max_tokens.get(model_name), default=2048), + ) + + @classmethod + def is_provider_credentials_valid_or_raise(cls, credentials: dict): + """ + Validates the given credentials. + """ + if 'api_base' not in credentials: + raise CredentialsValidateFailedError('ChatGLM Endpoint URL must be provided.') + + try: + credential_kwargs = { + 'endpoint_url': credentials['api_base'] + } + + llm = ChatGLM( + max_token=10, + **credential_kwargs + ) + + llm("ping") + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @classmethod + def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: + credentials['api_base'] = encrypter.encrypt_token(tenant_id, credentials['api_base']) + return credentials + + def get_provider_credentials(self, obfuscated: bool = False) -> dict: + if self.provider.provider_type == ProviderType.CUSTOM.value: + try: + credentials = json.loads(self.provider.encrypted_config) + except JSONDecodeError: + credentials = { + 'api_base': None + } + + if credentials['api_base']: + credentials['api_base'] = encrypter.decrypt_token( + self.provider.tenant_id, + credentials['api_base'] + ) + + if obfuscated: + credentials['api_base'] = encrypter.obfuscated_token(credentials['api_base']) + + return credentials + + return {} + + @classmethod + def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): + """ + check model credentials valid. + + :param model_name: + :param model_type: + :param credentials: + """ + return + + @classmethod + def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, + credentials: dict) -> dict: + """ + encrypt model credentials for save. + + :param tenant_id: + :param model_name: + :param model_type: + :param credentials: + :return: + """ + return {} + + def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: + """ + get credentials for llm use. + + :param model_name: + :param model_type: + :param obfuscated: + :return: + """ + return self.get_provider_credentials(obfuscated) diff --git a/api/core/model_providers/providers/hosted.py b/api/core/model_providers/providers/hosted.py new file mode 100644 index 0000000000..b34153d0ab --- /dev/null +++ b/api/core/model_providers/providers/hosted.py @@ -0,0 +1,76 @@ +import os +from typing import Optional + +import langchain +from flask import Flask +from pydantic import BaseModel + + +class HostedOpenAI(BaseModel): + api_base: str = None + api_organization: str = None + api_key: str + quota_limit: int = 0 + """Quota limit for the openai hosted model. 0 means unlimited.""" + paid_enabled: bool = False + paid_stripe_price_id: str = None + paid_increase_quota: int = 1 + + +class HostedAzureOpenAI(BaseModel): + api_base: str + api_key: str + quota_limit: int = 0 + """Quota limit for the azure openai hosted model. 0 means unlimited.""" + + +class HostedAnthropic(BaseModel): + api_base: str = None + api_key: str + quota_limit: int = 0 + """Quota limit for the anthropic hosted model. 0 means unlimited.""" + paid_enabled: bool = False + paid_stripe_price_id: str = None + paid_increase_quota: int = 1 + + +class HostedModelProviders(BaseModel): + openai: Optional[HostedOpenAI] = None + azure_openai: Optional[HostedAzureOpenAI] = None + anthropic: Optional[HostedAnthropic] = None + + +hosted_model_providers = HostedModelProviders() + + +def init_app(app: Flask): + if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true': + langchain.verbose = True + + if app.config.get("HOSTED_OPENAI_ENABLED"): + hosted_model_providers.openai = HostedOpenAI( + api_base=app.config.get("HOSTED_OPENAI_API_BASE"), + api_organization=app.config.get("HOSTED_OPENAI_API_ORGANIZATION"), + api_key=app.config.get("HOSTED_OPENAI_API_KEY"), + quota_limit=app.config.get("HOSTED_OPENAI_QUOTA_LIMIT"), + paid_enabled=app.config.get("HOSTED_OPENAI_PAID_ENABLED"), + paid_stripe_price_id=app.config.get("HOSTED_OPENAI_PAID_STRIPE_PRICE_ID"), + paid_increase_quota=app.config.get("HOSTED_OPENAI_PAID_INCREASE_QUOTA"), + ) + + if app.config.get("HOSTED_AZURE_OPENAI_ENABLED"): + hosted_model_providers.azure_openai = HostedAzureOpenAI( + api_base=app.config.get("HOSTED_AZURE_OPENAI_API_BASE"), + api_key=app.config.get("HOSTED_AZURE_OPENAI_API_KEY"), + quota_limit=app.config.get("HOSTED_AZURE_OPENAI_QUOTA_LIMIT"), + ) + + if app.config.get("HOSTED_ANTHROPIC_ENABLED"): + hosted_model_providers.anthropic = HostedAnthropic( + api_base=app.config.get("HOSTED_ANTHROPIC_API_BASE"), + api_key=app.config.get("HOSTED_ANTHROPIC_API_KEY"), + quota_limit=app.config.get("HOSTED_ANTHROPIC_QUOTA_LIMIT"), + paid_enabled=app.config.get("HOSTED_ANTHROPIC_PAID_ENABLED"), + paid_stripe_price_id=app.config.get("HOSTED_ANTHROPIC_PAID_STRIPE_PRICE_ID"), + paid_increase_quota=app.config.get("HOSTED_ANTHROPIC_PAID_INCREASE_QUOTA"), + ) diff --git a/api/core/model_providers/providers/huggingface_hub_provider.py b/api/core/model_providers/providers/huggingface_hub_provider.py new file mode 100644 index 0000000000..ded94e2a44 --- /dev/null +++ b/api/core/model_providers/providers/huggingface_hub_provider.py @@ -0,0 +1,183 @@ +import json +from typing import Type + +from huggingface_hub import HfApi +from langchain.llms import HuggingFaceEndpoint + +from core.helper import encrypter +from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType +from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel +from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError + +from core.model_providers.models.base import BaseProviderModel +from models.provider import ProviderType + + +class HuggingfaceHubProvider(BaseModelProvider): + @property + def provider_name(self): + """ + Returns the name of a provider. + """ + return 'huggingface_hub' + + def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: + return [] + + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: + """ + Returns the model class. + + :param model_type: + :return: + """ + if model_type == ModelType.TEXT_GENERATION: + model_class = HuggingfaceHubModel + else: + raise NotImplementedError + + return model_class + + def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: + """ + get model parameter rules. + + :param model_name: + :param model_type: + :return: + """ + return ModelKwargsRules( + temperature=KwargRule[float](min=0, max=2, default=1), + top_p=KwargRule[float](min=0.01, max=0.99, default=0.7), + presence_penalty=KwargRule[float](enabled=False), + frequency_penalty=KwargRule[float](enabled=False), + max_tokens=KwargRule[int](alias='max_new_tokens', min=10, max=1500, default=200), + ) + + @classmethod + def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): + """ + check model credentials valid. + + :param model_name: + :param model_type: + :param credentials: + """ + if model_type != ModelType.TEXT_GENERATION: + raise NotImplementedError + + if 'huggingfacehub_api_type' not in credentials \ + or credentials['huggingfacehub_api_type'] not in ['hosted_inference_api', 'inference_endpoints']: + raise CredentialsValidateFailedError('Hugging Face Hub API Type invalid, ' + 'must be hosted_inference_api or inference_endpoints.') + + if 'huggingfacehub_api_token' not in credentials: + raise CredentialsValidateFailedError('Hugging Face Hub API Token must be provided.') + + hfapi = HfApi(token=credentials['huggingfacehub_api_token']) + + try: + hfapi.whoami() + except Exception: + raise CredentialsValidateFailedError("Invalid API Token.") + + if credentials['huggingfacehub_api_type'] == 'inference_endpoints': + if 'huggingfacehub_endpoint_url' not in credentials: + raise CredentialsValidateFailedError('Hugging Face Hub Endpoint URL must be provided.') + + try: + llm = HuggingFaceEndpoint( + endpoint_url=credentials['huggingfacehub_endpoint_url'], + task="text2text-generation", + model_kwargs={"temperature": 0.5, "max_new_tokens": 200}, + huggingfacehub_api_token=credentials['huggingfacehub_api_token'] + ) + + llm("ping") + except Exception as e: + raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}") + else: + try: + model_info = hfapi.model_info(repo_id=model_name) + if not model_info: + raise ValueError(f'Model {model_name} not found.') + + if 'inference' in model_info.cardData and not model_info.cardData['inference']: + raise ValueError(f'Inference API has been turned off for this model {model_name}.') + + VALID_TASKS = ("text2text-generation", "text-generation", "summarization") + if model_info.pipeline_tag not in VALID_TASKS: + raise ValueError(f"Model {model_name} is not a valid task, " + f"must be one of {VALID_TASKS}.") + except Exception as e: + raise CredentialsValidateFailedError(f"{e.__class__.__name__}:{str(e)}") + + @classmethod + def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, + credentials: dict) -> dict: + """ + encrypt model credentials for save. + + :param tenant_id: + :param model_name: + :param model_type: + :param credentials: + :return: + """ + credentials['huggingfacehub_api_token'] = encrypter.encrypt_token(tenant_id, credentials['huggingfacehub_api_token']) + + if credentials['huggingfacehub_api_type'] == 'hosted_inference_api': + hfapi = HfApi(token=credentials['huggingfacehub_api_token']) + model_info = hfapi.model_info(repo_id=model_name) + if not model_info: + raise ValueError(f'Model {model_name} not found.') + + if 'inference' in model_info.cardData and not model_info.cardData['inference']: + raise ValueError(f'Inference API has been turned off for this model {model_name}.') + + credentials['task_type'] = model_info.pipeline_tag + + return credentials + + def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: + """ + get credentials for llm use. + + :param model_name: + :param model_type: + :param obfuscated: + :return: + """ + if self.provider.provider_type != ProviderType.CUSTOM.value: + raise NotImplementedError + + provider_model = self._get_provider_model(model_name, model_type) + + if not provider_model.encrypted_config: + return { + 'huggingfacehub_api_token': None, + 'task_type': None + } + + credentials = json.loads(provider_model.encrypted_config) + if credentials['huggingfacehub_api_token']: + credentials['huggingfacehub_api_token'] = encrypter.decrypt_token( + self.provider.tenant_id, + credentials['huggingfacehub_api_token'] + ) + + if obfuscated: + credentials['huggingfacehub_api_token'] = encrypter.obfuscated_token(credentials['huggingfacehub_api_token']) + + return credentials + + @classmethod + def is_provider_credentials_valid_or_raise(cls, credentials: dict): + return + + @classmethod + def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: + return {} + + def get_provider_credentials(self, obfuscated: bool = False) -> dict: + return {} diff --git a/api/core/model_providers/providers/minimax_provider.py b/api/core/model_providers/providers/minimax_provider.py new file mode 100644 index 0000000000..46ec84a6d8 --- /dev/null +++ b/api/core/model_providers/providers/minimax_provider.py @@ -0,0 +1,179 @@ +import json +from json import JSONDecodeError +from typing import Type + +from langchain.llms import Minimax + +from core.helper import encrypter +from core.model_providers.models.base import BaseProviderModel +from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding +from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType +from core.model_providers.models.llm.minimax_model import MinimaxModel +from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError +from models.provider import ProviderType, ProviderQuotaType + + +class MinimaxProvider(BaseModelProvider): + + @property + def provider_name(self): + """ + Returns the name of a provider. + """ + return 'minimax' + + def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: + if model_type == ModelType.TEXT_GENERATION: + return [ + { + 'id': 'abab5.5-chat', + 'name': 'abab5.5-chat', + }, + { + 'id': 'abab5-chat', + 'name': 'abab5-chat', + } + ] + elif model_type == ModelType.EMBEDDINGS: + return [ + { + 'id': 'embo-01', + 'name': 'embo-01', + } + ] + else: + return [] + + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: + """ + Returns the model class. + + :param model_type: + :return: + """ + if model_type == ModelType.TEXT_GENERATION: + model_class = MinimaxModel + elif model_type == ModelType.EMBEDDINGS: + model_class = MinimaxEmbedding + else: + raise NotImplementedError + + return model_class + + def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: + """ + get model parameter rules. + + :param model_name: + :param model_type: + :return: + """ + model_max_tokens = { + 'abab5.5-chat': 16384, + 'abab5-chat': 6144, + } + + return ModelKwargsRules( + temperature=KwargRule[float](min=0.01, max=1, default=0.9), + top_p=KwargRule[float](min=0, max=1, default=0.95), + presence_penalty=KwargRule[float](enabled=False), + frequency_penalty=KwargRule[float](enabled=False), + max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 6144), default=1024), + ) + + @classmethod + def is_provider_credentials_valid_or_raise(cls, credentials: dict): + """ + Validates the given credentials. + """ + if 'minimax_group_id' not in credentials: + raise CredentialsValidateFailedError('MiniMax Group ID must be provided.') + + if 'minimax_api_key' not in credentials: + raise CredentialsValidateFailedError('MiniMax API Key must be provided.') + + try: + credential_kwargs = { + 'minimax_group_id': credentials['minimax_group_id'], + 'minimax_api_key': credentials['minimax_api_key'], + } + + llm = Minimax( + model='abab5.5-chat', + max_tokens=10, + temperature=0.01, + **credential_kwargs + ) + + llm("ping") + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @classmethod + def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: + credentials['minimax_api_key'] = encrypter.encrypt_token(tenant_id, credentials['minimax_api_key']) + return credentials + + def get_provider_credentials(self, obfuscated: bool = False) -> dict: + if self.provider.provider_type == ProviderType.CUSTOM.value \ + or (self.provider.provider_type == ProviderType.SYSTEM.value + and self.provider.quota_type == ProviderQuotaType.FREE.value): + try: + credentials = json.loads(self.provider.encrypted_config) + except JSONDecodeError: + credentials = { + 'minimax_group_id': None, + 'minimax_api_key': None, + } + + if credentials['minimax_api_key']: + credentials['minimax_api_key'] = encrypter.decrypt_token( + self.provider.tenant_id, + credentials['minimax_api_key'] + ) + + if obfuscated: + credentials['minimax_api_key'] = encrypter.obfuscated_token(credentials['minimax_api_key']) + + return credentials + + return {} + + def should_deduct_quota(self): + return True + + @classmethod + def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): + """ + check model credentials valid. + + :param model_name: + :param model_type: + :param credentials: + """ + return + + @classmethod + def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, + credentials: dict) -> dict: + """ + encrypt model credentials for save. + + :param tenant_id: + :param model_name: + :param model_type: + :param credentials: + :return: + """ + return {} + + def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: + """ + get credentials for llm use. + + :param model_name: + :param model_type: + :param obfuscated: + :return: + """ + return self.get_provider_credentials(obfuscated) diff --git a/api/core/model_providers/providers/openai_provider.py b/api/core/model_providers/providers/openai_provider.py new file mode 100644 index 0000000000..0041d23ca6 --- /dev/null +++ b/api/core/model_providers/providers/openai_provider.py @@ -0,0 +1,289 @@ +import json +import logging +from json import JSONDecodeError +from typing import Type, Optional + +from flask import current_app +from openai.error import AuthenticationError, OpenAIError + +import openai + +from core.helper import encrypter +from core.model_providers.models.entity.provider import ModelFeature +from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper +from core.model_providers.models.base import BaseProviderModel +from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding +from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType +from core.model_providers.models.llm.openai_model import OpenAIModel +from core.model_providers.models.moderation.openai_moderation import OpenAIModeration +from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError +from core.model_providers.providers.hosted import hosted_model_providers +from models.provider import ProviderType, ProviderQuotaType + + +class OpenAIProvider(BaseModelProvider): + + @property + def provider_name(self): + """ + Returns the name of a provider. + """ + return 'openai' + + def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: + if model_type == ModelType.TEXT_GENERATION: + models = [ + { + 'id': 'gpt-3.5-turbo', + 'name': 'gpt-3.5-turbo', + 'features': [ + ModelFeature.AGENT_THOUGHT.value + ] + }, + { + 'id': 'gpt-3.5-turbo-16k', + 'name': 'gpt-3.5-turbo-16k', + 'features': [ + ModelFeature.AGENT_THOUGHT.value + ] + }, + { + 'id': 'gpt-4', + 'name': 'gpt-4', + 'features': [ + ModelFeature.AGENT_THOUGHT.value + ] + }, + { + 'id': 'gpt-4-32k', + 'name': 'gpt-4-32k', + 'features': [ + ModelFeature.AGENT_THOUGHT.value + ] + }, + { + 'id': 'text-davinci-003', + 'name': 'text-davinci-003', + } + ] + + if self.provider.provider_type == ProviderType.SYSTEM.value \ + and self.provider.quota_type == ProviderQuotaType.TRIAL.value: + models = [item for item in models if item['id'] not in ['gpt-4', 'gpt-4-32k']] + + return models + elif model_type == ModelType.EMBEDDINGS: + return [ + { + 'id': 'text-embedding-ada-002', + 'name': 'text-embedding-ada-002' + } + ] + elif model_type == ModelType.SPEECH_TO_TEXT: + return [ + { + 'id': 'whisper-1', + 'name': 'whisper-1' + } + ] + elif model_type == ModelType.MODERATION: + return [ + { + 'id': 'text-moderation-stable', + 'name': 'text-moderation-stable' + } + ] + else: + return [] + + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: + """ + Returns the model class. + + :param model_type: + :return: + """ + if model_type == ModelType.TEXT_GENERATION: + model_class = OpenAIModel + elif model_type == ModelType.EMBEDDINGS: + model_class = OpenAIEmbedding + elif model_type == ModelType.MODERATION: + model_class = OpenAIModeration + elif model_type == ModelType.SPEECH_TO_TEXT: + model_class = OpenAIWhisper + else: + raise NotImplementedError + + return model_class + + def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: + """ + get model parameter rules. + + :param model_name: + :param model_type: + :return: + """ + model_max_tokens = { + 'gpt-4': 8192, + 'gpt-4-32k': 32768, + 'gpt-3.5-turbo': 4096, + 'gpt-3.5-turbo-16k': 16384, + 'text-davinci-003': 4097, + } + + return ModelKwargsRules( + temperature=KwargRule[float](min=0, max=2, default=1), + top_p=KwargRule[float](min=0, max=1, default=1), + presence_penalty=KwargRule[float](min=-2, max=2, default=0), + frequency_penalty=KwargRule[float](min=-2, max=2, default=0), + max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name, 4097), default=16), + ) + + @classmethod + def is_provider_credentials_valid_or_raise(cls, credentials: dict): + """ + Validates the given credentials. + """ + if 'openai_api_key' not in credentials: + raise CredentialsValidateFailedError('OpenAI API key is required') + + try: + credentials_kwargs = { + "api_key": credentials['openai_api_key'] + } + + if 'openai_api_base' in credentials and credentials['openai_api_base']: + credentials_kwargs['api_base'] = credentials['openai_api_base'] + '/v1' + + if 'openai_organization' in credentials: + credentials_kwargs['organization'] = credentials['openai_organization'] + + openai.ChatCompletion.create( + messages=[{"role": "user", "content": 'ping'}], + model='gpt-3.5-turbo', + timeout=10, + request_timeout=(5, 30), + max_tokens=20, + **credentials_kwargs + ) + except (AuthenticationError, OpenAIError) as ex: + raise CredentialsValidateFailedError(str(ex)) + except Exception as ex: + logging.exception('OpenAI config validation failed') + raise ex + + @classmethod + def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: + credentials['openai_api_key'] = encrypter.encrypt_token(tenant_id, credentials['openai_api_key']) + return credentials + + def get_provider_credentials(self, obfuscated: bool = False) -> dict: + if self.provider.provider_type == ProviderType.CUSTOM.value: + try: + credentials = json.loads(self.provider.encrypted_config) + except JSONDecodeError: + credentials = { + 'openai_api_base': None, + 'openai_api_key': self.provider.encrypted_config, + 'openai_organization': None + } + + if credentials['openai_api_key']: + credentials['openai_api_key'] = encrypter.decrypt_token( + self.provider.tenant_id, + credentials['openai_api_key'] + ) + + if obfuscated: + credentials['openai_api_key'] = encrypter.obfuscated_token(credentials['openai_api_key']) + + if 'openai_api_base' not in credentials or not credentials['openai_api_base']: + credentials['openai_api_base'] = None + else: + credentials['openai_api_base'] = credentials['openai_api_base'] + '/v1' + + if 'openai_organization' not in credentials: + credentials['openai_organization'] = None + + return credentials + else: + if hosted_model_providers.openai: + return { + 'openai_api_base': hosted_model_providers.openai.api_base, + 'openai_api_key': hosted_model_providers.openai.api_key, + 'openai_organization': hosted_model_providers.openai.api_organization + } + else: + return { + 'openai_api_base': None, + 'openai_api_key': None, + 'openai_organization': None + } + + @classmethod + def is_provider_type_system_supported(cls) -> bool: + if current_app.config['EDITION'] != 'CLOUD': + return False + + if hosted_model_providers.openai: + return True + + return False + + def should_deduct_quota(self): + if hosted_model_providers.openai \ + and hosted_model_providers.openai.quota_limit and hosted_model_providers.openai.quota_limit > 0: + return True + + return False + + def get_payment_info(self) -> Optional[dict]: + """ + get payment info if it payable. + + :return: + """ + if hosted_model_providers.openai \ + and hosted_model_providers.openai.paid_enabled: + return { + 'product_id': hosted_model_providers.openai.paid_stripe_price_id, + 'increase_quota': hosted_model_providers.openai.paid_increase_quota, + } + + return None + + @classmethod + def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): + """ + check model credentials valid. + + :param model_name: + :param model_type: + :param credentials: + """ + return + + @classmethod + def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, credentials: dict) -> dict: + """ + encrypt model credentials for save. + + :param tenant_id: + :param model_name: + :param model_type: + :param credentials: + :return: + """ + return {} + + def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: + """ + get credentials for llm use. + + :param model_name: + :param model_type: + :param obfuscated: + :return: + """ + return self.get_provider_credentials(obfuscated) diff --git a/api/core/model_providers/providers/replicate_provider.py b/api/core/model_providers/providers/replicate_provider.py new file mode 100644 index 0000000000..404ca1c57c --- /dev/null +++ b/api/core/model_providers/providers/replicate_provider.py @@ -0,0 +1,184 @@ +import json +import logging +from typing import Type + +import replicate +from replicate.exceptions import ReplicateError + +from core.helper import encrypter +from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType +from core.model_providers.models.llm.replicate_model import ReplicateModel +from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError + +from core.model_providers.models.base import BaseProviderModel +from core.model_providers.models.embedding.replicate_embedding import ReplicateEmbedding +from models.provider import ProviderType + + +class ReplicateProvider(BaseModelProvider): + @property + def provider_name(self): + """ + Returns the name of a provider. + """ + return 'replicate' + + def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: + return [] + + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: + """ + Returns the model class. + + :param model_type: + :return: + """ + if model_type == ModelType.TEXT_GENERATION: + model_class = ReplicateModel + elif model_type == ModelType.EMBEDDINGS: + model_class = ReplicateEmbedding + else: + raise NotImplementedError + + return model_class + + def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: + """ + get model parameter rules. + + :param model_name: + :param model_type: + :return: + """ + model_credentials = self.get_model_credentials(model_name, model_type) + + model = replicate.Client(api_token=model_credentials.get("replicate_api_token")).models.get(model_name) + + try: + version = model.versions.get(model_credentials['model_version']) + except ReplicateError as e: + raise CredentialsValidateFailedError(f"Model {model_name}:{model_credentials['model_version']} not exists, " + f"cause: {e.__class__.__name__}:{str(e)}") + except Exception as e: + logging.exception("Model validate failed.") + raise e + + model_kwargs_rules = ModelKwargsRules() + for key, value in version.openapi_schema['components']['schemas']['Input']['properties'].items(): + if key not in ['debug', 'prompt'] and value['type'] in ['number', 'integer']: + if key == ['temperature', 'top_p']: + kwarg_rule = KwargRule[float]( + type=KwargRuleType.FLOAT.value if value['type'] == 'number' else KwargRuleType.INTEGER.value, + min=float(value.get('minimum')) if value.get('minimum') is not None else None, + max=float(value.get('maximum')) if value.get('maximum') is not None else None, + default=float(value.get('default')) if value.get('default') is not None else None, + ) + if key == 'temperature': + model_kwargs_rules.temperature = kwarg_rule + else: + model_kwargs_rules.top_p = kwarg_rule + elif key in ['max_length', 'max_new_tokens']: + model_kwargs_rules.max_tokens = KwargRule[int]( + alias=key, + type=KwargRuleType.INTEGER.value, + min=int(value.get('minimum')) if value.get('minimum') is not None else 1, + max=int(value.get('maximum')) if value.get('maximum') is not None else 8000, + default=int(value.get('default')) if value.get('default') is not None else 500, + ) + + return model_kwargs_rules + + @classmethod + def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): + """ + check model credentials valid. + + :param model_name: + :param model_type: + :param credentials: + """ + if 'replicate_api_token' not in credentials: + raise CredentialsValidateFailedError('Replicate API Key must be provided.') + + if 'model_version' not in credentials: + raise CredentialsValidateFailedError('Replicate Model Version must be provided.') + + if model_name.count("/") != 1: + raise CredentialsValidateFailedError('Replicate Model Name must be provided, ' + 'format: {user_name}/{model_name}') + + version = credentials['model_version'] + try: + model = replicate.Client(api_token=credentials.get("replicate_api_token")).models.get(model_name) + rst = model.versions.get(version) + + if model_type == ModelType.EMBEDDINGS \ + and 'Embedding' not in rst.openapi_schema['components']['schemas']: + raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Embedding model.") + elif model_type == ModelType.TEXT_GENERATION \ + and ('type' not in rst.openapi_schema['components']['schemas']['Output']['items'] + or rst.openapi_schema['components']['schemas']['Output']['items']['type'] != 'string'): + raise CredentialsValidateFailedError(f"Model {model_name}:{version} is not a Text Generation model.") + except ReplicateError as e: + raise CredentialsValidateFailedError( + f"Model {model_name}:{version} not exists, cause: {e.__class__.__name__}:{str(e)}") + except Exception as e: + logging.exception("Replicate config validation failed.") + raise e + + @classmethod + def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, + credentials: dict) -> dict: + """ + encrypt model credentials for save. + + :param tenant_id: + :param model_name: + :param model_type: + :param credentials: + :return: + """ + credentials['replicate_api_token'] = encrypter.encrypt_token(tenant_id, credentials['replicate_api_token']) + return credentials + + def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: + """ + get credentials for llm use. + + :param model_name: + :param model_type: + :param obfuscated: + :return: + """ + if self.provider.provider_type != ProviderType.CUSTOM.value: + raise NotImplementedError + + provider_model = self._get_provider_model(model_name, model_type) + + if not provider_model.encrypted_config: + return { + 'replicate_api_token': None, + } + + credentials = json.loads(provider_model.encrypted_config) + if credentials['replicate_api_token']: + credentials['replicate_api_token'] = encrypter.decrypt_token( + self.provider.tenant_id, + credentials['replicate_api_token'] + ) + + if obfuscated: + credentials['replicate_api_token'] = encrypter.obfuscated_token(credentials['replicate_api_token']) + + return credentials + + @classmethod + def is_provider_credentials_valid_or_raise(cls, credentials: dict): + return + + @classmethod + def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: + return {} + + def get_provider_credentials(self, obfuscated: bool = False) -> dict: + return {} diff --git a/api/core/model_providers/providers/spark_provider.py b/api/core/model_providers/providers/spark_provider.py new file mode 100644 index 0000000000..7bcd060be2 --- /dev/null +++ b/api/core/model_providers/providers/spark_provider.py @@ -0,0 +1,191 @@ +import json +import logging +from json import JSONDecodeError +from typing import Type + +from flask import current_app +from langchain.schema import HumanMessage + +from core.helper import encrypter +from core.model_providers.models.base import BaseProviderModel +from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType +from core.model_providers.models.llm.spark_model import SparkModel +from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError +from core.third_party.langchain.llms.spark import ChatSpark +from core.third_party.spark.spark_llm import SparkError +from models.provider import ProviderType, ProviderQuotaType + + +class SparkProvider(BaseModelProvider): + + @property + def provider_name(self): + """ + Returns the name of a provider. + """ + return 'spark' + + def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: + if model_type == ModelType.TEXT_GENERATION: + return [ + { + 'id': 'spark', + 'name': '星火认知大模型', + } + ] + else: + return [] + + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: + """ + Returns the model class. + + :param model_type: + :return: + """ + if model_type == ModelType.TEXT_GENERATION: + model_class = SparkModel + else: + raise NotImplementedError + + return model_class + + def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: + """ + get model parameter rules. + + :param model_name: + :param model_type: + :return: + """ + return ModelKwargsRules( + temperature=KwargRule[float](min=0, max=1, default=0.5), + top_p=KwargRule[float](enabled=False), + presence_penalty=KwargRule[float](enabled=False), + frequency_penalty=KwargRule[float](enabled=False), + max_tokens=KwargRule[int](min=10, max=4096, default=2048), + ) + + @classmethod + def is_provider_credentials_valid_or_raise(cls, credentials: dict): + """ + Validates the given credentials. + """ + if 'app_id' not in credentials: + raise CredentialsValidateFailedError('Spark app_id must be provided.') + + if 'api_key' not in credentials: + raise CredentialsValidateFailedError('Spark api_key must be provided.') + + if 'api_secret' not in credentials: + raise CredentialsValidateFailedError('Spark api_secret must be provided.') + + try: + credential_kwargs = { + 'app_id': credentials['app_id'], + 'api_key': credentials['api_key'], + 'api_secret': credentials['api_secret'], + } + + chat_llm = ChatSpark( + max_tokens=10, + temperature=0.01, + **credential_kwargs + ) + + messages = [ + HumanMessage( + content="ping" + ) + ] + + chat_llm(messages) + except SparkError as ex: + raise CredentialsValidateFailedError(str(ex)) + except Exception as ex: + logging.exception('Spark config validation failed') + raise ex + + @classmethod + def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: + credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key']) + credentials['api_secret'] = encrypter.encrypt_token(tenant_id, credentials['api_secret']) + return credentials + + def get_provider_credentials(self, obfuscated: bool = False) -> dict: + if self.provider.provider_type == ProviderType.CUSTOM.value \ + or (self.provider.provider_type == ProviderType.SYSTEM.value + and self.provider.quota_type == ProviderQuotaType.FREE.value): + try: + credentials = json.loads(self.provider.encrypted_config) + except JSONDecodeError: + credentials = { + 'app_id': None, + 'api_key': None, + 'api_secret': None, + } + + if credentials['api_key']: + credentials['api_key'] = encrypter.decrypt_token( + self.provider.tenant_id, + credentials['api_key'] + ) + + if obfuscated: + credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key']) + + if credentials['api_secret']: + credentials['api_secret'] = encrypter.decrypt_token( + self.provider.tenant_id, + credentials['api_secret'] + ) + + if obfuscated: + credentials['api_secret'] = encrypter.obfuscated_token(credentials['api_secret']) + + return credentials + else: + return { + 'app_id': None, + 'api_key': None, + 'api_secret': None, + } + + def should_deduct_quota(self): + return True + + @classmethod + def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): + """ + check model credentials valid. + + :param model_name: + :param model_type: + :param credentials: + """ + return + + @classmethod + def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, + credentials: dict) -> dict: + """ + encrypt model credentials for save. + + :param tenant_id: + :param model_name: + :param model_type: + :param credentials: + :return: + """ + return {} + + def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: + """ + get credentials for llm use. + + :param model_name: + :param model_type: + :param obfuscated: + :return: + """ + return self.get_provider_credentials(obfuscated) diff --git a/api/core/model_providers/providers/tongyi_provider.py b/api/core/model_providers/providers/tongyi_provider.py new file mode 100644 index 0000000000..ffa7c72db4 --- /dev/null +++ b/api/core/model_providers/providers/tongyi_provider.py @@ -0,0 +1,157 @@ +import json +from json import JSONDecodeError +from typing import Type + +from core.helper import encrypter +from core.model_providers.models.base import BaseProviderModel +from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType +from core.model_providers.models.llm.tongyi_model import TongyiModel +from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError +from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi +from models.provider import ProviderType + + +class TongyiProvider(BaseModelProvider): + + @property + def provider_name(self): + """ + Returns the name of a provider. + """ + return 'tongyi' + + def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: + if model_type == ModelType.TEXT_GENERATION: + return [ + { + 'id': 'qwen-v1', + 'name': 'qwen-v1', + }, + { + 'id': 'qwen-plus-v1', + 'name': 'qwen-plus-v1', + } + ] + else: + return [] + + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: + """ + Returns the model class. + + :param model_type: + :return: + """ + if model_type == ModelType.TEXT_GENERATION: + model_class = TongyiModel + else: + raise NotImplementedError + + return model_class + + def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: + """ + get model parameter rules. + + :param model_name: + :param model_type: + :return: + """ + model_max_tokens = { + 'qwen-v1': 1500, + 'qwen-plus-v1': 6500 + } + + return ModelKwargsRules( + temperature=KwargRule[float](enabled=False), + top_p=KwargRule[float](min=0, max=1, default=0.8), + presence_penalty=KwargRule[float](enabled=False), + frequency_penalty=KwargRule[float](enabled=False), + max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024), + ) + + @classmethod + def is_provider_credentials_valid_or_raise(cls, credentials: dict): + """ + Validates the given credentials. + """ + if 'dashscope_api_key' not in credentials: + raise CredentialsValidateFailedError('Dashscope API Key must be provided.') + + try: + credential_kwargs = { + 'dashscope_api_key': credentials['dashscope_api_key'] + } + + llm = EnhanceTongyi( + model_name='qwen-v1', + max_retries=1, + **credential_kwargs + ) + + llm("ping") + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @classmethod + def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: + credentials['dashscope_api_key'] = encrypter.encrypt_token(tenant_id, credentials['dashscope_api_key']) + return credentials + + def get_provider_credentials(self, obfuscated: bool = False) -> dict: + if self.provider.provider_type == ProviderType.CUSTOM.value: + try: + credentials = json.loads(self.provider.encrypted_config) + except JSONDecodeError: + credentials = { + 'dashscope_api_key': None + } + + if credentials['dashscope_api_key']: + credentials['dashscope_api_key'] = encrypter.decrypt_token( + self.provider.tenant_id, + credentials['dashscope_api_key'] + ) + + if obfuscated: + credentials['dashscope_api_key'] = encrypter.obfuscated_token(credentials['dashscope_api_key']) + + return credentials + + return {} + + @classmethod + def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): + """ + check model credentials valid. + + :param model_name: + :param model_type: + :param credentials: + """ + return + + @classmethod + def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, + credentials: dict) -> dict: + """ + encrypt model credentials for save. + + :param tenant_id: + :param model_name: + :param model_type: + :param credentials: + :return: + """ + return {} + + def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: + """ + get credentials for llm use. + + :param model_name: + :param model_type: + :param obfuscated: + :return: + """ + return self.get_provider_credentials(obfuscated) diff --git a/api/core/model_providers/providers/wenxin_provider.py b/api/core/model_providers/providers/wenxin_provider.py new file mode 100644 index 0000000000..1c62b72d95 --- /dev/null +++ b/api/core/model_providers/providers/wenxin_provider.py @@ -0,0 +1,182 @@ +import json +from json import JSONDecodeError +from typing import Type + +from core.helper import encrypter +from core.model_providers.models.base import BaseProviderModel +from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType +from core.model_providers.models.llm.wenxin_model import WenxinModel +from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError +from core.third_party.langchain.llms.wenxin import Wenxin +from models.provider import ProviderType + + +class WenxinProvider(BaseModelProvider): + + @property + def provider_name(self): + """ + Returns the name of a provider. + """ + return 'wenxin' + + def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: + if model_type == ModelType.TEXT_GENERATION: + return [ + { + 'id': 'ernie-bot', + 'name': 'ERNIE-Bot', + }, + { + 'id': 'ernie-bot-turbo', + 'name': 'ERNIE-Bot-turbo', + }, + { + 'id': 'bloomz-7b', + 'name': 'BLOOMZ-7B', + } + ] + else: + return [] + + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: + """ + Returns the model class. + + :param model_type: + :return: + """ + if model_type == ModelType.TEXT_GENERATION: + model_class = WenxinModel + else: + raise NotImplementedError + + return model_class + + def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: + """ + get model parameter rules. + + :param model_name: + :param model_type: + :return: + """ + if model_name in ['ernie-bot', 'ernie-bot-turbo']: + return ModelKwargsRules( + temperature=KwargRule[float](min=0.01, max=1, default=0.95), + top_p=KwargRule[float](min=0.01, max=1, default=0.8), + presence_penalty=KwargRule[float](enabled=False), + frequency_penalty=KwargRule[float](enabled=False), + max_tokens=KwargRule[int](enabled=False), + ) + else: + return ModelKwargsRules( + temperature=KwargRule[float](enabled=False), + top_p=KwargRule[float](enabled=False), + presence_penalty=KwargRule[float](enabled=False), + frequency_penalty=KwargRule[float](enabled=False), + max_tokens=KwargRule[int](enabled=False), + ) + + @classmethod + def is_provider_credentials_valid_or_raise(cls, credentials: dict): + """ + Validates the given credentials. + """ + if 'api_key' not in credentials: + raise CredentialsValidateFailedError('Wenxin api_key must be provided.') + + if 'secret_key' not in credentials: + raise CredentialsValidateFailedError('Wenxin secret_key must be provided.') + + try: + credential_kwargs = { + 'api_key': credentials['api_key'], + 'secret_key': credentials['secret_key'], + } + + llm = Wenxin( + temperature=0.01, + **credential_kwargs + ) + + llm("ping") + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @classmethod + def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: + credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key']) + credentials['secret_key'] = encrypter.encrypt_token(tenant_id, credentials['secret_key']) + return credentials + + def get_provider_credentials(self, obfuscated: bool = False) -> dict: + if self.provider.provider_type == ProviderType.CUSTOM.value: + try: + credentials = json.loads(self.provider.encrypted_config) + except JSONDecodeError: + credentials = { + 'api_key': None, + 'secret_key': None, + } + + if credentials['api_key']: + credentials['api_key'] = encrypter.decrypt_token( + self.provider.tenant_id, + credentials['api_key'] + ) + + if obfuscated: + credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key']) + + if credentials['secret_key']: + credentials['secret_key'] = encrypter.decrypt_token( + self.provider.tenant_id, + credentials['secret_key'] + ) + + if obfuscated: + credentials['secret_key'] = encrypter.obfuscated_token(credentials['secret_key']) + + return credentials + else: + return { + 'api_key': None, + 'secret_key': None, + } + + @classmethod + def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): + """ + check model credentials valid. + + :param model_name: + :param model_type: + :param credentials: + """ + return + + @classmethod + def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, + credentials: dict) -> dict: + """ + encrypt model credentials for save. + + :param tenant_id: + :param model_name: + :param model_type: + :param credentials: + :return: + """ + return {} + + def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: + """ + get credentials for llm use. + + :param model_name: + :param model_type: + :param obfuscated: + :return: + """ + return self.get_provider_credentials(obfuscated) diff --git a/api/core/model_providers/rules.py b/api/core/model_providers/rules.py new file mode 100644 index 0000000000..5a911500de --- /dev/null +++ b/api/core/model_providers/rules.py @@ -0,0 +1,47 @@ +import json +import os + + +def init_provider_rules(): + # Get the absolute path of the subdirectory + subdirectory_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'rules') + + # Path to the providers.json file + providers_json_file_path = os.path.join(subdirectory_path, '_providers.json') + + try: + # Open the JSON file and read its content + with open(providers_json_file_path, 'r') as json_file: + data = json.load(json_file) + # Store the content in a dictionary with the key as the file name (without extension) + provider_names = data + except FileNotFoundError: + return "JSON file not found or path error" + except json.JSONDecodeError: + return "JSON file decoding error" + + # Dictionary to store the content of all JSON files + json_data = {} + + try: + # Loop through all files in the directory + for provider_name in provider_names: + filename = provider_name + '.json' + + # Path to each JSON file + json_file_path = os.path.join(subdirectory_path, filename) + + # Open each JSON file and read its content + with open(json_file_path, 'r') as json_file: + data = json.load(json_file) + # Store the content in the dictionary with the key as the file name (without extension) + json_data[os.path.splitext(filename)[0]] = data + + return json_data + except FileNotFoundError: + return "JSON file not found or path error" + except json.JSONDecodeError: + return "JSON file decoding error" + + +provider_rules = init_provider_rules() diff --git a/api/core/model_providers/rules/_providers.json b/api/core/model_providers/rules/_providers.json new file mode 100644 index 0000000000..ad53f425ce --- /dev/null +++ b/api/core/model_providers/rules/_providers.json @@ -0,0 +1,12 @@ +[ + "openai", + "azure_openai", + "anthropic", + "minimax", + "tongyi", + "spark", + "wenxin", + "chatglm", + "replicate", + "huggingface_hub" +] \ No newline at end of file diff --git a/api/core/model_providers/rules/anthropic.json b/api/core/model_providers/rules/anthropic.json new file mode 100644 index 0000000000..56806aa7c6 --- /dev/null +++ b/api/core/model_providers/rules/anthropic.json @@ -0,0 +1,15 @@ +{ + "support_provider_types": [ + "system", + "custom" + ], + "system_config": { + "supported_quota_types": [ + "trial", + "paid" + ], + "quota_unit": "times", + "quota_limit": 1000 + }, + "model_flexibility": "fixed" +} \ No newline at end of file diff --git a/api/core/model_providers/rules/azure_openai.json b/api/core/model_providers/rules/azure_openai.json new file mode 100644 index 0000000000..5badb07178 --- /dev/null +++ b/api/core/model_providers/rules/azure_openai.json @@ -0,0 +1,7 @@ +{ + "support_provider_types": [ + "custom" + ], + "system_config": null, + "model_flexibility": "configurable" +} \ No newline at end of file diff --git a/api/core/model_providers/rules/chatglm.json b/api/core/model_providers/rules/chatglm.json new file mode 100644 index 0000000000..0af3e61ec7 --- /dev/null +++ b/api/core/model_providers/rules/chatglm.json @@ -0,0 +1,7 @@ +{ + "support_provider_types": [ + "custom" + ], + "system_config": null, + "model_flexibility": "fixed" +} \ No newline at end of file diff --git a/api/core/model_providers/rules/huggingface_hub.json b/api/core/model_providers/rules/huggingface_hub.json new file mode 100644 index 0000000000..5badb07178 --- /dev/null +++ b/api/core/model_providers/rules/huggingface_hub.json @@ -0,0 +1,7 @@ +{ + "support_provider_types": [ + "custom" + ], + "system_config": null, + "model_flexibility": "configurable" +} \ No newline at end of file diff --git a/api/core/model_providers/rules/minimax.json b/api/core/model_providers/rules/minimax.json new file mode 100644 index 0000000000..e19b885a25 --- /dev/null +++ b/api/core/model_providers/rules/minimax.json @@ -0,0 +1,13 @@ +{ + "support_provider_types": [ + "system", + "custom" + ], + "system_config": { + "supported_quota_types": [ + "free" + ], + "quota_unit": "tokens" + }, + "model_flexibility": "fixed" +} \ No newline at end of file diff --git a/api/core/model_providers/rules/openai.json b/api/core/model_providers/rules/openai.json new file mode 100644 index 0000000000..e615de6063 --- /dev/null +++ b/api/core/model_providers/rules/openai.json @@ -0,0 +1,14 @@ +{ + "support_provider_types": [ + "system", + "custom" + ], + "system_config": { + "supported_quota_types": [ + "trial" + ], + "quota_unit": "times", + "quota_limit": 200 + }, + "model_flexibility": "fixed" +} \ No newline at end of file diff --git a/api/core/model_providers/rules/replicate.json b/api/core/model_providers/rules/replicate.json new file mode 100644 index 0000000000..5badb07178 --- /dev/null +++ b/api/core/model_providers/rules/replicate.json @@ -0,0 +1,7 @@ +{ + "support_provider_types": [ + "custom" + ], + "system_config": null, + "model_flexibility": "configurable" +} \ No newline at end of file diff --git a/api/core/model_providers/rules/spark.json b/api/core/model_providers/rules/spark.json new file mode 100644 index 0000000000..e19b885a25 --- /dev/null +++ b/api/core/model_providers/rules/spark.json @@ -0,0 +1,13 @@ +{ + "support_provider_types": [ + "system", + "custom" + ], + "system_config": { + "supported_quota_types": [ + "free" + ], + "quota_unit": "tokens" + }, + "model_flexibility": "fixed" +} \ No newline at end of file diff --git a/api/core/model_providers/rules/tongyi.json b/api/core/model_providers/rules/tongyi.json new file mode 100644 index 0000000000..0af3e61ec7 --- /dev/null +++ b/api/core/model_providers/rules/tongyi.json @@ -0,0 +1,7 @@ +{ + "support_provider_types": [ + "custom" + ], + "system_config": null, + "model_flexibility": "fixed" +} \ No newline at end of file diff --git a/api/core/model_providers/rules/wenxin.json b/api/core/model_providers/rules/wenxin.json new file mode 100644 index 0000000000..0af3e61ec7 --- /dev/null +++ b/api/core/model_providers/rules/wenxin.json @@ -0,0 +1,7 @@ +{ + "support_provider_types": [ + "custom" + ], + "system_config": null, + "model_flexibility": "fixed" +} \ No newline at end of file diff --git a/api/core/orchestrator_rule_parser.py b/api/core/orchestrator_rule_parser.py index 38361f65c5..021f8c935f 100644 --- a/api/core/orchestrator_rule_parser.py +++ b/api/core/orchestrator_rule_parser.py @@ -3,7 +3,6 @@ from typing import Optional from langchain import WikipediaAPIWrapper from langchain.callbacks.manager import Callbacks -from langchain.chat_models import ChatOpenAI from langchain.memory.chat_memory import BaseChatMemory from langchain.tools import BaseTool, Tool, WikipediaQueryRun from pydantic import BaseModel, Field @@ -15,7 +14,8 @@ from core.callback_handler.main_chain_gather_callback_handler import MainChainGa from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain from core.conversation_message_task import ConversationMessageTask -from core.llm.llm_builder import LLMBuilder +from core.model_providers.model_factory import ModelFactory +from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode from core.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tool.provider.serpapi_provider import SerpAPIToolProvider from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput @@ -32,11 +32,9 @@ class OrchestratorRuleParser: def __init__(self, tenant_id: str, app_model_config: AppModelConfig): self.tenant_id = tenant_id self.app_model_config = app_model_config - self.agent_summary_model_name = "gpt-3.5-turbo-16k" - self.dataset_retrieve_model_name = "gpt-3.5-turbo" def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory], - rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \ + rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \ -> Optional[AgentExecutor]: if not self.app_model_config.agent_mode_dict: return None @@ -47,43 +45,50 @@ class OrchestratorRuleParser: chain = None if agent_mode_config and agent_mode_config.get('enabled'): tool_configs = agent_mode_config.get('tools', []) + agent_provider_name = model_dict.get('provider', 'openai') agent_model_name = model_dict.get('name', 'gpt-4') + agent_model_instance = ModelFactory.get_text_generation_model( + tenant_id=self.tenant_id, + model_provider_name=agent_provider_name, + model_name=agent_model_name, + model_kwargs=ModelKwargs( + temperature=0.2, + top_p=0.3, + max_tokens=1500 + ) + ) + # add agent callback to record agent thoughts agent_callback = AgentLoopGatherCallbackHandler( - model_name=agent_model_name, + model_instant=agent_model_instance, conversation_message_task=conversation_message_task ) chain_callback.agent_callback = agent_callback - - agent_llm = LLMBuilder.to_llm( - tenant_id=self.tenant_id, - model_name=agent_model_name, - temperature=0, - max_tokens=1500, - callbacks=[agent_callback, DifyStdOutCallbackHandler()] - ) + agent_model_instance.add_callbacks([agent_callback]) planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router')) # only OpenAI chat model (include Azure) support function call, use ReACT instead - if not isinstance(agent_llm, ChatOpenAI) \ - and planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]: - planning_strategy = PlanningStrategy.REACT + if agent_model_instance.model_mode != ModelMode.CHAT \ + or agent_model_instance.name not in ['openai', 'azure_openai']: + if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]: + planning_strategy = PlanningStrategy.REACT + elif planning_strategy == PlanningStrategy.ROUTER: + planning_strategy = PlanningStrategy.REACT_ROUTER - summary_llm = LLMBuilder.to_llm( + summary_model_instance = ModelFactory.get_text_generation_model( tenant_id=self.tenant_id, - model_name=self.agent_summary_model_name, - temperature=0, - max_tokens=500, - callbacks=[DifyStdOutCallbackHandler()] + model_kwargs=ModelKwargs( + temperature=0, + max_tokens=500 + ) ) tools = self.to_tools( tool_configs=tool_configs, conversation_message_task=conversation_message_task, - model_name=self.agent_summary_model_name, rest_tokens=rest_tokens, callbacks=[agent_callback, DifyStdOutCallbackHandler()] ) @@ -91,20 +96,11 @@ class OrchestratorRuleParser: if len(tools) == 0: return None - dataset_llm = LLMBuilder.to_llm( - tenant_id=self.tenant_id, - model_name=self.dataset_retrieve_model_name, - temperature=0, - max_tokens=500, - callbacks=[DifyStdOutCallbackHandler()] - ) - agent_configuration = AgentConfiguration( strategy=planning_strategy, - llm=agent_llm, + model_instance=agent_model_instance, tools=tools, - summary_llm=summary_llm, - dataset_llm=dataset_llm, + summary_model_instance=summary_model_instance, memory=memory, callbacks=[chain_callback, agent_callback], max_iterations=10, @@ -141,13 +137,12 @@ class OrchestratorRuleParser: return None def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask, - model_name: str, rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]: + rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]: """ Convert app agent tool configs to tools :param rest_tokens: :param tool_configs: app agent tool configs - :param model_name: :param conversation_message_task: :param callbacks: :return: @@ -163,7 +158,7 @@ class OrchestratorRuleParser: if tool_type == "dataset": tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens) elif tool_type == "web_reader": - tool = self.to_web_reader_tool(model_name) + tool = self.to_web_reader_tool() elif tool_type == "google_search": tool = self.to_google_search_tool() elif tool_type == "wikipedia": @@ -205,20 +200,22 @@ class OrchestratorRuleParser: return tool - def to_web_reader_tool(self, model_name: str) -> Optional[BaseTool]: + def to_web_reader_tool(self) -> Optional[BaseTool]: """ A tool for reading web pages :return: """ - summary_llm = LLMBuilder.to_llm( + summary_model_instance = ModelFactory.get_text_generation_model( tenant_id=self.tenant_id, - model_name=model_name, - temperature=0, - max_tokens=500, - callbacks=[DifyStdOutCallbackHandler()] + model_kwargs=ModelKwargs( + temperature=0, + max_tokens=500 + ) ) + summary_llm = summary_model_instance.client + tool = WebReaderTool( llm=summary_llm, max_chunk_length=4000, @@ -273,6 +270,10 @@ class OrchestratorRuleParser: def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int: DEFAULT_K = 2 CONTEXT_TOKENS_PERCENT = 0.3 + + if rest_tokens == -1: + return DEFAULT_K + processing_rule = dataset.latest_process_rule if not processing_rule: return DEFAULT_K diff --git a/api/core/third_party/langchain/embeddings/__init__.py b/api/core/third_party/langchain/embeddings/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/third_party/langchain/embeddings/replicate_embedding.py b/api/core/third_party/langchain/embeddings/replicate_embedding.py new file mode 100644 index 0000000000..0113ba8db7 --- /dev/null +++ b/api/core/third_party/langchain/embeddings/replicate_embedding.py @@ -0,0 +1,99 @@ +"""Wrapper around Replicate embedding models.""" +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Extra, root_validator + +from langchain.embeddings.base import Embeddings +from langchain.utils import get_from_dict_or_env + + +class ReplicateEmbeddings(BaseModel, Embeddings): + """Wrapper around Replicate embedding models. + + To use, you should have the ``replicate`` python package installed. + """ + + client: Any #: :meta private: + model: str + """Model name to use.""" + + replicate_api_token: Optional[str] = None + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + replicate_api_token = get_from_dict_or_env( + values, "replicate_api_token", "REPLICATE_API_TOKEN" + ) + try: + import replicate as replicate_python + + values["client"] = replicate_python.Client(api_token=replicate_api_token) + except ImportError: + raise ImportError( + "Could not import replicate python package. " + "Please install it with `pip install replicate`." + ) + return values + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Call out to Replicate's embedding endpoint. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + # get the model and version + model_str, version_str = self.model.split(":") + model = self.client.models.get(model_str) + version = model.versions.get(version_str) + + # sort through the openapi schema to get the name of the first input + input_properties = sorted( + version.openapi_schema["components"]["schemas"]["Input"][ + "properties" + ].items(), + key=lambda item: item[1].get("x-order", 0), + ) + first_input_name = input_properties[0][0] + + embeddings = [] + for text in texts: + result = self.client.run(self.model, input={first_input_name: text}) + embeddings.append(result[0].get('embedding')) + + return [list(map(float, e)) for e in embeddings] + + def embed_query(self, text: str) -> List[float]: + """Call out to Replicate's embedding endpoint. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + # get the model and version + model_str, version_str = self.model.split(":") + model = self.client.models.get(model_str) + version = model.versions.get(version_str) + + # sort through the openapi schema to get the name of the first input + input_properties = sorted( + version.openapi_schema["components"]["schemas"]["Input"][ + "properties" + ].items(), + key=lambda item: item[1].get("x-order", 0), + ) + first_input_name = input_properties[0][0] + result = self.client.run(self.model, input={first_input_name: text}) + embedding = result[0].get('embedding') + + return list(map(float, embedding)) diff --git a/api/core/third_party/langchain/llms/__init__.py b/api/core/third_party/langchain/llms/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/llm/streamable_azure_chat_open_ai.py b/api/core/third_party/langchain/llms/azure_chat_open_ai.py similarity index 75% rename from api/core/llm/streamable_azure_chat_open_ai.py rename to api/core/third_party/langchain/llms/azure_chat_open_ai.py index 1e2681fa8d..a1f6aa6ece 100644 --- a/api/core/llm/streamable_azure_chat_open_ai.py +++ b/api/core/third_party/langchain/llms/azure_chat_open_ai.py @@ -1,15 +1,13 @@ -from langchain.callbacks.manager import Callbacks, CallbackManagerForLLMRun -from langchain.chat_models.openai import _convert_dict_to_message -from langchain.schema import BaseMessage, LLMResult, ChatResult, ChatGeneration -from langchain.chat_models import AzureChatOpenAI -from typing import Optional, List, Dict, Any, Tuple, Union +from typing import Dict, Any, Optional, List, Tuple, Union +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.chat_models import AzureChatOpenAI +from langchain.chat_models.openai import _convert_dict_to_message +from langchain.schema import ChatResult, BaseMessage, ChatGeneration from pydantic import root_validator -from core.llm.wrappers.openai_wrapper import handle_openai_exceptions - -class StreamableAzureChatOpenAI(AzureChatOpenAI): +class EnhanceAzureChatOpenAI(AzureChatOpenAI): request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0) """Timeout for requests to OpenAI completion API. Default is 600 seconds.""" max_retries: int = 1 @@ -52,32 +50,6 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI): "organization": self.openai_organization if self.openai_organization else None, } - @handle_openai_exceptions - def generate( - self, - messages: List[List[BaseMessage]], - stop: Optional[List[str]] = None, - callbacks: Callbacks = None, - **kwargs: Any, - ) -> LLMResult: - return super().generate(messages, stop, callbacks, **kwargs) - - @classmethod - def get_kwargs_from_model_params(cls, params: dict): - model_kwargs = { - 'top_p': params.get('top_p', 1), - 'frequency_penalty': params.get('frequency_penalty', 0), - 'presence_penalty': params.get('presence_penalty', 0), - } - - del params['top_p'] - del params['frequency_penalty'] - del params['presence_penalty'] - - params['model_kwargs'] = model_kwargs - - return params - def _generate( self, messages: List[BaseMessage], @@ -116,4 +88,4 @@ class StreamableAzureChatOpenAI(AzureChatOpenAI): ) return ChatResult(generations=[ChatGeneration(message=message)]) response = self.completion_with_retry(messages=message_dicts, **params) - return self._create_chat_result(response) + return self._create_chat_result(response) \ No newline at end of file diff --git a/api/core/llm/streamable_azure_open_ai.py b/api/core/third_party/langchain/llms/azure_open_ai.py similarity index 87% rename from api/core/llm/streamable_azure_open_ai.py rename to api/core/third_party/langchain/llms/azure_open_ai.py index ab67f5abcb..f4d7155336 100644 --- a/api/core/llm/streamable_azure_open_ai.py +++ b/api/core/third_party/langchain/llms/azure_open_ai.py @@ -1,16 +1,14 @@ -from langchain.callbacks.manager import Callbacks, CallbackManagerForLLMRun +from typing import Dict, Any, Mapping, Optional, List, Union, Tuple + +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms import AzureOpenAI from langchain.llms.openai import _streaming_response_template, completion_with_retry, _update_response, \ update_token_usage from langchain.schema import LLMResult -from typing import Optional, List, Dict, Mapping, Any, Union, Tuple - from pydantic import root_validator -from core.llm.wrappers.openai_wrapper import handle_openai_exceptions - -class StreamableAzureOpenAI(AzureOpenAI): +class EnhanceAzureOpenAI(AzureOpenAI): openai_api_type: str = "azure" openai_api_version: str = "" request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0) @@ -56,20 +54,6 @@ class StreamableAzureOpenAI(AzureOpenAI): "organization": self.openai_organization if self.openai_organization else None, }} - @handle_openai_exceptions - def generate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - callbacks: Callbacks = None, - **kwargs: Any, - ) -> LLMResult: - return super().generate(prompts, stop, callbacks, **kwargs) - - @classmethod - def get_kwargs_from_model_params(cls, params: dict): - return params - def _generate( self, prompts: List[str], diff --git a/api/core/llm/streamable_chat_open_ai.py b/api/core/third_party/langchain/llms/chat_open_ai.py similarity index 62% rename from api/core/llm/streamable_chat_open_ai.py rename to api/core/third_party/langchain/llms/chat_open_ai.py index 64e4548442..b409a9889b 100644 --- a/api/core/llm/streamable_chat_open_ai.py +++ b/api/core/third_party/langchain/llms/chat_open_ai.py @@ -1,16 +1,12 @@ import os -from langchain.callbacks.manager import Callbacks -from langchain.schema import BaseMessage, LLMResult -from langchain.chat_models import ChatOpenAI -from typing import Optional, List, Dict, Any, Union, Tuple +from typing import Dict, Any, Optional, Union, Tuple +from langchain.chat_models import ChatOpenAI from pydantic import root_validator -from core.llm.wrappers.openai_wrapper import handle_openai_exceptions - -class StreamableChatOpenAI(ChatOpenAI): +class EnhanceChatOpenAI(ChatOpenAI): request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0) """Timeout for requests to OpenAI completion API. Default is 600 seconds.""" max_retries: int = 1 @@ -51,29 +47,3 @@ class StreamableChatOpenAI(ChatOpenAI): "api_key": self.openai_api_key, "organization": self.openai_organization if self.openai_organization else None, } - - @handle_openai_exceptions - def generate( - self, - messages: List[List[BaseMessage]], - stop: Optional[List[str]] = None, - callbacks: Callbacks = None, - **kwargs: Any, - ) -> LLMResult: - return super().generate(messages, stop, callbacks, **kwargs) - - @classmethod - def get_kwargs_from_model_params(cls, params: dict): - model_kwargs = { - 'top_p': params.get('top_p', 1), - 'frequency_penalty': params.get('frequency_penalty', 0), - 'presence_penalty': params.get('presence_penalty', 0), - } - - del params['top_p'] - del params['frequency_penalty'] - del params['presence_penalty'] - - params['model_kwargs'] = model_kwargs - - return params diff --git a/api/core/llm/fake.py b/api/core/third_party/langchain/llms/fake.py similarity index 85% rename from api/core/llm/fake.py rename to api/core/third_party/langchain/llms/fake.py index b7190220f8..b901df935a 100644 --- a/api/core/llm/fake.py +++ b/api/core/third_party/langchain/llms/fake.py @@ -1,9 +1,11 @@ import time -from typing import List, Optional, Any, Mapping +from typing import List, Optional, Any, Mapping, Callable from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.chat_models.base import SimpleChatModel -from langchain.schema import BaseMessage, ChatResult, AIMessage, ChatGeneration, BaseLanguageModel +from langchain.schema import BaseMessage, ChatResult, AIMessage, ChatGeneration + +from core.model_providers.models.entity.message import str_to_prompt_messages class FakeLLM(SimpleChatModel): @@ -12,7 +14,7 @@ class FakeLLM(SimpleChatModel): streaming: bool = False """Whether to stream the results or not.""" response: str - origin_llm: Optional[BaseLanguageModel] = None + num_token_func: Optional[Callable] = None @property def _llm_type(self) -> str: @@ -33,7 +35,7 @@ class FakeLLM(SimpleChatModel): return {"response": self.response} def get_num_tokens(self, text: str) -> int: - return self.origin_llm.get_num_tokens(text) if self.origin_llm else 0 + return self.num_token_func(str_to_prompt_messages([text])) if self.num_token_func else 0 def _generate( self, diff --git a/api/core/llm/streamable_open_ai.py b/api/core/third_party/langchain/llms/open_ai.py similarity index 74% rename from api/core/llm/streamable_open_ai.py rename to api/core/third_party/langchain/llms/open_ai.py index cfb32da3a6..a16998ab5f 100644 --- a/api/core/llm/streamable_open_ai.py +++ b/api/core/third_party/langchain/llms/open_ai.py @@ -1,15 +1,11 @@ import os -from langchain.callbacks.manager import Callbacks -from langchain.schema import LLMResult -from typing import Optional, List, Dict, Any, Mapping, Union, Tuple +from typing import Dict, Any, Mapping, Optional, Union, Tuple from langchain import OpenAI from pydantic import root_validator -from core.llm.wrappers.openai_wrapper import handle_openai_exceptions - -class StreamableOpenAI(OpenAI): +class EnhanceOpenAI(OpenAI): request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0) """Timeout for requests to OpenAI completion API. Default is 600 seconds.""" max_retries: int = 1 @@ -52,17 +48,3 @@ class StreamableOpenAI(OpenAI): "api_key": self.openai_api_key, "organization": self.openai_organization if self.openai_organization else None, }} - - @handle_openai_exceptions - def generate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - callbacks: Callbacks = None, - **kwargs: Any, - ) -> LLMResult: - return super().generate(prompts, stop, callbacks, **kwargs) - - @classmethod - def get_kwargs_from_model_params(cls, params: dict): - return params diff --git a/api/core/third_party/langchain/llms/replicate_llm.py b/api/core/third_party/langchain/llms/replicate_llm.py new file mode 100644 index 0000000000..556ef2b102 --- /dev/null +++ b/api/core/third_party/langchain/llms/replicate_llm.py @@ -0,0 +1,75 @@ +from typing import Dict, Optional, List, Any + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms import Replicate +from langchain.utils import get_from_dict_or_env +from pydantic import root_validator + + +class EnhanceReplicate(Replicate): + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + replicate_api_token = get_from_dict_or_env( + values, "replicate_api_token", "REPLICATE_API_TOKEN" + ) + values["replicate_api_token"] = replicate_api_token + return values + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call to replicate endpoint.""" + try: + import replicate as replicate_python + except ImportError: + raise ImportError( + "Could not import replicate python package. " + "Please install it with `pip install replicate`." + ) + + client = replicate_python.Client(api_token=self.replicate_api_token) + + # get the model and version + model_str, version_str = self.model.split(":") + model = client.models.get(model_str) + version = model.versions.get(version_str) + + # sort through the openapi schema to get the name of the first input + input_properties = sorted( + version.openapi_schema["components"]["schemas"]["Input"][ + "properties" + ].items(), + key=lambda item: item[1].get("x-order", 0), + ) + first_input_name = input_properties[0][0] + inputs = {first_input_name: prompt, **self.input} + + prediction = client.predictions.create( + version=version, input={**inputs, **kwargs} + ) + current_completion: str = "" + stop_condition_reached = False + for output in prediction.output_iterator(): + current_completion += output + + # test for stop conditions, if specified + if stop: + for s in stop: + if s in current_completion: + prediction.cancel() + stop_index = current_completion.find(s) + current_completion = current_completion[:stop_index] + stop_condition_reached = True + break + + if stop_condition_reached: + break + + if self.streaming and run_manager: + run_manager.on_llm_new_token(output) + return current_completion diff --git a/api/core/third_party/langchain/llms/spark.py b/api/core/third_party/langchain/llms/spark.py new file mode 100644 index 0000000000..23eb7472a9 --- /dev/null +++ b/api/core/third_party/langchain/llms/spark.py @@ -0,0 +1,185 @@ +import re +import string +import threading +from _decimal import Decimal, ROUND_HALF_UP +from typing import Dict, List, Optional, Any, Mapping + +from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun +from langchain.chat_models.base import BaseChatModel +from langchain.llms.utils import enforce_stop_tokens +from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage, ChatResult, \ + ChatGeneration +from langchain.utils import get_from_dict_or_env +from pydantic import root_validator + +from core.third_party.spark.spark_llm import SparkLLMClient + + +class ChatSpark(BaseChatModel): + r"""Wrapper around Spark's large language model. + + To use, you should pass `app_id`, `api_key`, `api_secret` + as a named parameter to the constructor. + + Example: + .. code-block:: python + + client = SparkLLMClient( + app_id="", + api_key="", + api_secret="" + ) + """ + client: Any = None #: :meta private: + + max_tokens: int = 256 + """Denotes the number of tokens to predict per generation.""" + + temperature: Optional[float] = None + """A non-negative float that tunes the degree of randomness in generation.""" + + top_k: Optional[int] = None + """Number of most likely tokens to consider at each step.""" + + user_id: Optional[str] = None + """User ID to use for the model.""" + + streaming: bool = False + """Whether to stream the results.""" + + app_id: Optional[str] = None + api_key: Optional[str] = None + api_secret: Optional[str] = None + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["app_id"] = get_from_dict_or_env( + values, "app_id", "SPARK_APP_ID" + ) + values["api_key"] = get_from_dict_or_env( + values, "api_key", "SPARK_API_KEY" + ) + values["api_secret"] = get_from_dict_or_env( + values, "api_secret", "SPARK_API_SECRET" + ) + + values["client"] = SparkLLMClient( + app_id=values["app_id"], + api_key=values["api_key"], + api_secret=values["api_secret"], + ) + return values + + @property + def _default_params(self) -> Mapping[str, Any]: + """Get the default parameters for calling Anthropic API.""" + d = { + "max_tokens": self.max_tokens + } + if self.temperature is not None: + d["temperature"] = self.temperature + if self.top_k is not None: + d["top_k"] = self.top_k + return d + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return {**{}, **self._default_params} + @property + def lc_secrets(self) -> Dict[str, str]: + return {"api_key": "API_KEY", "api_secret": "API_SECRET"} + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "spark-chat" + + @property + def lc_serializable(self) -> bool: + return True + + def _convert_messages_to_dicts(self, messages: List[BaseMessage]) -> list[dict]: + """Format a list of messages into a full dict list. + + Args: + messages (List[BaseMessage]): List of BaseMessage to combine. + + Returns: + list[dict] + """ + messages = messages.copy() # don't mutate the original list + + new_messages = [] + for message in messages: + if isinstance(message, ChatMessage): + new_messages.append({'role': 'user', 'content': message.content}) + elif isinstance(message, HumanMessage) or isinstance(message, SystemMessage): + new_messages.append({'role': 'user', 'content': message.content}) + elif isinstance(message, AIMessage): + new_messages.append({'role': 'assistant', 'content': message.content}) + else: + raise ValueError(f"Got unknown type {message}") + + return new_messages + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + messages = self._convert_messages_to_dicts(messages) + + thread = threading.Thread(target=self.client.run, args=( + messages, + self.user_id, + self._default_params, + self.streaming + )) + thread.start() + + completion = "" + for content in self.client.subscribe(): + if isinstance(content, dict): + delta = content['data'] + else: + delta = content + + completion += delta + if self.streaming and run_manager: + run_manager.on_llm_new_token( + delta, + ) + + thread.join() + + if stop is not None: + completion = enforce_stop_tokens(completion, stop) + + message = AIMessage(content=completion) + return ChatResult(generations=[ChatGeneration(message=message)]) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + message = AIMessage(content='') + return ChatResult(generations=[ChatGeneration(message=message)]) + + def get_num_tokens(self, text: str) -> float: + """Calculate number of tokens.""" + total = Decimal(0) + words = re.findall(r'\b\w+\b|[{}]|\s'.format(re.escape(string.punctuation)), text) + for word in words: + if word: + if '\u4e00' <= word <= '\u9fff': # if chinese + total += Decimal('1.5') + else: + total += Decimal('0.8') + return int(total) diff --git a/api/core/third_party/langchain/llms/tongyi_llm.py b/api/core/third_party/langchain/llms/tongyi_llm.py new file mode 100644 index 0000000000..c8241fe084 --- /dev/null +++ b/api/core/third_party/langchain/llms/tongyi_llm.py @@ -0,0 +1,82 @@ +from typing import Dict, Any, List, Optional + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms import Tongyi +from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry +from langchain.schema import Generation, LLMResult + + +class EnhanceTongyi(Tongyi): + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling OpenAI API.""" + normal_params = { + "top_p": self.top_p, + "api_key": self.dashscope_api_key + } + + return {**normal_params, **self.model_kwargs} + + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> LLMResult: + generations = [] + params: Dict[str, Any] = { + **{"model": self.model_name}, + **self._default_params, + **kwargs, + } + if self.streaming: + if len(prompts) > 1: + raise ValueError("Cannot stream results with multiple prompts.") + params["stream"] = True + text = '' + for stream_resp in stream_generate_with_retry( + self, prompt=prompts[0], **params + ): + if not generations: + current_text = stream_resp["output"]["text"] + else: + current_text = stream_resp["output"]["text"][len(text):] + + text = stream_resp["output"]["text"] + + generations.append( + [ + Generation( + text=current_text, + generation_info=dict( + finish_reason=stream_resp["output"]["finish_reason"], + ), + ) + ] + ) + + if run_manager: + run_manager.on_llm_new_token( + current_text, + verbose=self.verbose, + logprobs=None, + ) + else: + for prompt in prompts: + completion = generate_with_retry( + self, + prompt=prompt, + **params, + ) + generations.append( + [ + Generation( + text=completion["output"]["text"], + generation_info=dict( + finish_reason=completion["output"]["finish_reason"], + ), + ) + ] + ) + return LLMResult(generations=generations) diff --git a/api/core/third_party/langchain/llms/wenxin.py b/api/core/third_party/langchain/llms/wenxin.py new file mode 100644 index 0000000000..a10fb82b71 --- /dev/null +++ b/api/core/third_party/langchain/llms/wenxin.py @@ -0,0 +1,233 @@ +"""Wrapper around Wenxin APIs.""" +from __future__ import annotations + +import json +import logging +from typing import ( + Any, + Dict, + List, + Optional, Iterator, +) + +import requests +from langchain.llms.utils import enforce_stop_tokens +from langchain.schema.output import GenerationChunk +from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator + +from langchain.callbacks.manager import ( + CallbackManagerForLLMRun, +) +from langchain.llms.base import LLM +from langchain.utils import get_from_dict_or_env + +logger = logging.getLogger(__name__) + + +class _WenxinEndpointClient(BaseModel): + """An API client that talks to a Wenxin llm endpoint.""" + + base_url: str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/" + secret_key: str + api_key: str + + def get_access_token(self) -> str: + url = f"https://aip.baidubce.com/oauth/2.0/token?client_id={self.api_key}" \ + f"&client_secret={self.secret_key}&grant_type=client_credentials" + + headers = { + 'Content-Type': 'application/json', + 'Accept': 'application/json' + } + + response = requests.post(url, headers=headers) + if not response.ok: + raise ValueError(f"Wenxin HTTP {response.status_code} error: {response.text}") + if 'error' in response.json(): + raise ValueError( + f"Wenxin API {response.json()['error']}" + f" error: {response.json()['error_description']}" + ) + + access_token = response.json()['access_token'] + + # todo add cache + + return access_token + + def post(self, request: dict) -> Any: + if 'model' not in request: + raise ValueError(f"Wenxin Model name is required") + + model_url_map = { + 'ernie-bot': 'completions', + 'ernie-bot-turbo': 'eb-instant', + 'bloomz-7b': 'bloomz_7b1', + } + + stream = 'stream' in request and request['stream'] + + access_token = self.get_access_token() + api_url = f"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}" + + headers = {"Content-Type": "application/json"} + response = requests.post(api_url, + headers=headers, + json=request, + stream=stream) + if not response.ok: + raise ValueError(f"Wenxin HTTP {response.status_code} error: {response.text}") + + if not stream: + json_response = response.json() + if 'error_code' in json_response: + raise ValueError( + f"Wenxin API {json_response['error_code']}" + f" error: {json_response['error_msg']}" + ) + return json_response["result"] + else: + return response + + +class Wenxin(LLM): + """Wrapper around Wenxin large language models. + To use, you should have the environment variable + ``WENXIN_API_KEY`` and ``WENXIN_SECRET_KEY`` set with your API key, + or pass them as a named parameter to the constructor. + Example: + .. code-block:: python + from langchain.llms.wenxin import Wenxin + wenxin = Wenxin(model="", api_key="my-api-key", + secret_key="my-group-id") + """ + + _client: _WenxinEndpointClient = PrivateAttr() + model: str = "ernie-bot" + """Model name to use.""" + temperature: float = 0.7 + """A non-negative float that tunes the degree of randomness in generation.""" + top_p: float = 0.95 + """Total probability mass of tokens to consider at each step.""" + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any model parameters valid for `create` call not explicitly specified.""" + streaming: bool = False + """Whether to stream the response or return it all at once.""" + api_key: Optional[str] = None + secret_key: Optional[str] = None + + class Config: + """Configuration for this pydantic object.""" + + extra = Extra.forbid + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + values["api_key"] = get_from_dict_or_env( + values, "api_key", "WENXIN_API_KEY" + ) + values["secret_key"] = get_from_dict_or_env( + values, "secret_key", "WENXIN_SECRET_KEY" + ) + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling OpenAI API.""" + return { + "model": self.model, + "temperature": self.temperature, + "top_p": self.top_p, + "stream": self.streaming, + **self.model_kwargs, + } + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Get the identifying parameters.""" + return {**{"model": self.model}, **self._default_params} + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "wenxin" + + def __init__(self, **data: Any): + super().__init__(**data) + self._client = _WenxinEndpointClient( + api_key=self.api_key, + secret_key=self.secret_key, + ) + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + r"""Call out to Wenxin's completion endpoint to chat + Args: + prompt: The prompt to pass into the model. + Returns: + The string generated by the model. + Example: + .. code-block:: python + response = wenxin("Tell me a joke.") + """ + if self.streaming: + completion = "" + for chunk in self._stream( + prompt=prompt, stop=stop, run_manager=run_manager, **kwargs + ): + completion += chunk.text + else: + request = self._default_params + request["messages"] = [{"role": "user", "content": prompt}] + request.update(kwargs) + completion = self._client.post(request) + + if stop is not None: + completion = enforce_stop_tokens(completion, stop) + + return completion + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + r"""Call wenxin completion_stream and return the resulting generator. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + Returns: + A generator representing the stream of tokens from Wenxin. + Example: + .. code-block:: python + + prompt = "Write a poem about a stream." + prompt = f"\n\nHuman: {prompt}\n\nAssistant:" + generator = wenxin.stream(prompt) + for token in generator: + yield token + """ + request = self._default_params + request["messages"] = [{"role": "user", "content": prompt}] + request.update(kwargs) + + for token in self._client.post(request).iter_lines(): + if token: + token = token.decode("utf-8") + completion = json.loads(token[5:]) + + yield GenerationChunk(text=completion['result']) + if run_manager: + run_manager.on_llm_new_token(completion['result']) + + if completion['is_end']: + break diff --git a/api/core/third_party/spark/__init__.py b/api/core/third_party/spark/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/third_party/spark/spark_llm.py b/api/core/third_party/spark/spark_llm.py new file mode 100644 index 0000000000..2b6d9b498c --- /dev/null +++ b/api/core/third_party/spark/spark_llm.py @@ -0,0 +1,150 @@ +import base64 +import datetime +import hashlib +import hmac +import json +import queue +from typing import Optional +from urllib.parse import urlparse +import ssl +from datetime import datetime +from time import mktime +from urllib.parse import urlencode +from wsgiref.handlers import format_date_time + +import websocket + + +class SparkLLMClient: + def __init__(self, app_id: str, api_key: str, api_secret: str): + + self.api_base = "ws://spark-api.xf-yun.com/v1.1/chat" + self.app_id = app_id + self.ws_url = self.create_url( + urlparse(self.api_base).netloc, + urlparse(self.api_base).path, + self.api_base, + api_key, + api_secret + ) + + self.queue = queue.Queue() + self.blocking_message = '' + + def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str: + # generate timestamp by RFC1123 + now = datetime.now() + date = format_date_time(mktime(now.timetuple())) + + signature_origin = "host: " + host + "\n" + signature_origin += "date: " + date + "\n" + signature_origin += "GET " + path + " HTTP/1.1" + + # encrypt using hmac-sha256 + signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'), + digestmod=hashlib.sha256).digest() + + signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') + + authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' + + authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + + v = { + "authorization": authorization, + "date": date, + "host": host + } + # generate url + url = api_base + '?' + urlencode(v) + return url + + def run(self, messages: list, user_id: str, + model_kwargs: Optional[dict] = None, streaming: bool = False): + websocket.enableTrace(False) + ws = websocket.WebSocketApp( + self.ws_url, + on_message=self.on_message, + on_error=self.on_error, + on_close=self.on_close, + on_open=self.on_open + ) + ws.messages = messages + ws.user_id = user_id + ws.model_kwargs = model_kwargs + ws.streaming = streaming + ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE}) + + def on_error(self, ws, error): + self.queue.put({'error': error}) + ws.close() + + def on_close(self, ws, close_status_code, close_reason): + self.queue.put({'done': True}) + + def on_open(self, ws): + self.blocking_message = '' + data = json.dumps(self.gen_params( + messages=ws.messages, + user_id=ws.user_id, + model_kwargs=ws.model_kwargs + )) + ws.send(data) + + def on_message(self, ws, message): + data = json.loads(message) + code = data['header']['code'] + if code != 0: + self.queue.put({'error': f"Code: {code}, Error: {data['header']['message']}"}) + ws.close() + else: + choices = data["payload"]["choices"] + status = choices["status"] + content = choices["text"][0]["content"] + if ws.streaming: + self.queue.put({'data': content}) + else: + self.blocking_message += content + + if status == 2: + if not ws.streaming: + self.queue.put({'data': self.blocking_message}) + ws.close() + + def gen_params(self, messages: list, user_id: str, + model_kwargs: Optional[dict] = None) -> dict: + data = { + "header": { + "app_id": self.app_id, + "uid": user_id + }, + "parameter": { + "chat": { + "domain": "general" + } + }, + "payload": { + "message": { + "text": messages + } + } + } + + if model_kwargs: + data['parameter']['chat'].update(model_kwargs) + + return data + + def subscribe(self): + while True: + content = self.queue.get() + if 'error' in content: + raise SparkError(content['error']) + + if 'data' not in content: + break + yield content + + +class SparkError(Exception): + pass diff --git a/api/core/tool/dataset_index_tool.py b/api/core/tool/dataset_index_tool.py deleted file mode 100644 index c459ebaf13..0000000000 --- a/api/core/tool/dataset_index_tool.py +++ /dev/null @@ -1,102 +0,0 @@ -from flask import current_app -from langchain.embeddings import OpenAIEmbeddings -from langchain.tools import BaseTool - -from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.embedding.cached_embedding import CacheEmbedding -from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig -from core.index.vector_index.vector_index import VectorIndex -from core.llm.llm_builder import LLMBuilder -from models.dataset import Dataset, DocumentSegment - - -class DatasetTool(BaseTool): - """Tool for querying a Dataset.""" - - dataset: Dataset - k: int = 2 - - def _run(self, tool_input: str) -> str: - if self.dataset.indexing_technique == "economy": - # use keyword table query - kw_table_index = KeywordTableIndex( - dataset=self.dataset, - config=KeywordTableConfig( - max_keywords_per_chunk=5 - ) - ) - - documents = kw_table_index.search(tool_input, search_kwargs={'k': self.k}) - return str("\n".join([document.page_content for document in documents])) - else: - model_credentials = LLMBuilder.get_model_credentials( - tenant_id=self.dataset.tenant_id, - model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'), - model_name='text-embedding-ada-002' - ) - - embeddings = CacheEmbedding(OpenAIEmbeddings( - **model_credentials - )) - - vector_index = VectorIndex( - dataset=self.dataset, - config=current_app.config, - embeddings=embeddings - ) - - documents = vector_index.search( - tool_input, - search_type='similarity', - search_kwargs={ - 'k': self.k - } - ) - - hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id) - hit_callback.on_tool_end(documents) - document_context_list = [] - index_node_ids = [document.metadata['doc_id'] for document in documents] - segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == 'completed', - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids) - ).all() - - if segments: - for segment in segments: - if segment.answer: - document_context_list.append(segment.answer) - else: - document_context_list.append(segment.content) - - return str("\n".join(document_context_list)) - - async def _arun(self, tool_input: str) -> str: - model_credentials = LLMBuilder.get_model_credentials( - tenant_id=self.dataset.tenant_id, - model_provider=LLMBuilder.get_default_provider(self.dataset.tenant_id, 'text-embedding-ada-002'), - model_name='text-embedding-ada-002' - ) - - embeddings = CacheEmbedding(OpenAIEmbeddings( - **model_credentials - )) - - vector_index = VectorIndex( - dataset=self.dataset, - config=current_app.config, - embeddings=embeddings - ) - - documents = await vector_index.asearch( - tool_input, - search_type='similarity', - search_kwargs={ - 'k': 10 - } - ) - - hit_callback = DatasetIndexToolCallbackHandler(self.dataset.id) - hit_callback.on_tool_end(documents) - return str("\n".join([document.page_content for document in documents])) diff --git a/api/core/tool/dataset_retriever_tool.py b/api/core/tool/dataset_retriever_tool.py index 35f15bbceb..57ff10ae9b 100644 --- a/api/core/tool/dataset_retriever_tool.py +++ b/api/core/tool/dataset_retriever_tool.py @@ -2,7 +2,6 @@ import re from typing import Type from flask import current_app -from langchain.embeddings import OpenAIEmbeddings from langchain.tools import BaseTool from pydantic import Field, BaseModel @@ -10,7 +9,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa from core.embedding.cached_embedding import CacheEmbedding from core.index.keyword_table_index.keyword_table_index import KeywordTableIndex, KeywordTableConfig from core.index.vector_index.vector_index import VectorIndex -from core.llm.llm_builder import LLMBuilder +from core.model_providers.model_factory import ModelFactory from extensions.ext_database import db from models.dataset import Dataset, DocumentSegment @@ -71,15 +70,11 @@ class DatasetRetrieverTool(BaseTool): documents = kw_table_index.search(query, search_kwargs={'k': self.k}) return str("\n".join([document.page_content for document in documents])) else: - model_credentials = LLMBuilder.get_model_credentials( - tenant_id=dataset.tenant_id, - model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'), - model_name='text-embedding-ada-002' + embedding_model = ModelFactory.get_embedding_model( + tenant_id=dataset.tenant_id ) - embeddings = CacheEmbedding(OpenAIEmbeddings( - **model_credentials - )) + embeddings = CacheEmbedding(embedding_model) vector_index = VectorIndex( dataset=dataset, diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py index 28a94d7a7d..02020e9192 100644 --- a/api/events/event_handlers/__init__.py +++ b/api/events/event_handlers/__init__.py @@ -1,7 +1,5 @@ from .create_installed_app_when_app_created import handle from .delete_installed_app_when_app_deleted import handle -from .create_provider_when_tenant_created import handle -from .create_provider_when_tenant_updated import handle from .clean_when_document_deleted import handle from .clean_when_dataset_deleted import handle from .update_app_dataset_join_when_app_model_config_updated import handle diff --git a/api/events/event_handlers/create_provider_when_tenant_created.py b/api/events/event_handlers/create_provider_when_tenant_created.py deleted file mode 100644 index 0d35670258..0000000000 --- a/api/events/event_handlers/create_provider_when_tenant_created.py +++ /dev/null @@ -1,24 +0,0 @@ -from flask import current_app - -from events.tenant_event import tenant_was_updated -from models.provider import ProviderName -from services.provider_service import ProviderService - - -@tenant_was_updated.connect -def handle(sender, **kwargs): - tenant = sender - if tenant.status == 'normal': - ProviderService.create_system_provider( - tenant, - ProviderName.OPENAI.value, - current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'], - True - ) - - ProviderService.create_system_provider( - tenant, - ProviderName.ANTHROPIC.value, - current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'], - True - ) diff --git a/api/events/event_handlers/create_provider_when_tenant_updated.py b/api/events/event_handlers/create_provider_when_tenant_updated.py deleted file mode 100644 index 366e13c599..0000000000 --- a/api/events/event_handlers/create_provider_when_tenant_updated.py +++ /dev/null @@ -1,24 +0,0 @@ -from flask import current_app - -from events.tenant_event import tenant_was_created -from models.provider import ProviderName -from services.provider_service import ProviderService - - -@tenant_was_created.connect -def handle(sender, **kwargs): - tenant = sender - if tenant.status == 'normal': - ProviderService.create_system_provider( - tenant, - ProviderName.OPENAI.value, - current_app.config['OPENAI_HOSTED_QUOTA_LIMIT'], - True - ) - - ProviderService.create_system_provider( - tenant, - ProviderName.ANTHROPIC.value, - current_app.config['ANTHROPIC_HOSTED_QUOTA_LIMIT'], - True - ) diff --git a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py index 4c1bbee53e..dc18bf44f4 100644 --- a/api/events/event_handlers/generate_conversation_name_when_first_message_created.py +++ b/api/events/event_handlers/generate_conversation_name_when_first_message_created.py @@ -23,7 +23,6 @@ def handle(sender, **kwargs): conversation.name = name except: conversation.name = 'New Chat' - logging.exception('generate_conversation_name failed') db.session.add(conversation) db.session.commit() diff --git a/api/extensions/ext_stripe.py b/api/extensions/ext_stripe.py new file mode 100644 index 0000000000..3a192c081a --- /dev/null +++ b/api/extensions/ext_stripe.py @@ -0,0 +1,6 @@ +import stripe + + +def init_app(app): + if app.config.get('STRIPE_API_KEY'): + stripe.api_key = app.config.get('STRIPE_API_KEY') diff --git a/api/libs/rsa.py b/api/libs/rsa.py index 8741989a9a..a04282a5f8 100644 --- a/api/libs/rsa.py +++ b/api/libs/rsa.py @@ -1,16 +1,14 @@ # -*- coding:utf-8 -*- import hashlib -from Crypto.Cipher import PKCS1_OAEP +from Crypto.Cipher import PKCS1_OAEP, AES from Crypto.PublicKey import RSA +from Crypto.Random import get_random_bytes from extensions.ext_redis import redis_client from extensions.ext_storage import storage -# TODO: PKCS1_OAEP is no longer recommended for new systems and protocols. It is recommended to migrate to PKCS1_PSS. - - def generate_key_pair(tenant_id): private_key = RSA.generate(2048) public_key = private_key.publickey() @@ -25,14 +23,26 @@ def generate_key_pair(tenant_id): return pem_public.decode() +prefix_hybrid = b"HYBRID:" + + def encrypt(text, public_key): if isinstance(public_key, str): public_key = public_key.encode() + aes_key = get_random_bytes(16) + cipher_aes = AES.new(aes_key, AES.MODE_EAX) + + ciphertext, tag = cipher_aes.encrypt_and_digest(text.encode()) + rsa_key = RSA.import_key(public_key) - cipher = PKCS1_OAEP.new(rsa_key) - encrypted_text = cipher.encrypt(text.encode()) - return encrypted_text + cipher_rsa = PKCS1_OAEP.new(rsa_key) + + enc_aes_key = cipher_rsa.encrypt(aes_key) + + encrypted_data = enc_aes_key + cipher_aes.nonce + tag + ciphertext + + return prefix_hybrid + encrypted_data def decrypt(encrypted_text, tenant_id): @@ -49,8 +59,23 @@ def decrypt(encrypted_text, tenant_id): redis_client.setex(cache_key, 120, private_key) rsa_key = RSA.import_key(private_key) - cipher = PKCS1_OAEP.new(rsa_key) - decrypted_text = cipher.decrypt(encrypted_text) + cipher_rsa = PKCS1_OAEP.new(rsa_key) + + if encrypted_text.startswith(prefix_hybrid): + encrypted_text = encrypted_text[len(prefix_hybrid):] + + enc_aes_key = encrypted_text[:rsa_key.size_in_bytes()] + nonce = encrypted_text[rsa_key.size_in_bytes():rsa_key.size_in_bytes() + 16] + tag = encrypted_text[rsa_key.size_in_bytes() + 16:rsa_key.size_in_bytes() + 32] + ciphertext = encrypted_text[rsa_key.size_in_bytes() + 32:] + + aes_key = cipher_rsa.decrypt(enc_aes_key) + + cipher_aes = AES.new(aes_key, AES.MODE_EAX, nonce=nonce) + decrypted_text = cipher_aes.decrypt_and_verify(ciphertext, tag) + else: + decrypted_text = cipher_rsa.decrypt(encrypted_text) + return decrypted_text.decode() diff --git a/api/migrations/versions/16fa53d9faec_add_provider_model_support.py b/api/migrations/versions/16fa53d9faec_add_provider_model_support.py new file mode 100644 index 0000000000..92b1fba9c7 --- /dev/null +++ b/api/migrations/versions/16fa53d9faec_add_provider_model_support.py @@ -0,0 +1,79 @@ +"""add provider model support + +Revision ID: 16fa53d9faec +Revises: 8d2d099ceb74 +Create Date: 2023-08-06 16:57:51.248337 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '16fa53d9faec' +down_revision = '8d2d099ceb74' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('provider_models', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('provider_name', sa.String(length=40), nullable=False), + sa.Column('model_name', sa.String(length=40), nullable=False), + sa.Column('model_type', sa.String(length=40), nullable=False), + sa.Column('encrypted_config', sa.Text(), nullable=True), + sa.Column('is_valid', sa.Boolean(), server_default=sa.text('false'), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='provider_model_pkey'), + sa.UniqueConstraint('tenant_id', 'provider_name', 'model_name', 'model_type', name='unique_provider_model_name') + ) + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.create_index('provider_model_tenant_id_provider_idx', ['tenant_id', 'provider_name'], unique=False) + + op.create_table('tenant_default_models', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('provider_name', sa.String(length=40), nullable=False), + sa.Column('model_name', sa.String(length=40), nullable=False), + sa.Column('model_type', sa.String(length=40), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tenant_default_model_pkey') + ) + with op.batch_alter_table('tenant_default_models', schema=None) as batch_op: + batch_op.create_index('tenant_default_model_tenant_id_provider_type_idx', ['tenant_id', 'provider_name', 'model_type'], unique=False) + + op.create_table('tenant_preferred_model_providers', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('provider_name', sa.String(length=40), nullable=False), + sa.Column('preferred_provider_type', sa.String(length=40), nullable=False), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='tenant_preferred_model_provider_pkey') + ) + with op.batch_alter_table('tenant_preferred_model_providers', schema=None) as batch_op: + batch_op.create_index('tenant_preferred_model_provider_tenant_provider_idx', ['tenant_id', 'provider_name'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tenant_preferred_model_providers', schema=None) as batch_op: + batch_op.drop_index('tenant_preferred_model_provider_tenant_provider_idx') + + op.drop_table('tenant_preferred_model_providers') + with op.batch_alter_table('tenant_default_models', schema=None) as batch_op: + batch_op.drop_index('tenant_default_model_tenant_id_provider_type_idx') + + op.drop_table('tenant_default_models') + with op.batch_alter_table('provider_models', schema=None) as batch_op: + batch_op.drop_index('provider_model_tenant_id_provider_idx') + + op.drop_table('provider_models') + # ### end Alembic commands ### diff --git a/api/migrations/versions/5022897aaceb_add_model_name_in_embedding.py b/api/migrations/versions/5022897aaceb_add_model_name_in_embedding.py new file mode 100644 index 0000000000..182db6ccc3 --- /dev/null +++ b/api/migrations/versions/5022897aaceb_add_model_name_in_embedding.py @@ -0,0 +1,36 @@ +"""add model name in embedding + +Revision ID: 5022897aaceb +Revises: bf0aec5ba2cf +Create Date: 2023-08-11 14:38:15.499460 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '5022897aaceb' +down_revision = 'bf0aec5ba2cf' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('embeddings', schema=None) as batch_op: + batch_op.add_column(sa.Column('model_name', sa.String(length=40), server_default=sa.text("'text-embedding-ada-002'::character varying"), nullable=False)) + batch_op.drop_constraint('embedding_hash_idx', type_='unique') + batch_op.create_unique_constraint('embedding_hash_idx', ['model_name', 'hash']) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('embeddings', schema=None) as batch_op: + batch_op.drop_constraint('embedding_hash_idx', type_='unique') + batch_op.create_unique_constraint('embedding_hash_idx', ['hash']) + batch_op.drop_column('model_name') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/bf0aec5ba2cf_add_provider_order.py b/api/migrations/versions/bf0aec5ba2cf_add_provider_order.py new file mode 100644 index 0000000000..aa9f74fe38 --- /dev/null +++ b/api/migrations/versions/bf0aec5ba2cf_add_provider_order.py @@ -0,0 +1,52 @@ +"""add provider order + +Revision ID: bf0aec5ba2cf +Revises: e35ed59becda +Create Date: 2023-08-10 00:03:44.273430 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'bf0aec5ba2cf' +down_revision = 'e35ed59becda' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('provider_orders', + sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', postgresql.UUID(), nullable=False), + sa.Column('provider_name', sa.String(length=40), nullable=False), + sa.Column('account_id', postgresql.UUID(), nullable=False), + sa.Column('payment_product_id', sa.String(length=191), nullable=False), + sa.Column('payment_id', sa.String(length=191), nullable=True), + sa.Column('transaction_id', sa.String(length=191), nullable=True), + sa.Column('quantity', sa.Integer(), server_default=sa.text('1'), nullable=False), + sa.Column('currency', sa.String(length=40), nullable=True), + sa.Column('total_amount', sa.Integer(), nullable=True), + sa.Column('payment_status', sa.String(length=40), server_default=sa.text("'wait_pay'::character varying"), nullable=False), + sa.Column('paid_at', sa.DateTime(), nullable=True), + sa.Column('pay_failed_at', sa.DateTime(), nullable=True), + sa.Column('refunded_at', sa.DateTime(), nullable=True), + sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False), + sa.PrimaryKeyConstraint('id', name='provider_order_pkey') + ) + with op.batch_alter_table('provider_orders', schema=None) as batch_op: + batch_op.create_index('provider_order_tenant_provider_idx', ['tenant_id', 'provider_name'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('provider_orders', schema=None) as batch_op: + batch_op.drop_index('provider_order_tenant_provider_idx') + + op.drop_table('provider_orders') + # ### end Alembic commands ### diff --git a/api/migrations/versions/e35ed59becda_modify_quota_limit_field_type.py b/api/migrations/versions/e35ed59becda_modify_quota_limit_field_type.py new file mode 100644 index 0000000000..e9056d57f9 --- /dev/null +++ b/api/migrations/versions/e35ed59becda_modify_quota_limit_field_type.py @@ -0,0 +1,46 @@ +"""modify quota limit field type + +Revision ID: e35ed59becda +Revises: 16fa53d9faec +Create Date: 2023-08-09 22:20:31.577953 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'e35ed59becda' +down_revision = '16fa53d9faec' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.alter_column('quota_limit', + existing_type=sa.INTEGER(), + type_=sa.BigInteger(), + existing_nullable=True) + batch_op.alter_column('quota_used', + existing_type=sa.INTEGER(), + type_=sa.BigInteger(), + existing_nullable=True) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('providers', schema=None) as batch_op: + batch_op.alter_column('quota_used', + existing_type=sa.BigInteger(), + type_=sa.INTEGER(), + existing_nullable=True) + batch_op.alter_column('quota_limit', + existing_type=sa.BigInteger(), + type_=sa.INTEGER(), + existing_nullable=True) + + # ### end Alembic commands ### diff --git a/api/models/dataset.py b/api/models/dataset.py index b63b898df4..ecf087ef65 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -9,6 +9,7 @@ from extensions.ext_database import db from models.account import Account from models.model import App, UploadFile + class Dataset(db.Model): __tablename__ = 'datasets' __table_args__ = ( @@ -268,7 +269,7 @@ class Document(db.Model): @property def average_segment_length(self): if self.word_count and self.word_count != 0 and self.segment_count and self.segment_count != 0: - return self.word_count//self.segment_count + return self.word_count // self.segment_count return 0 @property @@ -346,16 +347,6 @@ class DocumentSegment(db.Model): def document(self): return db.session.query(Document).filter(Document.id == self.document_id).first() - @property - def embedding(self): - embedding = db.session.query(Embedding).filter(Embedding.hash == self.index_node_hash).first() \ - if self.index_node_hash else None - - if embedding: - return embedding.embedding - - return None - @property def previous_segment(self): return db.session.query(DocumentSegment).filter( @@ -436,10 +427,12 @@ class Embedding(db.Model): __tablename__ = 'embeddings' __table_args__ = ( db.PrimaryKeyConstraint('id', name='embedding_pkey'), - db.UniqueConstraint('hash', name='embedding_hash_idx') + db.UniqueConstraint('model_name', 'hash', name='embedding_hash_idx') ) id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()')) + model_name = db.Column(db.String(40), nullable=False, + server_default=db.text("'text-embedding-ada-002'::character varying")) hash = db.Column(db.String(64), nullable=False) embedding = db.Column(db.LargeBinary, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @@ -448,4 +441,4 @@ class Embedding(db.Model): self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) def get_embedding(self) -> list[float]: - return pickle.loads(self.embedding) \ No newline at end of file + return pickle.loads(self.embedding) diff --git a/api/models/provider.py b/api/models/provider.py index e4ecfa1241..63e9785a96 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -9,25 +9,30 @@ class ProviderType(Enum): CUSTOM = 'custom' SYSTEM = 'system' - -class ProviderName(Enum): - OPENAI = 'openai' - AZURE_OPENAI = 'azure_openai' - ANTHROPIC = 'anthropic' - COHERE = 'cohere' - HUGGINGFACEHUB = 'huggingfacehub' - @staticmethod def value_of(value): - for member in ProviderName: + for member in ProviderType: if member.value == value: return member raise ValueError(f"No matching enum found for value '{value}'") class ProviderQuotaType(Enum): - MONTHLY = 'monthly' + PAID = 'paid' + """hosted paid quota""" + + FREE = 'free' + """third-party free quota""" + TRIAL = 'trial' + """hosted trial quota""" + + @staticmethod + def value_of(value): + for member in ProviderQuotaType: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") class Provider(db.Model): @@ -50,8 +55,8 @@ class Provider(db.Model): last_used = db.Column(db.DateTime, nullable=True) quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying")) - quota_limit = db.Column(db.Integer, nullable=True) - quota_used = db.Column(db.Integer, default=0) + quota_limit = db.Column(db.BigInteger, nullable=True) + quota_used = db.Column(db.BigInteger, default=0) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) @@ -75,3 +80,96 @@ class Provider(db.Model): return self.is_valid else: return self.is_valid and self.token_is_set + + +class ProviderModel(db.Model): + """ + Provider model representing the API provider_models and their configurations. + """ + __tablename__ = 'provider_models' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='provider_model_pkey'), + db.Index('provider_model_tenant_id_provider_idx', 'tenant_id', 'provider_name'), + db.UniqueConstraint('tenant_id', 'provider_name', 'model_name', 'model_type', name='unique_provider_model_name') + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + provider_name = db.Column(db.String(40), nullable=False) + model_name = db.Column(db.String(40), nullable=False) + model_type = db.Column(db.String(40), nullable=False) + encrypted_config = db.Column(db.Text, nullable=True) + is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text('false')) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + +class TenantDefaultModel(db.Model): + __tablename__ = 'tenant_default_models' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='tenant_default_model_pkey'), + db.Index('tenant_default_model_tenant_id_provider_type_idx', 'tenant_id', 'provider_name', 'model_type'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + provider_name = db.Column(db.String(40), nullable=False) + model_name = db.Column(db.String(40), nullable=False) + model_type = db.Column(db.String(40), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + +class TenantPreferredModelProvider(db.Model): + __tablename__ = 'tenant_preferred_model_providers' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='tenant_preferred_model_provider_pkey'), + db.Index('tenant_preferred_model_provider_tenant_provider_idx', 'tenant_id', 'provider_name'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + provider_name = db.Column(db.String(40), nullable=False) + preferred_provider_type = db.Column(db.String(40), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + + +class ProviderOrderPaymentStatus(Enum): + WAIT_PAY = 'wait_pay' + PAID = 'paid' + PAY_FAILED = 'pay_failed' + REFUNDED = 'refunded' + + @staticmethod + def value_of(value): + for member in ProviderOrderPaymentStatus: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") + + + +class ProviderOrder(db.Model): + __tablename__ = 'provider_orders' + __table_args__ = ( + db.PrimaryKeyConstraint('id', name='provider_order_pkey'), + db.Index('provider_order_tenant_provider_idx', 'tenant_id', 'provider_name'), + ) + + id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) + tenant_id = db.Column(UUID, nullable=False) + provider_name = db.Column(db.String(40), nullable=False) + account_id = db.Column(UUID, nullable=False) + payment_product_id = db.Column(db.String(191), nullable=False) + payment_id = db.Column(db.String(191)) + transaction_id = db.Column(db.String(191)) + quantity = db.Column(db.Integer, nullable=False, server_default=db.text('1')) + currency = db.Column(db.String(40)) + total_amount = db.Column(db.Integer) + payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying")) + paid_at = db.Column(db.DateTime) + pay_failed_at = db.Column(db.DateTime) + refunded_at = db.Column(db.DateTime) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) diff --git a/api/requirements.txt b/api/requirements.txt index ccbb0e18cf..ac87a58ea7 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -10,12 +10,13 @@ flask-session2==1.3.1 flask-cors==3.0.10 gunicorn~=21.2.0 gevent~=22.10.2 -langchain==0.0.239 +langchain==0.0.250 openai~=0.27.8 psycopg2-binary~=2.9.6 pycryptodome==3.17 python-dotenv==1.0.0 pytest~=7.3.1 +pytest-mock~=3.11.1 tiktoken==0.3.3 Authlib==1.2.0 boto3~=1.26.123 @@ -40,4 +41,10 @@ newspaper3k==0.2.8 google-api-python-client==2.90.0 wikipedia==1.4.0 readabilipy==0.2.0 -google-search-results==2.4.2 \ No newline at end of file +google-search-results==2.4.2 +replicate~=0.9.0 +websocket-client~=1.6.1 +dashscope~=1.5.0 +huggingface_hub~=0.16.4 +transformers~=4.31.0 +stripe~=5.5.0 \ No newline at end of file diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index a5af67e301..abcb7e6235 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -2,42 +2,11 @@ import re import uuid from core.agent.agent_executor import PlanningStrategy -from core.constant import llm_constant +from core.model_providers.model_provider_factory import ModelProviderFactory +from core.model_providers.models.entity.model_params import ModelType from models.account import Account from services.dataset_service import DatasetService -from core.llm.llm_builder import LLMBuilder -MODEL_PROVIDERS = [ - 'openai', - 'anthropic', -] - -MODELS_BY_APP_MODE = { - 'chat': [ - 'claude-instant-1', - 'claude-2', - 'gpt-4', - 'gpt-4-32k', - 'gpt-3.5-turbo', - 'gpt-3.5-turbo-16k', - ], - 'completion': [ - 'claude-instant-1', - 'claude-2', - 'gpt-4', - 'gpt-4-32k', - 'gpt-3.5-turbo', - 'gpt-3.5-turbo-16k', - 'text-davinci-003', - ] -} - -SUPPORT_AGENT_MODELS = [ - "gpt-4", - "gpt-4-32k", - "gpt-3.5-turbo", - "gpt-3.5-turbo-16k", -] SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] @@ -65,40 +34,40 @@ class AppModelConfigService: # max_tokens if 'max_tokens' not in cp: cp["max_tokens"] = 512 - - if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \ - llm_constant.max_context_token_length[model_name]: - raise ValueError( - "max_tokens must be an integer greater than 0 " - "and not exceeding the maximum value of the corresponding model") - + # + # if not isinstance(cp["max_tokens"], int) or cp["max_tokens"] <= 0 or cp["max_tokens"] > \ + # llm_constant.max_context_token_length[model_name]: + # raise ValueError( + # "max_tokens must be an integer greater than 0 " + # "and not exceeding the maximum value of the corresponding model") + # # temperature if 'temperature' not in cp: cp["temperature"] = 1 - - if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2: - raise ValueError("temperature must be a float between 0 and 2") - + # + # if not isinstance(cp["temperature"], (float, int)) or cp["temperature"] < 0 or cp["temperature"] > 2: + # raise ValueError("temperature must be a float between 0 and 2") + # # top_p if 'top_p' not in cp: cp["top_p"] = 1 - if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2: - raise ValueError("top_p must be a float between 0 and 2") - + # if not isinstance(cp["top_p"], (float, int)) or cp["top_p"] < 0 or cp["top_p"] > 2: + # raise ValueError("top_p must be a float between 0 and 2") + # # presence_penalty if 'presence_penalty' not in cp: cp["presence_penalty"] = 0 - if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2: - raise ValueError("presence_penalty must be a float between -2 and 2") - + # if not isinstance(cp["presence_penalty"], (float, int)) or cp["presence_penalty"] < -2 or cp["presence_penalty"] > 2: + # raise ValueError("presence_penalty must be a float between -2 and 2") + # # presence_penalty if 'frequency_penalty' not in cp: cp["frequency_penalty"] = 0 - if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2: - raise ValueError("frequency_penalty must be a float between -2 and 2") + # if not isinstance(cp["frequency_penalty"], (float, int)) or cp["frequency_penalty"] < -2 or cp["frequency_penalty"] > 2: + # raise ValueError("frequency_penalty must be a float between -2 and 2") # Filter out extra parameters filtered_cp = { @@ -112,7 +81,7 @@ class AppModelConfigService: return filtered_cp @staticmethod - def validate_configuration(account: Account, config: dict, mode: str) -> dict: + def validate_configuration(tenant_id: str, account: Account, config: dict) -> dict: # opening_statement if 'opening_statement' not in config or not config["opening_statement"]: config["opening_statement"] = "" @@ -211,14 +180,21 @@ class AppModelConfigService: raise ValueError("model must be of object type") # model.provider - if 'provider' not in config["model"] or config["model"]["provider"] not in MODEL_PROVIDERS: - raise ValueError(f"model.provider is required and must be in {str(MODEL_PROVIDERS)}") + model_provider_names = ModelProviderFactory.get_provider_names() + if 'provider' not in config["model"] or config["model"]["provider"] not in model_provider_names: + raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") # model.name if 'name' not in config["model"]: raise ValueError("model.name is required") - if config["model"]["name"] not in MODELS_BY_APP_MODE[mode]: + model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, config["model"]["provider"]) + if not model_provider: + raise ValueError("model.name must be in the specified model list") + + model_list = model_provider.get_supported_model_list(ModelType.TEXT_GENERATION) + model_ids = [m['id'] for m in model_list] + if config["model"]["name"] not in model_ids: raise ValueError("model.name must be in the specified model list") # model.completion_params diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 667fb4cb67..db1f0fe218 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -1,15 +1,13 @@ import io from werkzeug.datastructures import FileStorage -from core.llm.llm_builder import LLMBuilder -from core.llm.provider.llm_provider_service import LLMProviderService +from core.model_providers.model_factory import ModelFactory from services.errors.audio import NoAudioUploadedServiceError, AudioTooLargeServiceError, UnsupportedAudioTypeServiceError, ProviderNotSupportSpeechToTextServiceError -from core.llm.whisper import Whisper -from models.provider import ProviderName FILE_SIZE = 15 FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024 ALLOWED_EXTENSIONS = ['mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'wav', 'webm'] + class AudioService: @classmethod def transcript(cls, tenant_id: str, file: FileStorage): @@ -26,14 +24,12 @@ class AudioService: if file_size > FILE_SIZE_LIMIT: message = f"Audio size larger than {FILE_SIZE} mb" raise AudioTooLargeServiceError(message) - - provider_name = LLMBuilder.get_default_provider(tenant_id, 'whisper-1') - if provider_name != ProviderName.OPENAI.value: - raise ProviderNotSupportSpeechToTextServiceError() - provider_service = LLMProviderService(tenant_id, provider_name) + model = ModelFactory.get_speech2text_model( + tenant_id=tenant_id + ) buffer = io.BytesIO(file_content) buffer.name = 'temp.mp3' - return Whisper(provider_service.provider).transcribe(buffer) + return model.run(buffer) diff --git a/api/services/completion_service.py b/api/services/completion_service.py index c081d8ec08..8899cdc11b 100644 --- a/api/services/completion_service.py +++ b/api/services/completion_service.py @@ -11,7 +11,7 @@ from sqlalchemy import and_ from core.completion import Completion from core.conversation_message_task import PubHandler, ConversationTaskStoppedException -from core.llm.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, \ +from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, \ LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -127,9 +127,9 @@ class CompletionService: # validate config model_config = AppModelConfigService.validate_configuration( + tenant_id=app_model.tenant_id, account=user, - config=args['model_config'], - mode=app_model.mode + config=args['model_config'] ) app_model_config = AppModelConfig( diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 3f3f3652e6..5edd4b3da8 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -9,8 +9,7 @@ from typing import Optional, List from flask import current_app from sqlalchemy import func -from core.llm.token_calculator import TokenCalculator -from events.event_handlers.document_index_event import document_index_created +from core.model_providers.model_factory import ModelFactory from extensions.ext_redis import redis_client from flask_login import current_user @@ -875,8 +874,13 @@ class SegmentService: content = args['content'] doc_id = str(uuid.uuid4()) segment_hash = helper.generate_text_hash(content) + + embedding_model = ModelFactory.get_embedding_model( + tenant_id=document.tenant_id + ) + # calc embedding use tokens - tokens = TokenCalculator.get_num_tokens('text-embedding-ada-002', content) + tokens = embedding_model.get_num_tokens(content) max_position = db.session.query(func.max(DocumentSegment.position)).filter( DocumentSegment.document_id == document.id ).scalar() @@ -921,8 +925,13 @@ class SegmentService: update_segment_keyword_index_task.delay(segment.id) else: segment_hash = helper.generate_text_hash(content) + + embedding_model = ModelFactory.get_embedding_model( + tenant_id=document.tenant_id + ) + # calc embedding use tokens - tokens = TokenCalculator.get_num_tokens('text-embedding-ada-002', content) + tokens = embedding_model.get_num_tokens(content) segment.content = content segment.index_node_hash = segment_hash segment.word_count = len(content) diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 17a4a4f4c6..3c1247ba56 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -4,14 +4,13 @@ from typing import List import numpy as np from flask import current_app -from langchain.embeddings import OpenAIEmbeddings from langchain.embeddings.base import Embeddings from langchain.schema import Document from sklearn.manifold import TSNE from core.embedding.cached_embedding import CacheEmbedding from core.index.vector_index.vector_index import VectorIndex -from core.llm.llm_builder import LLMBuilder +from core.model_providers.model_factory import ModelFactory from extensions.ext_database import db from models.account import Account from models.dataset import Dataset, DocumentSegment, DatasetQuery @@ -29,15 +28,11 @@ class HitTestingService: "records": [] } - model_credentials = LLMBuilder.get_model_credentials( - tenant_id=dataset.tenant_id, - model_provider=LLMBuilder.get_default_provider(dataset.tenant_id, 'text-embedding-ada-002'), - model_name='text-embedding-ada-002' + embedding_model = ModelFactory.get_embedding_model( + tenant_id=dataset.tenant_id ) - embeddings = CacheEmbedding(OpenAIEmbeddings( - **model_credentials - )) + embeddings = CacheEmbedding(embedding_model) vector_index = VectorIndex( dataset=dataset, diff --git a/api/services/provider_checkout_service.py b/api/services/provider_checkout_service.py new file mode 100644 index 0000000000..80391dfac1 --- /dev/null +++ b/api/services/provider_checkout_service.py @@ -0,0 +1,158 @@ +import datetime +import logging + +import stripe +from flask import current_app + +from core.model_providers.model_provider_factory import ModelProviderFactory +from extensions.ext_database import db +from models.account import Account +from models.provider import ProviderOrder, ProviderOrderPaymentStatus, ProviderType, Provider, ProviderQuotaType + + +class ProviderCheckout: + def __init__(self, stripe_checkout_session): + self.stripe_checkout_session = stripe_checkout_session + + def get_checkout_url(self): + return self.stripe_checkout_session.url + + +class ProviderCheckoutService: + def create_checkout(self, tenant_id: str, provider_name: str, account: Account) -> ProviderCheckout: + # check provider name is valid + model_provider_rules = ModelProviderFactory.get_provider_rules() + if provider_name not in model_provider_rules: + raise ValueError(f'provider name {provider_name} is invalid') + + model_provider_rule = model_provider_rules[provider_name] + + # check provider name can be paid + self._check_provider_payable(provider_name, model_provider_rule) + + # get stripe checkout product id + paid_provider = self._get_paid_provider(tenant_id, provider_name) + model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name) + model_provider = model_provider_class(provider=paid_provider) + payment_info = model_provider.get_payment_info() + if not payment_info: + raise ValueError(f'provider name {provider_name} not support payment') + + payment_product_id = payment_info['product_id'] + + # create provider order + provider_order = ProviderOrder( + tenant_id=tenant_id, + provider_name=provider_name, + account_id=account.id, + payment_product_id=payment_product_id, + quantity=1, + payment_status=ProviderOrderPaymentStatus.WAIT_PAY.value + ) + + db.session.add(provider_order) + db.session.flush() + + try: + # create stripe checkout session + checkout_session = stripe.checkout.Session.create( + line_items=[ + { + 'price': f'{payment_product_id}', + 'quantity': 1, + }, + ], + mode='payment', + success_url=current_app.config.get("CONSOLE_WEB_URL") + '?provider_payment=succeeded', + cancel_url=current_app.config.get("CONSOLE_WEB_URL") + '?provider_payment=cancelled', + automatic_tax={'enabled': True}, + ) + except Exception as e: + logging.exception(e) + raise ValueError(f'provider name {provider_name} create checkout session failed, please try again later') + + provider_order.payment_id = checkout_session.id + db.session.commit() + + return ProviderCheckout(checkout_session) + + def fulfill_provider_order(self, event): + provider_order = db.session.query(ProviderOrder) \ + .filter(ProviderOrder.payment_id == event['data']['object']['id']) \ + .first() + + if not provider_order: + raise ValueError(f'provider order not found, payment id: {event["data"]["object"]["id"]}') + + if provider_order.payment_status != ProviderOrderPaymentStatus.WAIT_PAY.value: + raise ValueError(f'provider order payment status is not wait pay, payment id: {event["data"]["object"]["id"]}') + + provider_order.transaction_id = event['data']['object']['payment_intent'] + provider_order.currency = event['data']['object']['currency'] + provider_order.total_amount = event['data']['object']['amount_subtotal'] + provider_order.payment_status = ProviderOrderPaymentStatus.PAID.value + provider_order.paid_at = datetime.datetime.utcnow() + provider_order.updated_at = provider_order.paid_at + + # update provider quota + provider = db.session.query(Provider).filter( + Provider.tenant_id == provider_order.tenant_id, + Provider.provider_name == provider_order.provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == ProviderQuotaType.PAID.value + ).first() + + if not provider: + raise ValueError(f'provider not found, tenant id: {provider_order.tenant_id}, ' + f'provider name: {provider_order.provider_name}') + + model_provider_class = ModelProviderFactory.get_model_provider_class(provider_order.provider_name) + model_provider = model_provider_class(provider=provider) + payment_info = model_provider.get_payment_info() + + if not payment_info: + increase_quota = 0 + else: + increase_quota = int(payment_info['increase_quota']) + + if increase_quota > 0: + provider.quota_limit += increase_quota + provider.is_valid = True + + db.session.commit() + + def _check_provider_payable(self, provider_name: str, model_provider_rule: dict): + if ProviderType.SYSTEM.value not in model_provider_rule['support_provider_types']: + raise ValueError(f'provider name {provider_name} not support payment') + + if 'system_config' not in model_provider_rule: + raise ValueError(f'provider name {provider_name} not support payment') + + if 'supported_quota_types' not in model_provider_rule['system_config']: + raise ValueError(f'provider name {provider_name} not support payment') + + if 'paid' not in model_provider_rule['system_config']['supported_quota_types']: + raise ValueError(f'provider name {provider_name} not support payment') + + def _get_paid_provider(self, tenant_id: str, provider_name: str): + paid_provider = db.session.query(Provider) \ + .filter( + Provider.tenant_id == tenant_id, + Provider.provider_name == provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == ProviderQuotaType.PAID.value, + ).first() + + if not paid_provider: + paid_provider = Provider( + tenant_id=tenant_id, + provider_name=provider_name, + provider_type=ProviderType.SYSTEM.value, + quota_type=ProviderQuotaType.PAID.value, + quota_limit=0, + quota_used=0, + ) + db.session.add(paid_provider) + db.session.commit() + + return paid_provider diff --git a/api/services/provider_service.py b/api/services/provider_service.py index fffd3fbd5b..de8f53d8fc 100644 --- a/api/services/provider_service.py +++ b/api/services/provider_service.py @@ -1,88 +1,503 @@ -from typing import Union +import datetime +import json +from collections import defaultdict +from typing import Optional -from flask import current_app - -from core.llm.provider.llm_provider_service import LLMProviderService -from models.account import Tenant -from models.provider import * +from core.model_providers.model_factory import ModelFactory +from extensions.ext_database import db +from core.model_providers.model_provider_factory import ModelProviderFactory +from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules +from models.provider import Provider, ProviderModel, TenantPreferredModelProvider, ProviderType, ProviderQuotaType, \ + TenantDefaultModel class ProviderService: - @staticmethod - def init_supported_provider(tenant): - """Initialize the model provider, check whether the supported provider has a record""" + def get_provider_list(self, tenant_id: str): + """ + get provider list of tenant. - need_init_provider_names = [ProviderName.OPENAI.value, ProviderName.AZURE_OPENAI.value, ProviderName.ANTHROPIC.value] + :param tenant_id: + :return: + """ + # get rules for all providers + model_provider_rules = ModelProviderFactory.get_provider_rules() + model_provider_names = [model_provider_name for model_provider_name, _ in model_provider_rules.items()] + configurable_model_provider_names = [ + model_provider_name + for model_provider_name, model_provider_rules in model_provider_rules.items() + if 'custom' in model_provider_rules['support_provider_types'] + and model_provider_rules['model_flexibility'] == 'configurable' + ] - providers = db.session.query(Provider).filter( - Provider.tenant_id == tenant.id, - Provider.provider_type == ProviderType.CUSTOM.value, - Provider.provider_name.in_(need_init_provider_names) + # get all providers for the tenant + providers = db.session.query(Provider) \ + .filter( + Provider.tenant_id == tenant_id, + Provider.provider_name.in_(model_provider_names), + Provider.is_valid == True + ).order_by(Provider.created_at.desc()).all() + + provider_name_to_provider_dict = defaultdict(list) + for provider in providers: + provider_name_to_provider_dict[provider.provider_name].append(provider) + + # get all configurable provider models for the tenant + provider_models = db.session.query(ProviderModel) \ + .filter( + ProviderModel.tenant_id == tenant_id, + ProviderModel.provider_name.in_(configurable_model_provider_names), + ProviderModel.is_valid == True + ).order_by(ProviderModel.created_at.desc()).all() + + provider_name_to_provider_model_dict = defaultdict(list) + for provider_model in provider_models: + provider_name_to_provider_model_dict[provider_model.provider_name].append(provider_model) + + # get all preferred provider type for the tenant + preferred_provider_types = db.session.query(TenantPreferredModelProvider) \ + .filter( + TenantPreferredModelProvider.tenant_id == tenant_id, + TenantPreferredModelProvider.provider_name.in_(model_provider_names) ).all() - exists_provider_names = [] - for provider in providers: - exists_provider_names.append(provider.provider_name) + provider_name_to_preferred_provider_type_dict = {preferred_provider_type.provider_name: preferred_provider_type + for preferred_provider_type in preferred_provider_types} - not_exists_provider_names = list(set(need_init_provider_names) - set(exists_provider_names)) + providers_list = {} - if not_exists_provider_names: - # Initialize the model provider, check whether the supported provider has a record - for provider_name in not_exists_provider_names: - provider = Provider( - tenant_id=tenant.id, - provider_name=provider_name, - provider_type=ProviderType.CUSTOM.value, - is_valid=False - ) - db.session.add(provider) + for model_provider_name, model_provider_rule in model_provider_rules.items(): + # get preferred provider type + preferred_model_provider = provider_name_to_preferred_provider_type_dict.get(model_provider_name) + preferred_provider_type = ModelProviderFactory.get_preferred_type_by_preferred_model_provider( + tenant_id, + model_provider_name, + preferred_model_provider + ) - db.session.commit() + provider_config_dict = { + "preferred_provider_type": preferred_provider_type, + "model_flexibility": model_provider_rule['model_flexibility'], + } - @staticmethod - def get_obfuscated_api_key(tenant, provider_name: ProviderName, only_custom: bool = False): - llm_provider_service = LLMProviderService(tenant.id, provider_name.value) - return llm_provider_service.get_provider_configs(obfuscated=True, only_custom=only_custom) + provider_parameter_dict = {} + if ProviderType.SYSTEM.value in model_provider_rule['support_provider_types']: + for quota_type_enum in ProviderQuotaType: + quota_type = quota_type_enum.value + if quota_type in model_provider_rule['system_config']['supported_quota_types']: + key = ProviderType.SYSTEM.value + ':' + quota_type + provider_parameter_dict[key] = { + "provider_name": model_provider_name, + "provider_type": ProviderType.SYSTEM.value, + "config": None, + "is_valid": False, # need update + "quota_type": quota_type, + "quota_unit": model_provider_rule['system_config']['quota_unit'], # need update + "quota_limit": 0 if quota_type != ProviderQuotaType.TRIAL.value else + model_provider_rule['system_config']['quota_limit'], # need update + "quota_used": 0, # need update + "last_used": None # need update + } - @staticmethod - def get_token_type(tenant, provider_name: ProviderName): - llm_provider_service = LLMProviderService(tenant.id, provider_name.value) - return llm_provider_service.get_token_type() + if ProviderType.CUSTOM.value in model_provider_rule['support_provider_types']: + provider_parameter_dict[ProviderType.CUSTOM.value] = { + "provider_name": model_provider_name, + "provider_type": ProviderType.CUSTOM.value, + "config": None, # need update + "models": [], # need update + "is_valid": False, + "last_used": None # need update + } - @staticmethod - def validate_provider_configs(tenant, provider_name: ProviderName, configs: Union[dict | str]): - if current_app.config['DISABLE_PROVIDER_CONFIG_VALIDATION']: - return - llm_provider_service = LLMProviderService(tenant.id, provider_name.value) - return llm_provider_service.config_validate(configs) + model_provider_class = ModelProviderFactory.get_model_provider_class(model_provider_name) - @staticmethod - def get_encrypted_token(tenant, provider_name: ProviderName, configs: Union[dict | str]): - llm_provider_service = LLMProviderService(tenant.id, provider_name.value) - return llm_provider_service.get_encrypted_token(configs) + current_providers = provider_name_to_provider_dict[model_provider_name] + for provider in current_providers: + if provider.provider_type == ProviderType.SYSTEM.value: + quota_type = provider.quota_type + key = f'{ProviderType.SYSTEM.value}:{quota_type}' - @staticmethod - def create_system_provider(tenant: Tenant, provider_name: str = ProviderName.OPENAI.value, quota_limit: int = 200, - is_valid: bool = True): - if current_app.config['EDITION'] != 'CLOUD': - return + if key in provider_parameter_dict: + provider_parameter_dict[key]['is_valid'] = provider.is_valid + provider_parameter_dict[key]['quota_used'] = provider.quota_used + provider_parameter_dict[key]['quota_limit'] = provider.quota_limit + provider_parameter_dict[key]['last_used'] = provider.last_used + elif provider.provider_type == ProviderType.CUSTOM.value \ + and ProviderType.CUSTOM.value in provider_parameter_dict: + # if custom + key = ProviderType.CUSTOM.value + provider_parameter_dict[key]['last_used'] = provider.last_used + provider_parameter_dict[key]['is_valid'] = provider.is_valid - provider = db.session.query(Provider).filter( - Provider.tenant_id == tenant.id, + if model_provider_rule['model_flexibility'] == 'fixed': + provider_parameter_dict[key]['config'] = model_provider_class(provider=provider) \ + .get_provider_credentials(obfuscated=True) + else: + models = [] + provider_models = provider_name_to_provider_model_dict[model_provider_name] + for provider_model in provider_models: + models.append({ + "model_name": provider_model.model_name, + "model_type": provider_model.model_type, + "config": model_provider_class(provider=provider) \ + .get_model_credentials(provider_model.model_name, + ModelType.value_of(provider_model.model_type), + obfuscated=True), + "is_valid": provider_model.is_valid + }) + provider_parameter_dict[key]['models'] = models + + provider_config_dict['providers'] = list(provider_parameter_dict.values()) + providers_list[model_provider_name] = provider_config_dict + + return providers_list + + def custom_provider_config_validate(self, provider_name: str, config: dict) -> None: + """ + validate custom provider config. + + :param provider_name: + :param config: + :return: + :raises CredentialsValidateFailedError: When the config credential verification fails. + """ + # get model provider rules + model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name) + + if model_provider_rules['model_flexibility'] != 'fixed': + raise ValueError('Only support fixed model provider') + + # only support provider type CUSTOM + if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']: + raise ValueError('Only support provider type CUSTOM') + + # validate provider config + model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name) + model_provider_class.is_provider_credentials_valid_or_raise(config) + + def save_custom_provider_config(self, tenant_id: str, provider_name: str, config: dict) -> None: + """ + save custom provider config. + + :param tenant_id: + :param provider_name: + :param config: + :return: + """ + # validate custom provider config + self.custom_provider_config_validate(provider_name, config) + + # get provider + provider = db.session.query(Provider) \ + .filter( + Provider.tenant_id == tenant_id, Provider.provider_name == provider_name, - Provider.provider_type == ProviderType.SYSTEM.value - ).one_or_none() + Provider.provider_type == ProviderType.CUSTOM.value + ).first() - if not provider: + model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name) + encrypted_config = model_provider_class.encrypt_provider_credentials(tenant_id, config) + + # save provider + if provider: + provider.encrypted_config = json.dumps(encrypted_config) + provider.is_valid = True + provider.updated_at = datetime.datetime.utcnow() + db.session.commit() + else: provider = Provider( - tenant_id=tenant.id, + tenant_id=tenant_id, provider_name=provider_name, - provider_type=ProviderType.SYSTEM.value, - quota_type=ProviderQuotaType.TRIAL.value, - quota_limit=quota_limit, - encrypted_config='', - is_valid=is_valid, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps(encrypted_config), + is_valid=True ) db.session.add(provider) db.session.commit() + + def delete_custom_provider(self, tenant_id: str, provider_name: str) -> None: + """ + delete custom provider. + + :param tenant_id: + :param provider_name: + :return: + """ + # get provider + provider = db.session.query(Provider) \ + .filter( + Provider.tenant_id == tenant_id, + Provider.provider_name == provider_name, + Provider.provider_type == ProviderType.CUSTOM.value + ).first() + + if provider: + try: + self.switch_preferred_provider(tenant_id, provider_name, ProviderType.SYSTEM.value) + except ValueError: + pass + + db.session.delete(provider) + db.session.commit() + + def custom_provider_model_config_validate(self, + provider_name: str, + model_name: str, + model_type: str, + config: dict) -> None: + """ + validate custom provider model config. + + :param provider_name: + :param model_name: + :param model_type: + :param config: + :return: + :raises CredentialsValidateFailedError: When the config credential verification fails. + """ + # get model provider rules + model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name) + + if model_provider_rules['model_flexibility'] != 'configurable': + raise ValueError('Only support configurable model provider') + + # only support provider type CUSTOM + if ProviderType.CUSTOM.value not in model_provider_rules['support_provider_types']: + raise ValueError('Only support provider type CUSTOM') + + # validate provider model config + model_type = ModelType.value_of(model_type) + model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name) + model_provider_class.is_model_credentials_valid_or_raise(model_name, model_type, config) + + def add_or_save_custom_provider_model_config(self, + tenant_id: str, + provider_name: str, + model_name: str, + model_type: str, + config: dict) -> None: + """ + Add or save custom provider model config. + + :param tenant_id: + :param provider_name: + :param model_name: + :param model_type: + :param config: + :return: + """ + # validate custom provider model config + self.custom_provider_model_config_validate(provider_name, model_name, model_type, config) + + # get provider + provider = db.session.query(Provider) \ + .filter( + Provider.tenant_id == tenant_id, + Provider.provider_name == provider_name, + Provider.provider_type == ProviderType.CUSTOM.value + ).first() + + if not provider: + provider = Provider( + tenant_id=tenant_id, + provider_name=provider_name, + provider_type=ProviderType.CUSTOM.value, + is_valid=True + ) + db.session.add(provider) + db.session.commit() + elif not provider.is_valid: + provider.is_valid = True + provider.encrypted_config = None + db.session.commit() + + model_provider_class = ModelProviderFactory.get_model_provider_class(provider_name) + encrypted_config = model_provider_class.encrypt_model_credentials( + tenant_id, + model_name, + ModelType.value_of(model_type), + config + ) + + # get provider model + provider_model = db.session.query(ProviderModel) \ + .filter( + ProviderModel.tenant_id == tenant_id, + ProviderModel.provider_name == provider_name, + ProviderModel.model_name == model_name, + ProviderModel.model_type == model_type + ).first() + + if provider_model: + provider_model.encrypted_config = json.dumps(encrypted_config) + provider_model.is_valid = True + db.session.commit() + else: + provider_model = ProviderModel( + tenant_id=tenant_id, + provider_name=provider_name, + model_name=model_name, + model_type=model_type, + encrypted_config=json.dumps(encrypted_config), + is_valid=True + ) + db.session.add(provider_model) + db.session.commit() + + def delete_custom_provider_model(self, + tenant_id: str, + provider_name: str, + model_name: str, + model_type: str) -> None: + """ + delete custom provider model. + + :param tenant_id: + :param provider_name: + :param model_name: + :param model_type: + :return: + """ + # get provider model + provider_model = db.session.query(ProviderModel) \ + .filter( + ProviderModel.tenant_id == tenant_id, + ProviderModel.provider_name == provider_name, + ProviderModel.model_name == model_name, + ProviderModel.model_type == model_type + ).first() + + if provider_model: + db.session.delete(provider_model) + db.session.commit() + + def switch_preferred_provider(self, tenant_id: str, provider_name: str, preferred_provider_type: str) -> None: + """ + switch preferred provider. + + :param tenant_id: + :param provider_name: + :param preferred_provider_type: + :return: + """ + provider_type = ProviderType.value_of(preferred_provider_type) + if not provider_type: + raise ValueError(f'Invalid preferred provider type: {preferred_provider_type}') + + model_provider_rules = ModelProviderFactory.get_provider_rule(provider_name) + if preferred_provider_type not in model_provider_rules['support_provider_types']: + raise ValueError(f'Not support provider type: {preferred_provider_type}') + + model_provider = ModelProviderFactory.get_model_provider_class(provider_name) + if not model_provider.is_provider_type_system_supported(): + return + + # get preferred provider + preferred_model_provider = db.session.query(TenantPreferredModelProvider) \ + .filter( + TenantPreferredModelProvider.tenant_id == tenant_id, + TenantPreferredModelProvider.provider_name == provider_name + ).first() + + if preferred_model_provider: + preferred_model_provider.preferred_provider_type = preferred_provider_type + else: + preferred_model_provider = TenantPreferredModelProvider( + tenant_id=tenant_id, + provider_name=provider_name, + preferred_provider_type=preferred_provider_type + ) + db.session.add(preferred_model_provider) + + db.session.commit() + + def get_default_model_of_model_type(self, tenant_id: str, model_type: str) -> Optional[TenantDefaultModel]: + """ + get default model of model type. + + :param tenant_id: + :param model_type: + :return: + """ + return ModelFactory.get_default_model(tenant_id, ModelType.value_of(model_type)) + + def update_default_model_of_model_type(self, + tenant_id: str, + model_type: str, + provider_name: str, + model_name: str) -> TenantDefaultModel: + """ + update default model of model type. + + :param tenant_id: + :param model_type: + :param provider_name: + :param model_name: + :return: + """ + return ModelFactory.update_default_model(tenant_id, ModelType.value_of(model_type), provider_name, model_name) + + def get_valid_model_list(self, tenant_id: str, model_type: str) -> list: + """ + get valid model list. + + :param tenant_id: + :param model_type: + :return: + """ + valid_model_list = [] + + # get model provider rules + model_provider_rules = ModelProviderFactory.get_provider_rules() + for model_provider_name, model_provider_rule in model_provider_rules.items(): + model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) + if not model_provider: + continue + + model_list = model_provider.get_supported_model_list(ModelType.value_of(model_type)) + provider = model_provider.provider + for model in model_list: + valid_model_dict = { + "model_name": model['id'], + "model_type": model_type, + "model_provider": { + "provider_name": provider.provider_name, + "provider_type": provider.provider_type + }, + 'features': [] + } + + if 'features' in model: + valid_model_dict['features'] = model['features'] + + if provider.provider_type == ProviderType.SYSTEM.value: + valid_model_dict['model_provider']['quota_type'] = provider.quota_type + valid_model_dict['model_provider']['quota_unit'] = model_provider_rule['system_config']['quota_unit'] + valid_model_dict['model_provider']['quota_limit'] = provider.quota_limit + valid_model_dict['model_provider']['quota_used'] = provider.quota_used + + valid_model_list.append(valid_model_dict) + + return valid_model_list + + def get_model_parameter_rules(self, tenant_id: str, model_provider_name: str, model_name: str, model_type: str) \ + -> ModelKwargsRules: + """ + get model parameter rules. + It depends on preferred provider in use. + + :param tenant_id: + :param model_provider_name: + :param model_name: + :param model_type: + :return: + """ + # get model provider + model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name) + if not model_provider: + # get empty model provider + return ModelKwargsRules() + + # get model parameter rules + return model_provider.get_model_parameter_rules(model_name, ModelType.value_of(model_type)) + diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index abd1f7f3fb..96319818c0 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -1,6 +1,6 @@ from extensions.ext_database import db from models.account import Tenant -from models.provider import Provider, ProviderType, ProviderName +from models.provider import Provider class WorkspaceService: @@ -13,8 +13,8 @@ class WorkspaceService: 'status': tenant.status, 'created_at': tenant.created_at, 'providers': [], - 'in_trail': False, - 'trial_end_reason': 'using_custom' + 'in_trial': True, + 'trial_end_reason': None } # Get providers @@ -25,25 +25,4 @@ class WorkspaceService: # Add providers to the tenant info tenant_info['providers'] = providers - custom_provider = None - system_provider = None - - for provider in providers: - if provider.provider_type == ProviderType.CUSTOM.value: - if provider.is_valid and provider.encrypted_config: - custom_provider = provider - elif provider.provider_type == ProviderType.SYSTEM.value: - if provider.provider_name == ProviderName.OPENAI.value and provider.is_valid: - system_provider = provider - - if system_provider and not custom_provider: - quota_used = system_provider.quota_used if system_provider.quota_used is not None else 0 - quota_limit = system_provider.quota_limit if system_provider.quota_limit is not None else 0 - - if quota_used >= quota_limit: - tenant_info['trial_end_reason'] = 'trial_exceeded' - else: - tenant_info['in_trail'] = True - tenant_info['trial_end_reason'] = None - return tenant_info diff --git a/api/tests/conftest.py b/api/tests/conftest.py deleted file mode 100644 index 48de037846..0000000000 --- a/api/tests/conftest.py +++ /dev/null @@ -1,50 +0,0 @@ -# -*- coding:utf-8 -*- - -import pytest -import flask_migrate - -from app import create_app -from extensions.ext_database import db - - -@pytest.fixture(scope='module') -def test_client(): - # Create a Flask app configured for testing - from config import TestConfig - flask_app = create_app(TestConfig()) - flask_app.config.from_object('config.TestingConfig') - - # Create a test client using the Flask application configured for testing - with flask_app.test_client() as testing_client: - # Establish an application context - with flask_app.app_context(): - yield testing_client # this is where the testing happens! - - -@pytest.fixture(scope='module') -def init_database(test_client): - # Initialize the database - with test_client.application.app_context(): - flask_migrate.upgrade() - - yield # this is where the testing happens! - - # Clean up the database - with test_client.application.app_context(): - flask_migrate.downgrade() - - -@pytest.fixture(scope='module') -def db_session(test_client): - with test_client.application.app_context(): - yield db.session - - -@pytest.fixture(scope='function') -def login_default_user(test_client): - - # todo - - yield # this is where the testing happens! - - test_client.get('/logout', follow_redirects=True) \ No newline at end of file diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example new file mode 100644 index 0000000000..f1ee239415 --- /dev/null +++ b/api/tests/integration_tests/.env.example @@ -0,0 +1,35 @@ +# OpenAI API Key +OPENAI_API_KEY= + +# Azure OpenAI API Base Endpoint & API Key +AZURE_OPENAI_API_BASE= +AZURE_OPENAI_API_KEY= + +# Anthropic API Key +ANTHROPIC_API_KEY= + +# Replicate API Key +REPLICATE_API_TOKEN= + +# Hugging Face API Key +HUGGINGFACE_API_KEY= +HUGGINGFACE_ENDPOINT_URL= + +# Minimax Credentials +MINIMAX_API_KEY= +MINIMAX_GROUP_ID= + +# Spark Credentials +SPARK_APP_ID= +SPARK_API_KEY= +SPARK_API_SECRET= + +# Tongyi Credentials +TONGYI_DASHSCOPE_API_KEY= + +# Wenxin Credentials +WENXIN_API_KEY= +WENXIN_SECRET_KEY= + +# ChatGLM Credentials +CHATGLM_API_BASE= \ No newline at end of file diff --git a/api/tests/integration_tests/__init__.py b/api/tests/integration_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/conftest.py b/api/tests/integration_tests/conftest.py new file mode 100644 index 0000000000..6e3ab4b74b --- /dev/null +++ b/api/tests/integration_tests/conftest.py @@ -0,0 +1,19 @@ +import os + +# Getting the absolute path of the current file's directory +ABS_PATH = os.path.dirname(os.path.abspath(__file__)) + +# Getting the absolute path of the project's root directory +PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir)) + + +# Loading the .env file if it exists +def _load_env() -> None: + dotenv_path = os.path.join(PROJECT_DIR, "tests", "integration_tests", ".env") + if os.path.exists(dotenv_path): + from dotenv import load_dotenv + + load_dotenv(dotenv_path) + + +_load_env() diff --git a/api/tests/integration_tests/models/__init__.py b/api/tests/integration_tests/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/models/embedding/__init__.py b/api/tests/integration_tests/models/embedding/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/models/embedding/test_azure_openai_embedding.py b/api/tests/integration_tests/models/embedding/test_azure_openai_embedding.py new file mode 100644 index 0000000000..9ea202a6ff --- /dev/null +++ b/api/tests/integration_tests/models/embedding/test_azure_openai_embedding.py @@ -0,0 +1,57 @@ +import json +import os +from unittest.mock import patch, MagicMock + +from core.model_providers.models.entity.model_params import ModelType +from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider +from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding +from models.provider import Provider, ProviderType, ProviderModel + + +def get_mock_provider(): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='azure_openai', + provider_type=ProviderType.CUSTOM.value, + encrypted_config='', + is_valid=True, + ) + + +def get_mock_azure_openai_embedding_model(mocker): + model_name = 'text-embedding-ada-002' + valid_openai_api_base = os.environ['AZURE_OPENAI_API_BASE'] + valid_openai_api_key = os.environ['AZURE_OPENAI_API_KEY'] + openai_provider = AzureOpenAIProvider(provider=get_mock_provider()) + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + provider_name='azure_openai', + model_name=model_name, + model_type=ModelType.EMBEDDINGS.value, + encrypted_config=json.dumps({ + 'openai_api_base': valid_openai_api_base, + 'openai_api_key': valid_openai_api_key, + 'base_model_name': model_name + }), + is_valid=True, + ) + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + return AzureOpenAIEmbedding( + model_provider=openai_provider, + name=model_name + ) + + +def decrypt_side_effect(tenant_id, encrypted_openai_api_key): + return encrypted_openai_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_embedding(mock_decrypt, mocker): + embedding_model = get_mock_azure_openai_embedding_model(mocker) + rst = embedding_model.client.embed_query('test') + assert isinstance(rst, list) + assert len(rst) == 1536 diff --git a/api/tests/integration_tests/models/embedding/test_minimax_embedding.py b/api/tests/integration_tests/models/embedding/test_minimax_embedding.py new file mode 100644 index 0000000000..feaad6bb15 --- /dev/null +++ b/api/tests/integration_tests/models/embedding/test_minimax_embedding.py @@ -0,0 +1,44 @@ +import json +import os +from unittest.mock import patch + +from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding +from core.model_providers.providers.minimax_provider import MinimaxProvider +from models.provider import Provider, ProviderType + + +def get_mock_provider(valid_group_id, valid_api_key): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='minimax', + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps({ + 'minimax_group_id': valid_group_id, + 'minimax_api_key': valid_api_key + }), + is_valid=True, + ) + + +def get_mock_embedding_model(): + model_name = 'embo-01' + valid_api_key = os.environ['MINIMAX_API_KEY'] + valid_group_id = os.environ['MINIMAX_GROUP_ID'] + provider = MinimaxProvider(provider=get_mock_provider(valid_group_id, valid_api_key)) + return MinimaxEmbedding( + model_provider=provider, + name=model_name + ) + + +def decrypt_side_effect(tenant_id, encrypted_api_key): + return encrypted_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_embedding(mock_decrypt): + embedding_model = get_mock_embedding_model() + rst = embedding_model.client.embed_query('test') + assert isinstance(rst, list) + assert len(rst) == 1536 diff --git a/api/tests/integration_tests/models/embedding/test_openai_embedding.py b/api/tests/integration_tests/models/embedding/test_openai_embedding.py new file mode 100644 index 0000000000..14e6133493 --- /dev/null +++ b/api/tests/integration_tests/models/embedding/test_openai_embedding.py @@ -0,0 +1,40 @@ +import json +import os +from unittest.mock import patch + +from core.model_providers.providers.openai_provider import OpenAIProvider +from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding +from models.provider import Provider, ProviderType + + +def get_mock_provider(valid_openai_api_key): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='openai', + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps({'openai_api_key': valid_openai_api_key}), + is_valid=True, + ) + + +def get_mock_openai_embedding_model(): + model_name = 'text-embedding-ada-002' + valid_openai_api_key = os.environ['OPENAI_API_KEY'] + openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key)) + return OpenAIEmbedding( + model_provider=openai_provider, + name=model_name + ) + + +def decrypt_side_effect(tenant_id, encrypted_openai_api_key): + return encrypted_openai_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_embedding(mock_decrypt): + embedding_model = get_mock_openai_embedding_model() + rst = embedding_model.client.embed_query('test') + assert isinstance(rst, list) + assert len(rst) == 1536 diff --git a/api/tests/integration_tests/models/embedding/test_replicate_embedding.py b/api/tests/integration_tests/models/embedding/test_replicate_embedding.py new file mode 100644 index 0000000000..16531574cc --- /dev/null +++ b/api/tests/integration_tests/models/embedding/test_replicate_embedding.py @@ -0,0 +1,64 @@ +import json +import os +from unittest.mock import patch, MagicMock + +from core.model_providers.models.embedding.replicate_embedding import ReplicateEmbedding +from core.model_providers.models.entity.model_params import ModelType +from core.model_providers.providers.replicate_provider import ReplicateProvider +from models.provider import Provider, ProviderType, ProviderModel + + +def get_mock_provider(): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='replicate', + provider_type=ProviderType.CUSTOM.value, + encrypted_config='', + is_valid=True, + ) + + +def get_mock_embedding_model(mocker): + model_name = 'replicate/all-mpnet-base-v2' + valid_api_key = os.environ['REPLICATE_API_TOKEN'] + model_provider = ReplicateProvider(provider=get_mock_provider()) + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + provider_name='replicate', + model_name=model_name, + model_type=ModelType.EMBEDDINGS.value, + encrypted_config=json.dumps({ + 'replicate_api_token': valid_api_key, + 'model_version': 'b6b7585c9640cd7a9572c6e129c9549d79c9c31f0d3fdce7baac7c67ca38f305' + }), + is_valid=True, + ) + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + return ReplicateEmbedding( + model_provider=model_provider, + name=model_name + ) + + +def decrypt_side_effect(tenant_id, encrypted_api_key): + return encrypted_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_embed_documents(mock_decrypt, mocker): + embedding_model = get_mock_embedding_model(mocker) + rst = embedding_model.client.embed_documents(['test', 'test1']) + assert isinstance(rst, list) + assert len(rst) == 2 + assert len(rst[0]) == 768 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_embed_query(mock_decrypt, mocker): + embedding_model = get_mock_embedding_model(mocker) + rst = embedding_model.client.embed_query('test') + assert isinstance(rst, list) + assert len(rst) == 768 diff --git a/api/tests/integration_tests/models/llm/__init__.py b/api/tests/integration_tests/models/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/models/llm/test_anthropic_model.py b/api/tests/integration_tests/models/llm/test_anthropic_model.py new file mode 100644 index 0000000000..86cfe9922d --- /dev/null +++ b/api/tests/integration_tests/models/llm/test_anthropic_model.py @@ -0,0 +1,61 @@ +import json +import os +from unittest.mock import patch + +from langchain.schema import ChatGeneration, AIMessage + +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelKwargs +from core.model_providers.models.llm.anthropic_model import AnthropicModel +from core.model_providers.providers.anthropic_provider import AnthropicProvider +from models.provider import Provider, ProviderType + + +def get_mock_provider(valid_api_key): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='anthropic', + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps({'anthropic_api_key': valid_api_key}), + is_valid=True, + ) + + +def get_mock_model(model_name): + model_kwargs = ModelKwargs( + max_tokens=10, + temperature=0 + ) + valid_api_key = os.environ['ANTHROPIC_API_KEY'] + model_provider = AnthropicProvider(provider=get_mock_provider(valid_api_key)) + return AnthropicModel( + model_provider=model_provider, + name=model_name, + model_kwargs=model_kwargs + ) + + +def decrypt_side_effect(tenant_id, encrypted_api_key): + return encrypted_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_num_tokens(mock_decrypt): + model = get_mock_model('claude-2') + rst = model.get_num_tokens([ + PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + ]) + assert rst == 6 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_run(mock_decrypt): + model = get_mock_model('claude-2') + messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: ')] + rst = model.run( + messages, + stop=['\nHuman:'], + ) + assert len(rst.content) > 0 + assert rst.content.strip() == '2' diff --git a/api/tests/integration_tests/models/llm/test_azure_openai_model.py b/api/tests/integration_tests/models/llm/test_azure_openai_model.py new file mode 100644 index 0000000000..112dcd6841 --- /dev/null +++ b/api/tests/integration_tests/models/llm/test_azure_openai_model.py @@ -0,0 +1,86 @@ +import json +import os +from unittest.mock import patch, MagicMock + +import pytest +from langchain.schema import ChatGeneration, AIMessage + +from core.model_providers.models.entity.model_params import ModelKwargs, ModelType +from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider +from models.provider import Provider, ProviderType, ProviderModel + + +def get_mock_provider(): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='azure_openai', + provider_type=ProviderType.CUSTOM.value, + encrypted_config='', + is_valid=True, + ) + + +def get_mock_azure_openai_model(model_name, mocker): + model_kwargs = ModelKwargs( + max_tokens=10, + temperature=0 + ) + valid_openai_api_base = os.environ['AZURE_OPENAI_API_BASE'] + valid_openai_api_key = os.environ['AZURE_OPENAI_API_KEY'] + provider = AzureOpenAIProvider(provider=get_mock_provider()) + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + provider_name='azure_openai', + model_name=model_name, + model_type=ModelType.TEXT_GENERATION.value, + encrypted_config=json.dumps({ + 'openai_api_base': valid_openai_api_base, + 'openai_api_key': valid_openai_api_key, + 'base_model_name': model_name + }), + is_valid=True, + ) + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + return AzureOpenAIModel( + model_provider=provider, + name=model_name, + model_kwargs=model_kwargs + ) + + +def decrypt_side_effect(tenant_id, encrypted_openai_api_key): + return encrypted_openai_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_num_tokens(mock_decrypt, mocker): + openai_model = get_mock_azure_openai_model('text-davinci-003', mocker) + rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')]) + assert rst == 6 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_chat_get_num_tokens(mock_decrypt, mocker): + openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker) + rst = openai_model.get_num_tokens([ + PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'), + PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + ]) + assert rst == 22 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_run(mock_decrypt, mocker): + openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker) + messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')] + rst = openai_model.run( + messages, + stop=['\nHuman:'], + ) + assert len(rst.content) > 0 + assert rst.content.strip() == 'n' diff --git a/api/tests/integration_tests/models/llm/test_huggingface_hub_model.py b/api/tests/integration_tests/models/llm/test_huggingface_hub_model.py new file mode 100644 index 0000000000..d55d6e93fe --- /dev/null +++ b/api/tests/integration_tests/models/llm/test_huggingface_hub_model.py @@ -0,0 +1,124 @@ +import json +import os +from unittest.mock import patch, MagicMock + +from langchain.schema import Generation + +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelKwargs, ModelType +from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel +from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider +from models.provider import Provider, ProviderType, ProviderModel + + +def get_mock_provider(): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='huggingface_hub', + provider_type=ProviderType.CUSTOM.value, + encrypted_config='', + is_valid=True, + ) + + +def get_mock_model(model_name, huggingfacehub_api_type, mocker): + model_kwargs = ModelKwargs( + max_tokens=10, + temperature=0.01 + ) + valid_api_key = os.environ['HUGGINGFACE_API_KEY'] + endpoint_url = os.environ['HUGGINGFACE_ENDPOINT_URL'] + model_provider = HuggingfaceHubProvider(provider=get_mock_provider()) + + credentials = { + 'huggingfacehub_api_type': huggingfacehub_api_type, + 'huggingfacehub_api_token': valid_api_key + } + + if huggingfacehub_api_type == 'inference_endpoints': + credentials['huggingfacehub_endpoint_url'] = endpoint_url + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + provider_name='huggingface_hub', + model_name=model_name, + model_type=ModelType.TEXT_GENERATION.value, + encrypted_config=json.dumps(credentials), + is_valid=True, + ) + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + return HuggingfaceHubModel( + model_provider=model_provider, + name=model_name, + model_kwargs=model_kwargs + ) + + +def decrypt_side_effect(tenant_id, encrypted_api_key): + return encrypted_api_key + +@patch('huggingface_hub.hf_api.ModelInfo') +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_hosted_inference_api_get_num_tokens(mock_decrypt, mock_model_info, mocker): + mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation') + mocker.patch('langchain.llms.huggingface_hub.HuggingFaceHub._call', return_value="abc") + + model = get_mock_model( + 'tiiuae/falcon-40b', + 'hosted_inference_api', + mocker + ) + rst = model.get_num_tokens([ + PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + ]) + assert rst == 5 + + +@patch('huggingface_hub.hf_api.ModelInfo') +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_inference_endpoints_get_num_tokens(mock_decrypt, mock_model_info, mocker): + mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation') + mocker.patch('langchain.llms.huggingface_hub.HuggingFaceHub._call', return_value="abc") + + model = get_mock_model( + '', + 'inference_endpoints', + mocker + ) + rst = model.get_num_tokens([ + PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + ]) + assert rst == 5 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_hosted_inference_api_run(mock_decrypt, mocker): + model = get_mock_model( + 'google/flan-t5-base', + 'hosted_inference_api', + mocker + ) + + rst = model.run( + [PromptMessage(content='Human: Are you Really Human? you MUST only answer `y` or `n`? \nAssistant: ')], + stop=['\nHuman:'], + ) + assert len(rst.content) > 0 + assert rst.content.strip() == 'n' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_inference_endpoints_run(mock_decrypt, mocker): + model = get_mock_model( + '', + 'inference_endpoints', + mocker + ) + + rst = model.run( + [PromptMessage(content='Answer the following yes/no question. Can you write a whole Haiku in a single tweet?')], + ) + assert len(rst.content) > 0 + assert rst.content.strip() == 'no' diff --git a/api/tests/integration_tests/models/llm/test_minimax_model.py b/api/tests/integration_tests/models/llm/test_minimax_model.py new file mode 100644 index 0000000000..79a05bc279 --- /dev/null +++ b/api/tests/integration_tests/models/llm/test_minimax_model.py @@ -0,0 +1,64 @@ +import json +import os +from unittest.mock import patch + +from langchain.schema import ChatGeneration, AIMessage, Generation + +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelKwargs +from core.model_providers.models.llm.minimax_model import MinimaxModel +from core.model_providers.providers.minimax_provider import MinimaxProvider +from models.provider import Provider, ProviderType + + +def get_mock_provider(valid_group_id, valid_api_key): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='minimax', + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps({ + 'minimax_group_id': valid_group_id, + 'minimax_api_key': valid_api_key + }), + is_valid=True, + ) + + +def get_mock_model(model_name): + model_kwargs = ModelKwargs( + max_tokens=10, + temperature=0.01 + ) + valid_api_key = os.environ['MINIMAX_API_KEY'] + valid_group_id = os.environ['MINIMAX_GROUP_ID'] + model_provider = MinimaxProvider(provider=get_mock_provider(valid_group_id, valid_api_key)) + return MinimaxModel( + model_provider=model_provider, + name=model_name, + model_kwargs=model_kwargs + ) + + +def decrypt_side_effect(tenant_id, encrypted_api_key): + return encrypted_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_num_tokens(mock_decrypt): + model = get_mock_model('abab5.5-chat') + rst = model.get_num_tokens([ + PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + ]) + assert rst == 5 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_run(mock_decrypt): + model = get_mock_model('abab5.5-chat') + rst = model.run( + [PromptMessage(content='Human: Are you a real Human? you MUST only answer `y` or `n`? \nAssistant: ')], + stop=['\nHuman:'], + ) + assert len(rst.content) > 0 + assert rst.content.strip() == 'n' diff --git a/api/tests/integration_tests/models/llm/test_openai_model.py b/api/tests/integration_tests/models/llm/test_openai_model.py new file mode 100644 index 0000000000..ebc40fd529 --- /dev/null +++ b/api/tests/integration_tests/models/llm/test_openai_model.py @@ -0,0 +1,80 @@ +import json +import os +from unittest.mock import patch + +from langchain.schema import Generation, ChatGeneration, AIMessage + +from core.model_providers.providers.openai_provider import OpenAIProvider +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelKwargs +from core.model_providers.models.llm.openai_model import OpenAIModel +from models.provider import Provider, ProviderType + + +def get_mock_provider(valid_openai_api_key): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='openai', + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps({'openai_api_key': valid_openai_api_key}), + is_valid=True, + ) + + +def get_mock_openai_model(model_name): + model_kwargs = ModelKwargs( + max_tokens=10, + temperature=0 + ) + model_name = model_name + valid_openai_api_key = os.environ['OPENAI_API_KEY'] + openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key)) + return OpenAIModel( + model_provider=openai_provider, + name=model_name, + model_kwargs=model_kwargs + ) + + +def decrypt_side_effect(tenant_id, encrypted_openai_api_key): + return encrypted_openai_api_key + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_num_tokens(mock_decrypt): + openai_model = get_mock_openai_model('text-davinci-003') + rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')]) + assert rst == 6 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_chat_get_num_tokens(mock_decrypt): + openai_model = get_mock_openai_model('gpt-3.5-turbo') + rst = openai_model.get_num_tokens([ + PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'), + PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + ]) + assert rst == 22 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_run(mock_decrypt): + openai_model = get_mock_openai_model('text-davinci-003') + rst = openai_model.run( + [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')], + stop=['\nHuman:'], + ) + assert len(rst.content) > 0 + assert rst.content.strip() == 'n' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_chat_run(mock_decrypt): + openai_model = get_mock_openai_model('gpt-3.5-turbo') + messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')] + rst = openai_model.run( + messages, + stop=['\nHuman:'], + ) + assert len(rst.content) > 0 + assert rst.content.strip() == 'n' diff --git a/api/tests/integration_tests/models/llm/test_replicate_model.py b/api/tests/integration_tests/models/llm/test_replicate_model.py new file mode 100644 index 0000000000..7689a3c0fc --- /dev/null +++ b/api/tests/integration_tests/models/llm/test_replicate_model.py @@ -0,0 +1,73 @@ +import json +import os +from unittest.mock import patch, MagicMock + +from langchain.schema import Generation + +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelKwargs, ModelType +from core.model_providers.models.llm.replicate_model import ReplicateModel +from core.model_providers.providers.replicate_provider import ReplicateProvider +from models.provider import Provider, ProviderType, ProviderModel + + +def get_mock_provider(): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='replicate', + provider_type=ProviderType.CUSTOM.value, + encrypted_config='', + is_valid=True, + ) + + +def get_mock_model(model_name, model_version, mocker): + model_kwargs = ModelKwargs( + max_tokens=10, + temperature=0.01 + ) + valid_api_key = os.environ['REPLICATE_API_TOKEN'] + model_provider = ReplicateProvider(provider=get_mock_provider()) + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + provider_name='replicate', + model_name=model_name, + model_type=ModelType.TEXT_GENERATION.value, + encrypted_config=json.dumps({ + 'replicate_api_token': valid_api_key, + 'model_version': model_version + }), + is_valid=True, + ) + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + return ReplicateModel( + model_provider=model_provider, + name=model_name, + model_kwargs=model_kwargs + ) + + +def decrypt_side_effect(tenant_id, encrypted_api_key): + return encrypted_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_num_tokens(mock_decrypt, mocker): + model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker) + rst = model.get_num_tokens([ + PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + ]) + assert rst == 7 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_run(mock_decrypt, mocker): + model = get_mock_model('a16z-infra/llama-2-13b-chat', '2a7f981751ec7fdf87b5b91ad4db53683a98082e9ff7bfd12c8cd5ea85980a52', mocker) + messages = [PromptMessage(content='Human: 1+1=? \nAnswer: ')] + rst = model.run( + messages + ) + assert len(rst.content) > 0 diff --git a/api/tests/integration_tests/models/llm/test_spark_model.py b/api/tests/integration_tests/models/llm/test_spark_model.py new file mode 100644 index 0000000000..4e62aeb2cd --- /dev/null +++ b/api/tests/integration_tests/models/llm/test_spark_model.py @@ -0,0 +1,69 @@ +import json +import os +from unittest.mock import patch + +from langchain.schema import ChatGeneration, AIMessage, Generation + +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelKwargs +from core.model_providers.models.llm.minimax_model import MinimaxModel +from core.model_providers.models.llm.spark_model import SparkModel +from core.model_providers.providers.minimax_provider import MinimaxProvider +from core.model_providers.providers.spark_provider import SparkProvider +from models.provider import Provider, ProviderType + + +def get_mock_provider(valid_app_id, valid_api_key, valid_api_secret): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='spark', + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps({ + 'app_id': valid_app_id, + 'api_key': valid_api_key, + 'api_secret': valid_api_secret, + }), + is_valid=True, + ) + + +def get_mock_model(model_name): + model_kwargs = ModelKwargs( + max_tokens=10, + temperature=0.01 + ) + valid_app_id = os.environ['SPARK_APP_ID'] + valid_api_key = os.environ['SPARK_API_KEY'] + valid_api_secret = os.environ['SPARK_API_SECRET'] + model_provider = SparkProvider(provider=get_mock_provider(valid_app_id, valid_api_key, valid_api_secret)) + return SparkModel( + model_provider=model_provider, + name=model_name, + model_kwargs=model_kwargs + ) + + +def decrypt_side_effect(tenant_id, encrypted_api_key): + return encrypted_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_num_tokens(mock_decrypt): + model = get_mock_model('spark') + rst = model.get_num_tokens([ + PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + ]) + assert rst == 6 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_run(mock_decrypt): + model = get_mock_model('spark') + messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')] + rst = model.run( + messages, + stop=['\nHuman:'], + ) + assert len(rst.content) > 0 + assert rst.content.strip() == '2' diff --git a/api/tests/integration_tests/models/llm/test_tongyi_model.py b/api/tests/integration_tests/models/llm/test_tongyi_model.py new file mode 100644 index 0000000000..2f9e33992f --- /dev/null +++ b/api/tests/integration_tests/models/llm/test_tongyi_model.py @@ -0,0 +1,61 @@ +import json +import os +from unittest.mock import patch + +from langchain.schema import ChatGeneration, AIMessage, Generation + +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelKwargs +from core.model_providers.models.llm.tongyi_model import TongyiModel +from core.model_providers.providers.tongyi_provider import TongyiProvider +from models.provider import Provider, ProviderType + + +def get_mock_provider(valid_api_key): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='tongyi', + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps({ + 'dashscope_api_key': valid_api_key, + }), + is_valid=True, + ) + + +def get_mock_model(model_name): + model_kwargs = ModelKwargs( + max_tokens=10, + temperature=0.01 + ) + valid_api_key = os.environ['TONGYI_DASHSCOPE_API_KEY'] + model_provider = TongyiProvider(provider=get_mock_provider(valid_api_key)) + return TongyiModel( + model_provider=model_provider, + name=model_name, + model_kwargs=model_kwargs + ) + + +def decrypt_side_effect(tenant_id, encrypted_api_key): + return encrypted_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_num_tokens(mock_decrypt): + model = get_mock_model('qwen-v1') + rst = model.get_num_tokens([ + PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + ]) + assert rst == 5 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_run(mock_decrypt): + model = get_mock_model('qwen-v1') + rst = model.run( + [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')], + stop=['\nHuman:'], + ) + assert len(rst.content) > 0 diff --git a/api/tests/integration_tests/models/llm/test_wenxin_model.py b/api/tests/integration_tests/models/llm/test_wenxin_model.py new file mode 100644 index 0000000000..f517d05c25 --- /dev/null +++ b/api/tests/integration_tests/models/llm/test_wenxin_model.py @@ -0,0 +1,63 @@ +import json +import os +from unittest.mock import patch + + +from core.model_providers.models.entity.message import PromptMessage, MessageType +from core.model_providers.models.entity.model_params import ModelKwargs +from core.model_providers.models.llm.wenxin_model import WenxinModel +from core.model_providers.providers.wenxin_provider import WenxinProvider +from models.provider import Provider, ProviderType + + +def get_mock_provider(valid_api_key, valid_secret_key): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='wenxin', + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps({ + 'api_key': valid_api_key, + 'secret_key': valid_secret_key, + }), + is_valid=True, + ) + + +def get_mock_model(model_name): + model_kwargs = ModelKwargs( + temperature=0.01 + ) + valid_api_key = os.environ['WENXIN_API_KEY'] + valid_secret_key = os.environ['WENXIN_SECRET_KEY'] + model_provider = WenxinProvider(provider=get_mock_provider(valid_api_key, valid_secret_key)) + return WenxinModel( + model_provider=model_provider, + name=model_name, + model_kwargs=model_kwargs + ) + + +def decrypt_side_effect(tenant_id, encrypted_api_key): + return encrypted_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_num_tokens(mock_decrypt): + model = get_mock_model('ernie-bot') + rst = model.get_num_tokens([ + PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?') + ]) + assert rst == 5 + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_run(mock_decrypt): + model = get_mock_model('ernie-bot') + messages = [PromptMessage(content='Human: 1 + 1=? \nAssistant: Integer answer is:')] + rst = model.run( + messages, + stop=['\nHuman:'], + ) + assert len(rst.content) > 0 + assert rst.content.strip() == '2' diff --git a/api/tests/integration_tests/models/moderation/__init__.py b/api/tests/integration_tests/models/moderation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/models/moderation/test_openai_moderation.py b/api/tests/integration_tests/models/moderation/test_openai_moderation.py new file mode 100644 index 0000000000..c27f43e141 --- /dev/null +++ b/api/tests/integration_tests/models/moderation/test_openai_moderation.py @@ -0,0 +1,40 @@ +import json +import os +from unittest.mock import patch + +from core.model_providers.models.moderation.openai_moderation import OpenAIModeration, DEFAULT_AUDIO_MODEL +from core.model_providers.providers.openai_provider import OpenAIProvider +from models.provider import Provider, ProviderType + + +def get_mock_provider(valid_openai_api_key): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='openai', + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps({'openai_api_key': valid_openai_api_key}), + is_valid=True, + ) + + +def get_mock_openai_moderation_model(): + valid_openai_api_key = os.environ['OPENAI_API_KEY'] + openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key)) + return OpenAIModeration( + model_provider=openai_provider, + name=DEFAULT_AUDIO_MODEL + ) + + +def decrypt_side_effect(tenant_id, encrypted_openai_api_key): + return encrypted_openai_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_run(mock_decrypt): + model = get_mock_openai_moderation_model() + rst = model.run('hello') + + assert isinstance(rst, dict) + assert 'id' in rst diff --git a/api/tests/integration_tests/models/speech2text/__init__.py b/api/tests/integration_tests/models/speech2text/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/models/speech2text/audio.mp3 b/api/tests/integration_tests/models/speech2text/audio.mp3 new file mode 100644 index 0000000000..7c86e02e16 Binary files /dev/null and b/api/tests/integration_tests/models/speech2text/audio.mp3 differ diff --git a/api/tests/integration_tests/models/speech2text/test_openai_whisper.py b/api/tests/integration_tests/models/speech2text/test_openai_whisper.py new file mode 100644 index 0000000000..a649c794e7 --- /dev/null +++ b/api/tests/integration_tests/models/speech2text/test_openai_whisper.py @@ -0,0 +1,50 @@ +import json +import os +from unittest.mock import patch + +from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper +from core.model_providers.providers.openai_provider import OpenAIProvider +from models.provider import Provider, ProviderType + + +def get_mock_provider(valid_openai_api_key): + return Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name='openai', + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps({'openai_api_key': valid_openai_api_key}), + is_valid=True, + ) + + +def get_mock_openai_whisper_model(): + model_name = 'whisper-1' + valid_openai_api_key = os.environ['OPENAI_API_KEY'] + openai_provider = OpenAIProvider(provider=get_mock_provider(valid_openai_api_key)) + return OpenAIWhisper( + model_provider=openai_provider, + name=model_name + ) + + +def decrypt_side_effect(tenant_id, encrypted_openai_api_key): + return encrypted_openai_api_key + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_run(mock_decrypt): + # Get the directory of the current file + current_dir = os.path.dirname(os.path.abspath(__file__)) + + # Construct the path to the audio file + audio_file_path = os.path.join(current_dir, 'audio.mp3') + + model = get_mock_openai_whisper_model() + + # Open the file and get the file object + with open(audio_file_path, 'rb') as audio_file: + rst = model.run(audio_file) + + assert isinstance(rst, dict) + assert rst['text'] == '1, 2, 3, 4, 5, 6, 7, 8, 9, 10' diff --git a/api/tests/test_controllers/test_account_api.py.bak b/api/tests/test_controllers/test_account_api.py.bak deleted file mode 100644 index a73c796b78..0000000000 --- a/api/tests/test_controllers/test_account_api.py.bak +++ /dev/null @@ -1,75 +0,0 @@ -import json -import pytest -from flask import url_for - -from models.model import Account - -# Sample user data for testing -sample_user_data = { - 'name': 'Test User', - 'email': 'test@example.com', - 'interface_language': 'en-US', - 'interface_theme': 'light', - 'timezone': 'America/New_York', - 'password': 'testpassword', - 'new_password': 'newtestpassword', - 'repeat_new_password': 'newtestpassword' -} - -# Create a test user and log them in -@pytest.fixture(scope='function') -def logged_in_user(client, session): - # Create test user and add them to the database - # Replace this with your actual User model and any required fields - - # todo refer to api.controllers.setup.SetupApi.post() to create a user - db_user_data = sample_user_data.copy() - db_user_data['password_salt'] = 'testpasswordsalt' - del db_user_data['new_password'] - del db_user_data['repeat_new_password'] - test_user = Account(**db_user_data) - session.add(test_user) - session.commit() - - # Log in the test user - client.post(url_for('console.loginapi'), data={'email': sample_user_data['email'], 'password': sample_user_data['password']}) - - return test_user - -def test_account_profile(logged_in_user, client): - response = client.get(url_for('console.accountprofileapi')) - assert response.status_code == 200 - assert json.loads(response.data)['name'] == sample_user_data['name'] - -def test_account_name(logged_in_user, client): - new_name = 'New Test User' - response = client.post(url_for('console.accountnameapi'), json={'name': new_name}) - assert response.status_code == 200 - assert json.loads(response.data)['name'] == new_name - -def test_account_interface_language(logged_in_user, client): - new_language = 'zh-CN' - response = client.post(url_for('console.accountinterfacelanguageapi'), json={'interface_language': new_language}) - assert response.status_code == 200 - assert json.loads(response.data)['interface_language'] == new_language - -def test_account_interface_theme(logged_in_user, client): - new_theme = 'dark' - response = client.post(url_for('console.accountinterfacethemeapi'), json={'interface_theme': new_theme}) - assert response.status_code == 200 - assert json.loads(response.data)['interface_theme'] == new_theme - -def test_account_timezone(logged_in_user, client): - new_timezone = 'Asia/Shanghai' - response = client.post(url_for('console.accounttimezoneapi'), json={'timezone': new_timezone}) - assert response.status_code == 200 - assert json.loads(response.data)['timezone'] == new_timezone - -def test_account_password(logged_in_user, client): - response = client.post(url_for('console.accountpasswordapi'), json={ - 'password': sample_user_data['password'], - 'new_password': sample_user_data['new_password'], - 'repeat_new_password': sample_user_data['repeat_new_password'] - }) - assert response.status_code == 200 - assert json.loads(response.data)['result'] == 'success' diff --git a/api/tests/test_controllers/test_login.py b/api/tests/test_controllers/test_login.py deleted file mode 100644 index 559e2f809e..0000000000 --- a/api/tests/test_controllers/test_login.py +++ /dev/null @@ -1,108 +0,0 @@ -import pytest -from app import create_app, db -from flask_login import current_user -from models.model import Account, TenantAccountJoin, Tenant - - -@pytest.fixture -def client(test_client, db_session): - app = create_app() - app.config["TESTING"] = True - with app.app_context(): - db.create_all() - yield test_client - db.drop_all() - - -def test_login_api_post(client, db_session): - # create a tenant, account, and tenant account join - tenant = Tenant(name="Test Tenant", status="normal") - account = Account(email="test@test.com", name="Test User") - account.password_salt = "uQ7K0/0wUJ7VPhf3qBzwNQ==" - account.password = "A9YpfzjK7c/tOwzamrvpJg==" - db.session.add_all([tenant, account]) - db.session.flush() - tenant_account_join = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, is_tenant_owner=True) - db.session.add(tenant_account_join) - db.session.commit() - - # login with correct credentials - response = client.post("/login", json={ - "email": "test@test.com", - "password": "Abc123456", - "remember_me": True - }) - assert response.status_code == 200 - assert response.json == {"result": "success"} - assert current_user == account - assert 'tenant_id' in client.session - assert client.session['tenant_id'] == tenant.id - - # login with incorrect password - response = client.post("/login", json={ - "email": "test@test.com", - "password": "wrong_password", - "remember_me": True - }) - assert response.status_code == 401 - - # login with non-existent account - response = client.post("/login", json={ - "email": "non_existent_account@test.com", - "password": "Abc123456", - "remember_me": True - }) - assert response.status_code == 401 - - -def test_logout_api_get(client, db_session): - # create a tenant, account, and tenant account join - tenant = Tenant(name="Test Tenant", status="normal") - account = Account(email="test@test.com", name="Test User") - db.session.add_all([tenant, account]) - db.session.flush() - tenant_account_join = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, is_tenant_owner=True) - db.session.add(tenant_account_join) - db.session.commit() - - # login and check if session variable and current_user are set - with client.session_transaction() as session: - session['tenant_id'] = tenant.id - client.post("/login", json={ - "email": "test@test.com", - "password": "Abc123456", - "remember_me": True - }) - assert current_user == account - assert 'tenant_id' in client.session - assert client.session['tenant_id'] == tenant.id - - # logout and check if session variable and current_user are unset - response = client.get("/logout") - assert response.status_code == 200 - assert current_user.is_authenticated is False - assert 'tenant_id' not in client.session - - -def test_reset_password_api_get(client, db_session): - # create a tenant, account, and tenant account join - tenant = Tenant(name="Test Tenant", status="normal") - account = Account(email="test@test.com", name="Test User") - db.session.add_all([tenant, account]) - db.session.flush() - tenant_account_join = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, is_tenant_owner=True) - db.session.add(tenant_account_join) - db.session.commit() - - # reset password in cloud edition - app = client.application - app.config["CLOUD_EDITION"] = True - response = client.get("/reset_password") - assert response.status_code == 200 - assert response.json == {"result": "success"} - - # reset password in non-cloud edition - app.config["CLOUD_EDITION"] = False - response = client.get("/reset_password") - assert response.status_code == 200 - assert response.json == {"result": "success"} diff --git a/api/tests/test_controllers/test_setup.py b/api/tests/test_controllers/test_setup.py deleted file mode 100644 index 96a9b0911e..0000000000 --- a/api/tests/test_controllers/test_setup.py +++ /dev/null @@ -1,80 +0,0 @@ -import os -import pytest -from models.model import Account, Tenant, TenantAccountJoin - - -def test_setup_api_get(test_client,db_session): - response = test_client.get("/setup") - assert response.status_code == 200 - assert response.json == {"step": "not_start"} - - # create a tenant and check again - tenant = Tenant(name="Test Tenant", status="normal") - db_session.add(tenant) - db_session.commit() - response = test_client.get("/setup") - assert response.status_code == 200 - assert response.json == {"step": "step2"} - - # create setup file and check again - response = test_client.get("/setup") - assert response.status_code == 200 - assert response.json == {"step": "finished"} - - -def test_setup_api_post(test_client): - response = test_client.post("/setup", json={ - "email": "test@test.com", - "name": "Test User", - "password": "Abc123456" - }) - assert response.status_code == 200 - assert response.json == {"result": "success", "next_step": "step2"} - - # check if the tenant, account, and tenant account join records were created - tenant = Tenant.query.first() - assert tenant.name == "Test User's LLM Factory" - assert tenant.status == "normal" - assert tenant.encrypt_public_key - - account = Account.query.first() - assert account.email == "test@test.com" - assert account.name == "Test User" - assert account.password_salt - assert account.password - assert TenantAccountJoin.query.filter_by(account_id=account.id, is_tenant_owner=True).count() == 1 - - # check if password is encrypted correctly - salt = account.password_salt.encode() - password_hashed = account.password.encode() - assert account.password == base64.b64encode(hash_password("Abc123456", salt)).decode() - - -def test_setup_step2_api_post(test_client,db_session): - # create a tenant, account, and setup file - tenant = Tenant(name="Test Tenant", status="normal") - account = Account(email="test@test.com", name="Test User") - db_session.add_all([tenant, account]) - db_session.commit() - - # try to set up with incorrect language - response = test_client.post("/setup/step2", json={ - "interface_language": "invalid_language", - "timezone": "Asia/Shanghai" - }) - assert response.status_code == 400 - - # set up successfully - response = test_client.post("/setup/step2", json={ - "interface_language": "en", - "timezone": "Asia/Shanghai" - }) - assert response.status_code == 200 - assert response.json == {"result": "success", "next_step": "finished"} - - # check if account was updated correctly - account = Account.query.first() - assert account.interface_language == "en" - assert account.timezone == "Asia/Shanghai" - assert account.interface_theme == "light" - assert account.last_login_ip == "127.0.0.1" diff --git a/api/tests/test_factory.py b/api/tests/test_factory.py deleted file mode 100644 index 0d73168b43..0000000000 --- a/api/tests/test_factory.py +++ /dev/null @@ -1,22 +0,0 @@ -# -*- coding:utf-8 -*- - -import pytest - -from app import create_app - -def test_create_app(): - - # Test Default(CE) Config - app = create_app() - - assert app.config['SECRET_KEY'] is not None - assert app.config['SQLALCHEMY_DATABASE_URI'] is not None - assert app.config['EDITION'] == "SELF_HOSTED" - - # Test TestConfig - from config import TestConfig - test_app = create_app(TestConfig()) - - assert test_app.config['SECRET_KEY'] is not None - assert test_app.config['SQLALCHEMY_DATABASE_URI'] is not None - assert test_app.config['TESTING'] is True \ No newline at end of file diff --git a/api/tests/unit_tests/__init__.py b/api/tests/unit_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/model_providers/__init__.py b/api/tests/unit_tests/model_providers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/model_providers/fake_model_provider.py b/api/tests/unit_tests/model_providers/fake_model_provider.py new file mode 100644 index 0000000000..4e14d5924e --- /dev/null +++ b/api/tests/unit_tests/model_providers/fake_model_provider.py @@ -0,0 +1,44 @@ +from typing import Type + +from core.model_providers.models.base import BaseProviderModel +from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules +from core.model_providers.models.llm.openai_model import OpenAIModel +from core.model_providers.providers.base import BaseModelProvider + + +class FakeModelProvider(BaseModelProvider): + @property + def provider_name(self): + return 'fake' + + def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]: + return [{'id': 'test_model', 'name': 'Test Model'}] + + def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]: + return OpenAIModel + + @classmethod + def is_provider_credentials_valid_or_raise(cls, credentials: dict): + pass + + @classmethod + def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict: + return credentials + + def get_provider_credentials(self, obfuscated: bool = False) -> dict: + return {} + + @classmethod + def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict): + pass + + @classmethod + def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType, + credentials: dict) -> dict: + return credentials + + def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules: + return ModelKwargsRules() + + def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict: + return {} diff --git a/api/tests/unit_tests/model_providers/test_anthropic_provider.py b/api/tests/unit_tests/model_providers/test_anthropic_provider.py new file mode 100644 index 0000000000..ea4b62a20a --- /dev/null +++ b/api/tests/unit_tests/model_providers/test_anthropic_provider.py @@ -0,0 +1,123 @@ +from typing import List, Optional, Any + +import anthropic +import httpx +import pytest +from unittest.mock import patch +import json + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.schema import BaseMessage, ChatResult, ChatGeneration, AIMessage + +from core.model_providers.providers.anthropic_provider import AnthropicProvider +from core.model_providers.providers.base import CredentialsValidateFailedError +from models.provider import ProviderType, Provider + + +PROVIDER_NAME = 'anthropic' +MODEL_PROVIDER_CLASS = AnthropicProvider +VALIDATE_CREDENTIAL_KEY = 'anthropic_api_key' + + +def mock_chat_generate(messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any): + return ChatResult(generations=[ChatGeneration(message=AIMessage(content='answer'))]) + + +def mock_chat_generate_invalid(messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any): + raise anthropic.APIStatusError('Invalid credentials', + request=httpx._models.Request( + method='POST', + url='https://api.anthropic.com/v1/completions', + ), + response=httpx._models.Response( + status_code=401, + ), + body=None + ) + + +def encrypt_side_effect(tenant_id, encrypt_key): + return f'encrypted_{encrypt_key}' + + +def decrypt_side_effect(tenant_id, encrypted_key): + return encrypted_key.replace('encrypted_', '') + + +@patch('langchain.chat_models.ChatAnthropic._generate', side_effect=mock_chat_generate) +def test_is_provider_credentials_valid_or_raise_valid(mock_create): + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({VALIDATE_CREDENTIAL_KEY: 'valid_key'}) + + +@patch('langchain.chat_models.ChatAnthropic._generate', side_effect=mock_chat_generate_invalid) +def test_is_provider_credentials_valid_or_raise_invalid(mock_create): + # raise CredentialsValidateFailedError if anthropic_api_key is not in credentials + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({}) + + # raise CredentialsValidateFailedError if anthropic_api_key is invalid + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({VALIDATE_CREDENTIAL_KEY: 'invalid_key'}) + + +@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) +def test_encrypt_credentials(mock_encrypt): + api_key = 'valid_key' + result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', {VALIDATE_CREDENTIAL_KEY: api_key}) + mock_encrypt.assert_called_with('tenant_id', api_key) + assert result[VALIDATE_CREDENTIAL_KEY] == f'encrypted_{api_key}' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_credentials_custom(mock_decrypt): + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps({VALIDATE_CREDENTIAL_KEY: 'encrypted_valid_key'}), + is_valid=True, + ) + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_provider_credentials() + assert result[VALIDATE_CREDENTIAL_KEY] == 'valid_key' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_credentials_obfuscated(mock_decrypt): + api_key = 'valid_key' + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps({VALIDATE_CREDENTIAL_KEY: f'encrypted_{api_key}'}), + is_valid=True, + ) + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_provider_credentials(obfuscated=True) + middle_token = result[VALIDATE_CREDENTIAL_KEY][6:-2] + assert len(middle_token) == max(len(api_key) - 8, 0) + assert all(char == '*' for char in middle_token) + + +@patch('core.model_providers.providers.hosted.hosted_model_providers.anthropic') +def test_get_credentials_hosted(mock_hosted): + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.SYSTEM.value, + encrypted_config='', + is_valid=True, + ) + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + mock_hosted.api_key = 'hosted_key' + result = model_provider.get_provider_credentials() + assert result[VALIDATE_CREDENTIAL_KEY] == 'hosted_key' diff --git a/api/tests/unit_tests/model_providers/test_azure_openai_provider.py b/api/tests/unit_tests/model_providers/test_azure_openai_provider.py new file mode 100644 index 0000000000..43788d4e07 --- /dev/null +++ b/api/tests/unit_tests/model_providers/test_azure_openai_provider.py @@ -0,0 +1,117 @@ +import pytest +from unittest.mock import patch, MagicMock +import json + +from core.model_providers.models.entity.model_params import ModelType +from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider +from core.model_providers.providers.base import CredentialsValidateFailedError +from models.provider import ProviderType, Provider, ProviderModel + +PROVIDER_NAME = 'azure_openai' +MODEL_PROVIDER_CLASS = AzureOpenAIProvider +VALIDATE_CREDENTIAL = { + 'openai_api_base': 'https://xxxx.openai.azure.com/', + 'openai_api_key': 'valid_key', + 'base_model_name': 'gpt-35-turbo' +} + + +def encrypt_side_effect(tenant_id, encrypt_key): + return f'encrypted_{encrypt_key}' + + +def decrypt_side_effect(tenant_id, encrypted_key): + return encrypted_key.replace('encrypted_', '') + + +def test_is_model_credentials_valid_or_raise(mocker): + mocker.patch('langchain.chat_models.base.BaseChatModel.generate', return_value=None) + + # assert True if credentials is valid + MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials=VALIDATE_CREDENTIAL + ) + + +def test_is_model_credentials_valid_or_raise_invalid(): + # raise CredentialsValidateFailedError if credentials is not in credentials + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials={} + ) + + +@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) +def test_encrypt_model_credentials(mock_encrypt): + openai_api_key = 'valid_key' + result = MODEL_PROVIDER_CLASS.encrypt_model_credentials( + tenant_id='tenant_id', + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials={'openai_api_key': openai_api_key} + ) + mock_encrypt.assert_called_with('tenant_id', openai_api_key) + assert result['openai_api_key'] == f'encrypted_{openai_api_key}' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_model_credentials_custom(mock_decrypt, mocker): + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=None, + is_valid=True, + ) + + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['openai_api_key'] = 'encrypted_' + encrypted_credential['openai_api_key'] + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + encrypted_config=json.dumps(encrypted_credential) + ) + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_model_credentials( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION + ) + assert result['openai_api_key'] == 'valid_key' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_model_credentials_obfuscated(mock_decrypt, mocker): + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=None, + is_valid=True, + ) + + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['openai_api_key'] = 'encrypted_' + encrypted_credential['openai_api_key'] + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + encrypted_config=json.dumps(encrypted_credential) + ) + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_model_credentials( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + obfuscated=True + ) + middle_token = result['openai_api_key'][6:-2] + assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['openai_api_key']) - 8, 0) + assert all(char == '*' for char in middle_token) diff --git a/api/tests/unit_tests/model_providers/test_base_model_provider.py b/api/tests/unit_tests/model_providers/test_base_model_provider.py new file mode 100644 index 0000000000..7d6e56eb0a --- /dev/null +++ b/api/tests/unit_tests/model_providers/test_base_model_provider.py @@ -0,0 +1,72 @@ +from unittest.mock import MagicMock + +import pytest + +from core.model_providers.error import QuotaExceededError +from core.model_providers.models.entity.model_params import ModelType +from models.provider import Provider, ProviderType +from tests.unit_tests.model_providers.fake_model_provider import FakeModelProvider + + +def test_get_supported_model_list(mocker): + mocker.patch.object( + FakeModelProvider, + 'get_rules', + return_value={'support_provider_types': ['custom'], 'model_flexibility': 'configurable'} + ) + + mock_provider_model = MagicMock() + mock_provider_model.model_name = 'test_model' + mock_query = MagicMock() + mock_query.filter.return_value.order_by.return_value.all.return_value = [mock_provider_model] + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + provider = FakeModelProvider(provider=Provider()) + result = provider.get_supported_model_list(ModelType.TEXT_GENERATION) + + assert result == [{'id': 'test_model', 'name': 'test_model'}] + + +def test_check_quota_over_limit(mocker): + mocker.patch.object( + FakeModelProvider, + 'get_rules', + return_value={'support_provider_types': ['system']} + ) + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = None + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + provider = FakeModelProvider(provider=Provider(provider_type=ProviderType.SYSTEM.value)) + + with pytest.raises(QuotaExceededError): + provider.check_quota_over_limit() + + +def test_check_quota_not_over_limit(mocker): + mocker.patch.object( + FakeModelProvider, + 'get_rules', + return_value={'support_provider_types': ['system']} + ) + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = Provider() + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + provider = FakeModelProvider(provider=Provider(provider_type=ProviderType.SYSTEM.value)) + + assert provider.check_quota_over_limit() is None + + +def test_check_custom_quota_over_limit(mocker): + mocker.patch.object( + FakeModelProvider, + 'get_rules', + return_value={'support_provider_types': ['custom']} + ) + + provider = FakeModelProvider(provider=Provider(provider_type=ProviderType.CUSTOM.value)) + + assert provider.check_quota_over_limit() is None diff --git a/api/tests/unit_tests/model_providers/test_chatglm_provider.py b/api/tests/unit_tests/model_providers/test_chatglm_provider.py new file mode 100644 index 0000000000..9dfa1291f4 --- /dev/null +++ b/api/tests/unit_tests/model_providers/test_chatglm_provider.py @@ -0,0 +1,89 @@ +import pytest +from unittest.mock import patch +import json + +from langchain.schema import LLMResult, Generation, AIMessage, ChatResult, ChatGeneration + +from core.model_providers.providers.base import CredentialsValidateFailedError +from core.model_providers.providers.chatglm_provider import ChatGLMProvider +from core.model_providers.providers.spark_provider import SparkProvider +from models.provider import ProviderType, Provider + + +PROVIDER_NAME = 'chatglm' +MODEL_PROVIDER_CLASS = ChatGLMProvider +VALIDATE_CREDENTIAL = { + 'api_base': 'valid_api_base', +} + + +def encrypt_side_effect(tenant_id, encrypt_key): + return f'encrypted_{encrypt_key}' + + +def decrypt_side_effect(tenant_id, encrypted_key): + return encrypted_key.replace('encrypted_', '') + + +def test_is_provider_credentials_valid_or_raise_valid(mocker): + mocker.patch('langchain.llms.chatglm.ChatGLM._call', + return_value="abc") + + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL) + + +def test_is_provider_credentials_valid_or_raise_invalid(): + # raise CredentialsValidateFailedError if api_key is not in credentials + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({}) + + credential = VALIDATE_CREDENTIAL.copy() + credential['api_base'] = 'invalid_api_base' + + # raise CredentialsValidateFailedError if api_key is invalid + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential) + + +@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) +def test_encrypt_credentials(mock_encrypt): + result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy()) + assert result['api_base'] == f'encrypted_{VALIDATE_CREDENTIAL["api_base"]}' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_credentials_custom(mock_decrypt): + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['api_base'] = 'encrypted_' + encrypted_credential['api_base'] + + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps(encrypted_credential), + is_valid=True, + ) + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_provider_credentials() + assert result['api_base'] == 'valid_api_base' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_credentials_obfuscated(mock_decrypt): + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['api_base'] = 'encrypted_' + encrypted_credential['api_base'] + + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps(encrypted_credential), + is_valid=True, + ) + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_provider_credentials(obfuscated=True) + middle_token = result['api_base'][6:-2] + assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_base']) - 8, 0) + assert all(char == '*' for char in middle_token) diff --git a/api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py b/api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py new file mode 100644 index 0000000000..3f3384834c --- /dev/null +++ b/api/tests/unit_tests/model_providers/test_huggingface_hub_provider.py @@ -0,0 +1,161 @@ +import pytest +from unittest.mock import patch, MagicMock +import json + +from core.model_providers.models.entity.model_params import ModelType +from core.model_providers.providers.base import CredentialsValidateFailedError +from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider +from models.provider import ProviderType, Provider, ProviderModel + +PROVIDER_NAME = 'huggingface_hub' +MODEL_PROVIDER_CLASS = HuggingfaceHubProvider +HOSTED_INFERENCE_API_VALIDATE_CREDENTIAL = { + 'huggingfacehub_api_type': 'hosted_inference_api', + 'huggingfacehub_api_token': 'valid_key' +} + +INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL = { + 'huggingfacehub_api_type': 'inference_endpoints', + 'huggingfacehub_api_token': 'valid_key', + 'huggingfacehub_endpoint_url': 'valid_url' +} + +def encrypt_side_effect(tenant_id, encrypt_key): + return f'encrypted_{encrypt_key}' + + +def decrypt_side_effect(tenant_id, encrypted_key): + return encrypted_key.replace('encrypted_', '') + + +@patch('huggingface_hub.hf_api.ModelInfo') +def test_hosted_inference_api_is_credentials_valid_or_raise_valid(mock_model_info, mocker): + mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation') + mocker.patch('langchain.llms.huggingface_hub.HuggingFaceHub._call', return_value="abc") + + MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials=HOSTED_INFERENCE_API_VALIDATE_CREDENTIAL + ) + +@patch('huggingface_hub.hf_api.ModelInfo') +def test_hosted_inference_api_is_credentials_valid_or_raise_invalid(mock_model_info): + mock_model_info.return_value = MagicMock(pipeline_tag='text2text-generation') + + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials={} + ) + + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials={ + 'huggingfacehub_api_type': 'hosted_inference_api', + }) + + +def test_inference_endpoints_is_credentials_valid_or_raise_valid(mocker): + mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None) + mocker.patch('langchain.llms.huggingface_endpoint.HuggingFaceEndpoint._call', return_value="abc") + + MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials=INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL + ) + +def test_inference_endpoints_is_credentials_valid_or_raise_invalid(mocker): + mocker.patch('huggingface_hub.hf_api.HfApi.whoami', return_value=None) + + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials={} + ) + + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials={ + 'huggingfacehub_api_type': 'inference_endpoints', + 'huggingfacehub_endpoint_url': 'valid_url' + }) + + +@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) +def test_encrypt_model_credentials(mock_encrypt): + api_key = 'valid_key' + result = MODEL_PROVIDER_CLASS.encrypt_model_credentials( + tenant_id='tenant_id', + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials=INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL.copy() + ) + mock_encrypt.assert_called_with('tenant_id', api_key) + assert result['huggingfacehub_api_token'] == f'encrypted_{api_key}' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_model_credentials_custom(mock_decrypt, mocker): + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=None, + is_valid=True, + ) + + encrypted_credential = INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL.copy() + encrypted_credential['huggingfacehub_api_token'] = 'encrypted_' + encrypted_credential['huggingfacehub_api_token'] + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + encrypted_config=json.dumps(encrypted_credential) + ) + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_model_credentials( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION + ) + assert result['huggingfacehub_api_token'] == 'valid_key' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_model_credentials_obfuscated(mock_decrypt, mocker): + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=None, + is_valid=True, + ) + + encrypted_credential = INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL.copy() + encrypted_credential['huggingfacehub_api_token'] = 'encrypted_' + encrypted_credential['huggingfacehub_api_token'] + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + encrypted_config=json.dumps(encrypted_credential) + ) + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_model_credentials( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + obfuscated=True + ) + middle_token = result['huggingfacehub_api_token'][6:-2] + assert len(middle_token) == max(len(INFERENCE_ENDPOINTS_VALIDATE_CREDENTIAL['huggingfacehub_api_token']) - 8, 0) + assert all(char == '*' for char in middle_token) diff --git a/api/tests/unit_tests/model_providers/test_minimax_provider.py b/api/tests/unit_tests/model_providers/test_minimax_provider.py new file mode 100644 index 0000000000..ec3e476273 --- /dev/null +++ b/api/tests/unit_tests/model_providers/test_minimax_provider.py @@ -0,0 +1,88 @@ +import pytest +from unittest.mock import patch +import json + +from core.model_providers.providers.base import CredentialsValidateFailedError +from core.model_providers.providers.minimax_provider import MinimaxProvider +from models.provider import ProviderType, Provider + + +PROVIDER_NAME = 'minimax' +MODEL_PROVIDER_CLASS = MinimaxProvider +VALIDATE_CREDENTIAL = { + 'minimax_group_id': 'fake-group-id', + 'minimax_api_key': 'valid_key' +} + + +def encrypt_side_effect(tenant_id, encrypt_key): + return f'encrypted_{encrypt_key}' + + +def decrypt_side_effect(tenant_id, encrypted_key): + return encrypted_key.replace('encrypted_', '') + + +def test_is_provider_credentials_valid_or_raise_valid(mocker): + mocker.patch('langchain.llms.minimax.Minimax._call', return_value='abc') + + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL) + + +def test_is_provider_credentials_valid_or_raise_invalid(): + # raise CredentialsValidateFailedError if api_key is not in credentials + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({}) + + credential = VALIDATE_CREDENTIAL.copy() + credential['minimax_api_key'] = 'invalid_key' + + # raise CredentialsValidateFailedError if api_key is invalid + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential) + + +@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) +def test_encrypt_credentials(mock_encrypt): + api_key = 'valid_key' + result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy()) + mock_encrypt.assert_called_with('tenant_id', api_key) + assert result['minimax_api_key'] == f'encrypted_{api_key}' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_credentials_custom(mock_decrypt): + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['minimax_api_key'] = 'encrypted_' + encrypted_credential['minimax_api_key'] + + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps(encrypted_credential), + is_valid=True, + ) + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_provider_credentials() + assert result['minimax_api_key'] == 'valid_key' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_credentials_obfuscated(mock_decrypt): + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['minimax_api_key'] = 'encrypted_' + encrypted_credential['minimax_api_key'] + + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps(encrypted_credential), + is_valid=True, + ) + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_provider_credentials(obfuscated=True) + middle_token = result['minimax_api_key'][6:-2] + assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['minimax_api_key']) - 8, 0) + assert all(char == '*' for char in middle_token) diff --git a/api/tests/unit_tests/model_providers/test_openai_provider.py b/api/tests/unit_tests/model_providers/test_openai_provider.py new file mode 100644 index 0000000000..3e2f717ee0 --- /dev/null +++ b/api/tests/unit_tests/model_providers/test_openai_provider.py @@ -0,0 +1,126 @@ +import pytest +from unittest.mock import patch, MagicMock +import json + +from openai.error import AuthenticationError + +from core.model_providers.providers.base import CredentialsValidateFailedError +from core.model_providers.providers.openai_provider import OpenAIProvider +from models.provider import ProviderType, Provider + +PROVIDER_NAME = 'openai' +MODEL_PROVIDER_CLASS = OpenAIProvider +VALIDATE_CREDENTIAL_KEY = 'openai_api_key' + + +def moderation_side_effect(*args, **kwargs): + if kwargs['api_key'] == 'valid_key': + mock_instance = MagicMock() + mock_instance.request = MagicMock() + return mock_instance, {} + else: + raise AuthenticationError('Invalid credentials') + + +def encrypt_side_effect(tenant_id, encrypt_key): + return f'encrypted_{encrypt_key}' + + +def decrypt_side_effect(tenant_id, encrypted_key): + return encrypted_key.replace('encrypted_', '') + + +@patch('openai.ChatCompletion.create', side_effect=moderation_side_effect) +def test_is_provider_credentials_valid_or_raise_valid(mock_create): + # assert True if api_key is valid + credentials = {VALIDATE_CREDENTIAL_KEY: 'valid_key'} + assert MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credentials) is None + + +@patch('openai.ChatCompletion.create', side_effect=moderation_side_effect) +def test_is_provider_credentials_valid_or_raise_invalid(mock_create): + # raise CredentialsValidateFailedError if api_key is not in credentials + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({}) + + # raise CredentialsValidateFailedError if api_key is invalid + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({VALIDATE_CREDENTIAL_KEY: 'invalid_key'}) + + +@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) +def test_encrypt_credentials(mock_encrypt): + api_key = 'valid_key' + result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', {VALIDATE_CREDENTIAL_KEY: api_key}) + mock_encrypt.assert_called_with('tenant_id', api_key) + assert result[VALIDATE_CREDENTIAL_KEY] == f'encrypted_{api_key}' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_credentials_custom(mock_decrypt): + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps({VALIDATE_CREDENTIAL_KEY: 'encrypted_valid_key'}), + is_valid=True, + ) + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_provider_credentials() + assert result[VALIDATE_CREDENTIAL_KEY] == 'valid_key' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_credentials_custom_str(mock_decrypt): + """ + Only the OpenAI provider needs to be compatible with the previous case where the encrypted_config was stored as a plain string. + + :param mock_decrypt: + :return: + """ + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config='encrypted_valid_key', + is_valid=True, + ) + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_provider_credentials() + assert result[VALIDATE_CREDENTIAL_KEY] == 'valid_key' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_credentials_obfuscated(mock_decrypt): + openai_api_key = 'valid_key' + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps({VALIDATE_CREDENTIAL_KEY: f'encrypted_{openai_api_key}'}), + is_valid=True, + ) + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_provider_credentials(obfuscated=True) + middle_token = result[VALIDATE_CREDENTIAL_KEY][6:-2] + assert len(middle_token) == max(len(openai_api_key) - 8, 0) + assert all(char == '*' for char in middle_token) + + +@patch('core.model_providers.providers.hosted.hosted_model_providers.openai') +def test_get_credentials_hosted(mock_hosted): + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.SYSTEM.value, + encrypted_config='', + is_valid=True + ) + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + mock_hosted.api_key = 'hosted_key' + result = model_provider.get_provider_credentials() + assert result[VALIDATE_CREDENTIAL_KEY] == 'hosted_key' diff --git a/api/tests/unit_tests/model_providers/test_replicate_provider.py b/api/tests/unit_tests/model_providers/test_replicate_provider.py new file mode 100644 index 0000000000..e555636f0f --- /dev/null +++ b/api/tests/unit_tests/model_providers/test_replicate_provider.py @@ -0,0 +1,125 @@ +import pytest +from unittest.mock import patch, MagicMock +import json + +from core.model_providers.models.entity.model_params import ModelType +from core.model_providers.providers.base import CredentialsValidateFailedError +from core.model_providers.providers.replicate_provider import ReplicateProvider +from models.provider import ProviderType, Provider, ProviderModel + +PROVIDER_NAME = 'replicate' +MODEL_PROVIDER_CLASS = ReplicateProvider +VALIDATE_CREDENTIAL = { + 'model_version': 'fake-version', + 'replicate_api_token': 'valid_key' +} + + +def encrypt_side_effect(tenant_id, encrypt_key): + return f'encrypted_{encrypt_key}' + + +def decrypt_side_effect(tenant_id, encrypted_key): + return encrypted_key.replace('encrypted_', '') + + +def test_is_credentials_valid_or_raise_valid(mocker): + mock_query = MagicMock() + mock_query.return_value = None + mocker.patch('replicate.model.ModelCollection.get', return_value=mock_query) + mocker.patch('replicate.model.Model.versions', return_value=mock_query) + + MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials=VALIDATE_CREDENTIAL.copy() + ) + + +def test_is_credentials_valid_or_raise_invalid(): + # raise CredentialsValidateFailedError if replicate_api_token is not in credentials + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials={} + ) + + # raise CredentialsValidateFailedError if replicate_api_token is invalid + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials={'replicate_api_token': 'invalid_key'}) + + +@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) +def test_encrypt_model_credentials(mock_encrypt): + api_key = 'valid_key' + result = MODEL_PROVIDER_CLASS.encrypt_model_credentials( + tenant_id='tenant_id', + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + credentials=VALIDATE_CREDENTIAL.copy() + ) + mock_encrypt.assert_called_with('tenant_id', api_key) + assert result['replicate_api_token'] == f'encrypted_{api_key}' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_model_credentials_custom(mock_decrypt, mocker): + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=None, + is_valid=True, + ) + + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['replicate_api_token'] = 'encrypted_' + encrypted_credential['replicate_api_token'] + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + encrypted_config=json.dumps(encrypted_credential) + ) + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_model_credentials( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION + ) + assert result['replicate_api_token'] == 'valid_key' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_model_credentials_obfuscated(mock_decrypt, mocker): + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=None, + is_valid=True, + ) + + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['replicate_api_token'] = 'encrypted_' + encrypted_credential['replicate_api_token'] + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = ProviderModel( + encrypted_config=json.dumps(encrypted_credential) + ) + mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query) + + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_model_credentials( + model_name='test_model_name', + model_type=ModelType.TEXT_GENERATION, + obfuscated=True + ) + middle_token = result['replicate_api_token'][6:-2] + assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['replicate_api_token']) - 8, 0) + assert all(char == '*' for char in middle_token) diff --git a/api/tests/unit_tests/model_providers/test_spark_provider.py b/api/tests/unit_tests/model_providers/test_spark_provider.py new file mode 100644 index 0000000000..7193221f1d --- /dev/null +++ b/api/tests/unit_tests/model_providers/test_spark_provider.py @@ -0,0 +1,97 @@ +import pytest +from unittest.mock import patch +import json + +from langchain.schema import LLMResult, Generation, AIMessage, ChatResult, ChatGeneration + +from core.model_providers.providers.base import CredentialsValidateFailedError +from core.model_providers.providers.spark_provider import SparkProvider +from models.provider import ProviderType, Provider + + +PROVIDER_NAME = 'spark' +MODEL_PROVIDER_CLASS = SparkProvider +VALIDATE_CREDENTIAL = { + 'app_id': 'valid_app_id', + 'api_key': 'valid_key', + 'api_secret': 'valid_secret' +} + + +def encrypt_side_effect(tenant_id, encrypt_key): + return f'encrypted_{encrypt_key}' + + +def decrypt_side_effect(tenant_id, encrypted_key): + return encrypted_key.replace('encrypted_', '') + + +def test_is_provider_credentials_valid_or_raise_valid(mocker): + mocker.patch('core.third_party.langchain.llms.spark.ChatSpark._generate', + return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content="abc"))])) + + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL) + + +def test_is_provider_credentials_valid_or_raise_invalid(): + # raise CredentialsValidateFailedError if api_key is not in credentials + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({}) + + credential = VALIDATE_CREDENTIAL.copy() + credential['api_key'] = 'invalid_key' + + # raise CredentialsValidateFailedError if api_key is invalid + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential) + + +@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) +def test_encrypt_credentials(mock_encrypt): + result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy()) + assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}' + assert result['api_secret'] == f'encrypted_{VALIDATE_CREDENTIAL["api_secret"]}' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_credentials_custom(mock_decrypt): + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key'] + encrypted_credential['api_secret'] = 'encrypted_' + encrypted_credential['api_secret'] + + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps(encrypted_credential), + is_valid=True, + ) + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_provider_credentials() + assert result['api_key'] == 'valid_key' + assert result['api_secret'] == 'valid_secret' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_credentials_obfuscated(mock_decrypt): + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key'] + encrypted_credential['api_secret'] = 'encrypted_' + encrypted_credential['api_secret'] + + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps(encrypted_credential), + is_valid=True, + ) + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_provider_credentials(obfuscated=True) + middle_token = result['api_key'][6:-2] + middle_secret = result['api_secret'][6:-2] + assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0) + assert len(middle_secret) == max(len(VALIDATE_CREDENTIAL['api_secret']) - 8, 0) + assert all(char == '*' for char in middle_token) + assert all(char == '*' for char in middle_secret) diff --git a/api/tests/unit_tests/model_providers/test_tongyi_provider.py b/api/tests/unit_tests/model_providers/test_tongyi_provider.py new file mode 100644 index 0000000000..275a1908fe --- /dev/null +++ b/api/tests/unit_tests/model_providers/test_tongyi_provider.py @@ -0,0 +1,90 @@ +import pytest +from unittest.mock import patch +import json + +from langchain.schema import LLMResult, Generation + +from core.model_providers.providers.base import CredentialsValidateFailedError +from core.model_providers.providers.minimax_provider import MinimaxProvider +from core.model_providers.providers.tongyi_provider import TongyiProvider +from models.provider import ProviderType, Provider + + +PROVIDER_NAME = 'tongyi' +MODEL_PROVIDER_CLASS = TongyiProvider +VALIDATE_CREDENTIAL = { + 'dashscope_api_key': 'valid_key' +} + + +def encrypt_side_effect(tenant_id, encrypt_key): + return f'encrypted_{encrypt_key}' + + +def decrypt_side_effect(tenant_id, encrypted_key): + return encrypted_key.replace('encrypted_', '') + + +def test_is_provider_credentials_valid_or_raise_valid(mocker): + mocker.patch('langchain.llms.tongyi.Tongyi._generate', return_value=LLMResult(generations=[[Generation(text="abc")]])) + + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL) + + +def test_is_provider_credentials_valid_or_raise_invalid(): + # raise CredentialsValidateFailedError if api_key is not in credentials + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({}) + + credential = VALIDATE_CREDENTIAL.copy() + credential['dashscope_api_key'] = 'invalid_key' + + # raise CredentialsValidateFailedError if api_key is invalid + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential) + + +@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) +def test_encrypt_credentials(mock_encrypt): + api_key = 'valid_key' + result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy()) + mock_encrypt.assert_called_with('tenant_id', api_key) + assert result['dashscope_api_key'] == f'encrypted_{api_key}' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_credentials_custom(mock_decrypt): + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['dashscope_api_key'] = 'encrypted_' + encrypted_credential['dashscope_api_key'] + + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps(encrypted_credential), + is_valid=True, + ) + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_provider_credentials() + assert result['dashscope_api_key'] == 'valid_key' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_credentials_obfuscated(mock_decrypt): + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['dashscope_api_key'] = 'encrypted_' + encrypted_credential['dashscope_api_key'] + + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps(encrypted_credential), + is_valid=True, + ) + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_provider_credentials(obfuscated=True) + middle_token = result['dashscope_api_key'][6:-2] + assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['dashscope_api_key']) - 8, 0) + assert all(char == '*' for char in middle_token) diff --git a/api/tests/unit_tests/model_providers/test_wenxin_provider.py b/api/tests/unit_tests/model_providers/test_wenxin_provider.py new file mode 100644 index 0000000000..9f714bb6d3 --- /dev/null +++ b/api/tests/unit_tests/model_providers/test_wenxin_provider.py @@ -0,0 +1,93 @@ +import pytest +from unittest.mock import patch +import json + +from core.model_providers.providers.base import CredentialsValidateFailedError +from core.model_providers.providers.wenxin_provider import WenxinProvider +from models.provider import ProviderType, Provider + + +PROVIDER_NAME = 'wenxin' +MODEL_PROVIDER_CLASS = WenxinProvider +VALIDATE_CREDENTIAL = { + 'api_key': 'valid_key', + 'secret_key': 'valid_secret' +} + + +def encrypt_side_effect(tenant_id, encrypt_key): + return f'encrypted_{encrypt_key}' + + +def decrypt_side_effect(tenant_id, encrypted_key): + return encrypted_key.replace('encrypted_', '') + + +def test_is_provider_credentials_valid_or_raise_valid(mocker): + mocker.patch('core.third_party.langchain.llms.wenxin.Wenxin._call', return_value="abc") + + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL) + + +def test_is_provider_credentials_valid_or_raise_invalid(): + # raise CredentialsValidateFailedError if api_key is not in credentials + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({}) + + credential = VALIDATE_CREDENTIAL.copy() + credential['api_key'] = 'invalid_key' + + # raise CredentialsValidateFailedError if api_key is invalid + with pytest.raises(CredentialsValidateFailedError): + MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential) + + +@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect) +def test_encrypt_credentials(mock_encrypt): + result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy()) + assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}' + assert result['secret_key'] == f'encrypted_{VALIDATE_CREDENTIAL["secret_key"]}' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_credentials_custom(mock_decrypt): + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key'] + encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key'] + + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps(encrypted_credential), + is_valid=True, + ) + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_provider_credentials() + assert result['api_key'] == 'valid_key' + assert result['secret_key'] == 'valid_secret' + + +@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect) +def test_get_credentials_obfuscated(mock_decrypt): + encrypted_credential = VALIDATE_CREDENTIAL.copy() + encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key'] + encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key'] + + provider = Provider( + id='provider_id', + tenant_id='tenant_id', + provider_name=PROVIDER_NAME, + provider_type=ProviderType.CUSTOM.value, + encrypted_config=json.dumps(encrypted_credential), + is_valid=True, + ) + model_provider = MODEL_PROVIDER_CLASS(provider=provider) + result = model_provider.get_provider_credentials(obfuscated=True) + middle_token = result['api_key'][6:-2] + middle_secret = result['secret_key'][6:-2] + assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0) + assert len(middle_secret) == max(len(VALIDATE_CREDENTIAL['secret_key']) - 8, 0) + assert all(char == '*' for char in middle_token) + assert all(char == '*' for char in middle_secret)