diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000000..21ec0d5fa4 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,30 @@ +# Description + +Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. + +Fixes # (issue) + +## Type of Change + +Please delete options that are not relevant. + +- [ ] Bug fix (non-breaking change which fixes an issue) +- [ ] New feature (non-breaking change which adds functionality) +- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) +- [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs) + +# How Has This Been Tested? + +Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. Please also list any relevant details for your test configuration + +- [ ] TODO + +# Suggested Checklist: + +- [ ] I have performed a self-review of my own code +- [ ] I have commented my code, particularly in hard-to-understand areas +- [ ] My changes generate no new warnings +- [ ] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods +- [ ] `optional` I have made corresponding changes to the documentation +- [ ] `optional` I have added tests that prove my fix is effective or that my feature works +- [ ] `optional` New and existing unit tests pass locally with my changes diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 7b5ed7ddd7..0e1c9d6927 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -10,9 +10,33 @@ concurrency: cancel-in-progress: true jobs: + python-style: + name: Python Style + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Python dependencies + run: pip install ruff + + - name: Ruff check + run: ruff check ./api + + - name: Lint hints + if: failure() + run: echo "Please run 'dev/reformat' to fix the fixable linting errors." + test: name: ESLint and SuperLinter runs-on: ubuntu-latest + needs: python-style steps: - name: Checkout code diff --git a/.github/workflows/tool-test-sdks.yaml b/.github/workflows/tool-test-sdks.yaml new file mode 100644 index 0000000000..575ead4b3b --- /dev/null +++ b/.github/workflows/tool-test-sdks.yaml @@ -0,0 +1,34 @@ +name: Run Unit Test For SDKs + +on: + pull_request: + branches: + - main +jobs: + build: + name: unit test for Node.js SDK + runs-on: ubuntu-latest + + strategy: + matrix: + node-version: [16, 18, 20] + + defaults: + run: + working-directory: sdks/nodejs-client + + steps: + - uses: actions/checkout@v4 + + - name: Use Node.js ${{ matrix.node-version }} + uses: actions/setup-node@v4 + with: + node-version: ${{ matrix.node-version }} + cache: '' + cache-dependency-path: 'yarn.lock' + + - name: Install Dependencies + run: yarn install + + - name: Test + run: yarn test diff --git a/api/.env.example b/api/.env.example index bf5bf7c4e5..d492c1f8be 100644 --- a/api/.env.example +++ b/api/.env.example @@ -81,11 +81,17 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10 # Model Configuration MULTIMODAL_SEND_IMAGE_FORMAT=base64 -# Mail configuration, support: resend -MAIL_TYPE= +# Mail configuration, support: resend, smtp +MAIL_TYPE=resend MAIL_DEFAULT_SEND_FROM=no-reply RESEND_API_KEY= RESEND_API_URL=https://api.resend.com +# smtp configuration +SMTP_SERVER=smtp.gmail.com +SMTP_PORT=587 +SMTP_USERNAME=123 +SMTP_PASSWORD=abc +SMTP_USE_TLS=false # Sentry configuration SENTRY_DSN= @@ -120,4 +126,7 @@ HOSTED_ANTHROPIC_QUOTA_LIMIT=600000 HOSTED_ANTHROPIC_PAID_ENABLED=false ETL_TYPE=dify -UNSTRUCTURED_API_URL= \ No newline at end of file +UNSTRUCTURED_API_URL= + +SSRF_PROXY_HTTP_URL= +SSRF_PROXY_HTTPS_URL= diff --git a/api/app.py b/api/app.py index cb3b226e4c..bcf3856c13 100644 --- a/api/app.py +++ b/api/app.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import os from werkzeug.exceptions import Unauthorized @@ -19,18 +18,28 @@ import threading import time import warnings -from commands import register_commands -from config import CloudEditionConfig, Config -from events import event_handlers -from extensions import (ext_celery, ext_code_based_extension, ext_database, ext_hosting_provider, ext_login, ext_mail, - ext_migrate, ext_redis, ext_sentry, ext_storage) -from extensions.ext_database import db -from extensions.ext_login import login_manager from flask import Flask, Response, request from flask_cors import CORS + +from commands import register_commands +from config import CloudEditionConfig, Config +from extensions import ( + ext_celery, + ext_code_based_extension, + ext_database, + ext_hosting_provider, + ext_login, + ext_mail, + ext_migrate, + ext_redis, + ext_sentry, + ext_storage, +) +from extensions.ext_database import db +from extensions.ext_login import login_manager from libs.passport import PassportService + # DO NOT REMOVE BELOW -from models import account, dataset, model, source, task, tool, tools, web from services.account_service import AccountService # DO NOT REMOVE ABOVE diff --git a/api/commands.py b/api/commands.py index b44f166926..91b50445e6 100644 --- a/api/commands.py +++ b/api/commands.py @@ -3,11 +3,13 @@ import json import secrets import click +from flask import current_app +from werkzeug.exceptions import NotFound + from core.embedding.cached_embedding import CacheEmbedding from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db -from flask import current_app from libs.helper import email as email_validate from libs.password import hash_password, password_pattern, valid_password from libs.rsa import generate_key_pair @@ -15,7 +17,6 @@ from models.account import Tenant from models.dataset import Dataset from models.model import Account from models.provider import Provider, ProviderModel -from werkzeug.exceptions import NotFound @click.command('reset-password', help='Reset the account password.') diff --git a/api/config.py b/api/config.py index 66e4c90a6f..1728a18d0b 100644 --- a/api/config.py +++ b/api/config.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import os import dotenv @@ -87,7 +86,7 @@ class Config: # ------------------------ # General Configurations. # ------------------------ - self.CURRENT_VERSION = "0.5.3" + self.CURRENT_VERSION = "0.5.5" self.COMMIT_SHA = get_env('COMMIT_SHA') self.EDITION = "SELF_HOSTED" self.DEPLOY_ENV = get_env('DEPLOY_ENV') @@ -209,6 +208,12 @@ class Config: self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM') self.RESEND_API_KEY = get_env('RESEND_API_KEY') self.RESEND_API_URL = get_env('RESEND_API_URL') + # SMTP settings + self.SMTP_SERVER = get_env('SMTP_SERVER') + self.SMTP_PORT = get_env('SMTP_PORT') + self.SMTP_USERNAME = get_env('SMTP_USERNAME') + self.SMTP_PASSWORD = get_env('SMTP_PASSWORD') + self.SMTP_USE_TLS = get_bool_env('SMTP_USE_TLS') # ------------------------ # Workpace Configurations. diff --git a/api/constants/languages.py b/api/constants/languages.py index 6a6a3cf311..284f3d8758 100644 --- a/api/constants/languages.py +++ b/api/constants/languages.py @@ -1,9 +1,8 @@ - import json from models.model import AppModelConfig -languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT'] +languages = ['en-US', 'zh-Hans', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP', 'ko-KR', 'ru-RU', 'it-IT', 'uk-UA'] language_timezone_mapping = { 'en-US': 'America/New_York', @@ -16,8 +15,10 @@ language_timezone_mapping = { 'ko-KR': 'Asia/Seoul', 'ru-RU': 'Europe/Moscow', 'it-IT': 'Europe/Rome', + 'uk-UA': 'Europe/Kyiv', } + def supported_language(lang): if lang in languages: return lang @@ -26,6 +27,7 @@ def supported_language(lang): .format(lang=lang)) raise ValueError(error) + user_input_form_template = { "en-US": [ { @@ -67,6 +69,16 @@ user_input_form_template = { } } ], + "ua-UK": [ + { + "paragraph": { + "label": "Запит", + "variable": "default_input", + "required": False, + "default": "" + } + } + ], } demo_model_templates = { @@ -145,7 +157,7 @@ demo_model_templates = { 'Italian', ] } - },{ + }, { "paragraph": { "label": "Query", "variable": "query", @@ -272,7 +284,7 @@ demo_model_templates = { "意大利语", ] } - },{ + }, { "paragraph": { "label": "文本内容", "variable": "query", @@ -323,5 +335,130 @@ demo_model_templates = { ) } ], + 'uk-UA': [{ + "name": "Помічник перекладу", + "icon": "", + "icon_background": "", + "description": "Багатомовний перекладач, який надає можливості перекладу різними мовами, перекладаючи введені користувачем дані на потрібну мову.", + "mode": "completion", + "model_config": AppModelConfig( + provider="openai", + model_id="gpt-3.5-turbo-instruct", + configs={ + "prompt_template": "Будь ласка, перекладіть наступний текст на {{target_language}}:\n", + "prompt_variables": [ + { + "key": "target_language", + "name": "Цільова мова", + "description": "Мова, на яку ви хочете перекласти.", + "type": "select", + "default": "Ukrainian", + "options": [ + "Chinese", + "English", + "Japanese", + "French", + "Russian", + "German", + "Spanish", + "Korean", + "Italian", + ], + }, + ], + "completion_params": { + "max_token": 1000, + "temperature": 0, + "top_p": 0, + "presence_penalty": 0.1, + "frequency_penalty": 0.1, + }, + }, + opening_statement="", + suggested_questions=None, + pre_prompt="Будь ласка, перекладіть наступний текст на {{target_language}}:\n{{query}}\ntranslate:", + model=json.dumps({ + "provider": "openai", + "name": "gpt-3.5-turbo-instruct", + "mode": "completion", + "completion_params": { + "max_tokens": 1000, + "temperature": 0, + "top_p": 0, + "presence_penalty": 0.1, + "frequency_penalty": 0.1, + }, + }), + user_input_form=json.dumps([ + { + "select": { + "label": "Цільова мова", + "variable": "target_language", + "description": "Мова, на яку ви хочете перекласти.", + "default": "Chinese", + "required": True, + 'options': [ + 'Chinese', + 'English', + 'Japanese', + 'French', + 'Russian', + 'German', + 'Spanish', + 'Korean', + 'Italian', + ] + } + }, { + "paragraph": { + "label": "Запит", + "variable": "query", + "required": True, + "default": "" + } + } + ]) + ) + }, + { + "name": "AI інтерв’юер фронтенду", + "icon": "", + "icon_background": "", + "description": "Симульований інтерв’юер фронтенду, який перевіряє рівень кваліфікації у розробці фронтенду через опитування.", + "mode": "chat", + "model_config": AppModelConfig( + provider="openai", + model_id="gpt-3.5-turbo", + configs={ + "introduction": "Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ", + "prompt_template": "Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n", + "prompt_variables": [], + "completion_params": { + "max_token": 300, + "temperature": 0.8, + "top_p": 0.9, + "presence_penalty": 0.1, + "frequency_penalty": 0.1, + }, + }, + opening_statement="Привіт, ласкаво просимо на наше співбесіду. Я інтерв'юер цієї технологічної компанії, і я перевірю ваші навички веб-розробки фронтенду. Далі я поставлю вам декілька технічних запитань. Будь ласка, відповідайте якомога ретельніше. ", + suggested_questions=None, + pre_prompt="Ви будете грати роль інтерв'юера технологічної компанії, перевіряючи навички розробки фронтенду користувача та ставлячи 5-10 чітких технічних питань.\n\nЗверніть увагу:\n- Ставте лише одне запитання за раз.\n- Після того, як користувач відповість на запитання, ставте наступне запитання безпосередньо, не намагаючись виправити будь-які помилки, допущені кандидатом.\n- Якщо ви вважаєте, що користувач не відповів правильно на кілька питань поспіль, задайте менше запитань.\n- Після того, як ви задали останнє запитання, ви можете поставити таке запитання: Чому ви залишили свою попередню роботу? Після того, як користувач відповість на це питання, висловіть своє розуміння та підтримку.\n", + model=json.dumps({ + "provider": "openai", + "name": "gpt-3.5-turbo", + "mode": "chat", + "completion_params": { + "max_tokens": 300, + "temperature": 0.8, + "top_p": 0.9, + "presence_penalty": 0.1, + "frequency_penalty": 0.1, + }, + }), + user_input_form=None + ), + } + ], } diff --git a/api/constants/model_template.py b/api/constants/model_template.py index 5ec0f3125e..5b9a09fd9b 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -1,7 +1,5 @@ import json -from models.model import App, AppModelConfig - model_templates = { # completion default mode 'completion_default': { diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 4b8d27434d..aaa737f83a 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -1,14 +1,15 @@ import os from functools import wraps +from flask import request +from flask_restful import Resource, reqparse +from werkzeug.exceptions import NotFound, Unauthorized + from constants.languages import supported_language from controllers.console import api from controllers.console.wraps import only_edition_cloud from extensions.ext_database import db -from flask import request -from flask_restful import Resource, reqparse from models.model import App, InstalledApp, RecommendedApp -from werkzeug.exceptions import NotFound, Unauthorized def admin_required(view): diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index b8dd1ed5bf..324b831175 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,12 +1,13 @@ import flask_restful -from extensions.ext_database import db from flask_login import current_user from flask_restful import Resource, fields, marshal_with +from werkzeug.exceptions import Forbidden + +from extensions.ext_database import db from libs.helper import TimestampField from libs.login import login_required from models.dataset import Dataset from models.model import ApiToken, App -from werkzeug.exceptions import Forbidden from . import api from .setup import setup_required diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index c7693fb950..fa2b3807e8 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -1,7 +1,8 @@ +from flask_restful import Resource, reqparse + from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from flask_restful import Resource, reqparse from libs.login import login_required from services.advanced_prompt_template_service import AdvancedPromptTemplateService diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 8c7cae9519..1ac8e60dcd 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -1,17 +1,20 @@ +from flask import request +from flask_login import current_user +from flask_restful import Resource, marshal, marshal_with, reqparse +from werkzeug.exceptions import Forbidden + from controllers.console import api from controllers.console.app.error import NoFileUploadedError from controllers.console.datasets.error import TooManyFilesError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_redis import redis_client -from fields.annotation_fields import (annotation_fields, annotation_hit_history_fields, - annotation_hit_history_list_fields, annotation_list_fields) -from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, marshal_with, reqparse +from fields.annotation_fields import ( + annotation_fields, + annotation_hit_history_fields, +) from libs.login import login_required from services.annotation_service import AppAnnotationService -from werkzeug.exceptions import Forbidden class AnnotationReplyActionApi(Resource): diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 2aac27af3e..87cad07462 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,8 +1,11 @@ -# -*- coding:utf-8 -*- import json import logging from datetime import datetime +from flask_login import current_user +from flask_restful import Resource, abort, inputs, marshal_with, reqparse +from werkzeug.exceptions import Forbidden + from constants.languages import demo_model_templates, languages from constants.model_template import model_templates from controllers.console import api @@ -15,16 +18,15 @@ from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager from events.app_event import app_was_created, app_was_deleted from extensions.ext_database import db -from fields.app_fields import (app_detail_fields, app_detail_fields_with_site, app_pagination_fields, - template_list_fields) -from flask import current_app -from flask_login import current_user -from flask_restful import Resource, abort, inputs, marshal_with, reqparse +from fields.app_fields import ( + app_detail_fields, + app_detail_fields_with_site, + app_pagination_fields, + template_list_fields, +) from libs.login import login_required from models.model import App, AppModelConfig, Site -from models.tools import ApiToolProvider from services.app_model_config_service import AppModelConfigService -from werkzeug.exceptions import Forbidden def _get_app(app_id, tenant_id): @@ -130,8 +132,8 @@ class AppListApi(Resource): if not model_instance: raise ProviderNotInitializeError( - f"No Default System Reasoning Model available. Please configure " - f"in the Settings -> Model Provider.") + "No Default System Reasoning Model available. Please configure " + "in the Settings -> Model Provider.") else: model_config_dict["model"]["provider"] = model_instance.provider model_config_dict["model"]["name"] = model_instance.model diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index d7a4f3e3e0..ac90dfcc8d 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -1,24 +1,35 @@ -# -*- coding:utf-8 -*- import logging +from flask import request +from flask_restful import Resource, reqparse +from werkzeug.exceptions import InternalServerError + import services from controllers.console import api from controllers.console.app import _get_app -from controllers.console.app.error import (AppUnavailableError, AudioTooLargeError, CompletionRequestError, - NoAudioUploadedError, ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, ProviderNotSupportSpeechToTextError, - ProviderQuotaExceededError, UnsupportedAudioTypeError) +from controllers.console.app.error import ( + AppUnavailableError, + AudioTooLargeError, + CompletionRequestError, + NoAudioUploadedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, + UnsupportedAudioTypeError, +) from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from flask import request -from flask_restful import Resource from libs.login import login_required from services.audio_service import AudioService -from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError, - ProviderNotSupportSpeechToTextServiceError, UnsupportedAudioTypeServiceError) -from werkzeug.exceptions import InternalServerError +from services.errors.audio import ( + AudioTooLargeServiceError, + NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, + UnsupportedAudioTypeServiceError, +) class ChatMessageAudioApi(Resource): @@ -34,7 +45,9 @@ class ChatMessageAudioApi(Resource): try: response = AudioService.transcript_asr( tenant_id=app_model.tenant_id, - file=file + file=file, + end_user=None, + promot=app_model.app_model_config.pre_prompt ) return response @@ -60,7 +73,7 @@ class ChatMessageAudioApi(Resource): except ValueError as e: raise e except Exception as e: - logging.exception("internal server error.") + logging.exception(f"internal server error, {str(e)}.") raise InternalServerError() @@ -71,10 +84,12 @@ class ChatMessageTextApi(Resource): def post(self, app_id): app_id = str(app_id) app_model = _get_app(app_id, None) + try: response = AudioService.transcript_tts( tenant_id=app_model.tenant_id, text=request.form['text'], + voice=app_model.app_model_config.text_to_speech_dict.get('voice'), streaming=False ) @@ -101,9 +116,50 @@ class ChatMessageTextApi(Resource): except ValueError as e: raise e except Exception as e: - logging.exception("internal server error.") + logging.exception(f"internal server error, {str(e)}.") + raise InternalServerError() + + +class TextModesApi(Resource): + def get(self, app_id: str): + app_model = _get_app(str(app_id)) + + try: + parser = reqparse.RequestParser() + parser.add_argument('language', type=str, required=True, location='args') + args = parser.parse_args() + + response = AudioService.transcript_tts_voices( + tenant_id=app_model.tenant_id, + language=args['language'], + ) + + return response + except services.errors.audio.ProviderNotSupportTextToSpeechLanageServiceError: + raise AppUnavailableError("Text to audio voices language parameter loss.") + except NoAudioUploadedServiceError: + raise NoAudioUploadedError() + except AudioTooLargeServiceError as e: + raise AudioTooLargeError(str(e)) + except UnsupportedAudioTypeServiceError: + raise UnsupportedAudioTypeError() + except ProviderNotSupportSpeechToTextServiceError: + raise ProviderNotSupportSpeechToTextError() + except ProviderTokenNotInitError as ex: + raise ProviderNotInitializeError(ex.description) + except QuotaExceededError: + raise ProviderQuotaExceededError() + except ModelCurrentlyNotSupportError: + raise ProviderModelCurrentlyNotSupportError() + except InvokeError as e: + raise CompletionRequestError(e.description) + except ValueError as e: + raise e + except Exception as e: + logging.exception(f"internal server error, {str(e)}.") raise InternalServerError() api.add_resource(ChatMessageAudioApi, '/apps//audio-to-text') api.add_resource(ChatMessageTextApi, '/apps//text-to-audio') +api.add_resource(TextModesApi, '/apps//text-to-audio/voices') diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index b530a9ee2f..f01d2afa03 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -1,27 +1,33 @@ -# -*- coding:utf-8 -*- import json import logging -from typing import Generator, Union +from collections.abc import Generator +from typing import Union import flask_login +from flask import Response, stream_with_context +from flask_restful import Resource, reqparse +from werkzeug.exceptions import InternalServerError, NotFound + import services from controllers.console import api from controllers.console.app import _get_app -from controllers.console.app.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError, - ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, - ProviderQuotaExceededError) +from controllers.console.app.error import ( + AppUnavailableError, + CompletionRequestError, + ConversationCompletedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.application_queue_manager import ApplicationQueueManager from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from flask import Response, stream_with_context -from flask_restful import Resource, reqparse from libs.helper import uuid_value from libs.login import login_required from services.completion_service import CompletionService -from werkzeug.exceptions import InternalServerError, NotFound # define completion message api for user @@ -163,8 +169,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: return Response(response=json.dumps(response), status=200, mimetype='application/json') else: def generate() -> Generator: - for chunk in response: - yield chunk + yield from response return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index f159f74c71..452b0fddf6 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -1,22 +1,27 @@ from datetime import datetime import pytz +from flask_login import current_user +from flask_restful import Resource, marshal_with, reqparse +from flask_restful.inputs import int_range +from sqlalchemy import func, or_ +from sqlalchemy.orm import joinedload +from werkzeug.exceptions import NotFound + from controllers.console import api from controllers.console.app import _get_app from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from extensions.ext_database import db -from fields.conversation_fields import (conversation_detail_fields, conversation_message_detail_fields, - conversation_pagination_fields, conversation_with_summary_pagination_fields) -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse -from flask_restful.inputs import int_range +from fields.conversation_fields import ( + conversation_detail_fields, + conversation_message_detail_fields, + conversation_pagination_fields, + conversation_with_summary_pagination_fields, +) from libs.helper import datetime_string from libs.login import login_required from models.model import Conversation, Message, MessageAnnotation -from sqlalchemy import func, or_ -from sqlalchemy.orm import joinedload -from werkzeug.exceptions import NotFound class CompletionConversationApi(Resource): diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index d7a320db99..3ec932b5f1 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,13 +1,18 @@ +from flask_login import current_user +from flask_restful import Resource, reqparse + from controllers.console import api -from controllers.console.app.error import (CompletionRequestError, ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, ProviderQuotaExceededError) +from controllers.console.app.error import ( + CompletionRequestError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.generator.llm_generator import LLMGenerator from core.model_runtime.errors.invoke import InvokeError -from flask_login import current_user -from flask_restful import Resource, reqparse from libs.login import login_required diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 50b4e2d983..0064dbe663 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,12 +1,23 @@ import json import logging -from typing import Generator, Union +from collections.abc import Generator +from typing import Union + +from flask import Response, stream_with_context +from flask_login import current_user +from flask_restful import Resource, fields, marshal_with, reqparse +from flask_restful.inputs import int_range +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound from controllers.console import api from controllers.console.app import _get_app -from controllers.console.app.error import (AppMoreLikeThisDisabledError, CompletionRequestError, - ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, - ProviderQuotaExceededError) +from controllers.console.app.error import ( + AppMoreLikeThisDisabledError, + CompletionRequestError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from core.entities.application_entities import InvokeFrom @@ -14,10 +25,6 @@ from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotIni from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from fields.conversation_fields import annotation_fields, message_detail_fields -from flask import Response, stream_with_context -from flask_login import current_user -from flask_restful import Resource, fields, marshal_with, reqparse -from flask_restful.inputs import int_range from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import login_required @@ -28,7 +35,6 @@ from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError from services.message_service import MessageService -from werkzeug.exceptions import Forbidden, InternalServerError, NotFound class ChatMessageListApi(Resource): @@ -241,8 +247,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: return Response(response=json.dumps(response), status=200, mimetype='application/json') else: def generate() -> Generator: - for chunk in response: - yield chunk + yield from response return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index d447bfa756..f67fff4b06 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -1,4 +1,7 @@ -# -*- coding:utf-8 -*- + +from flask import request +from flask_login import current_user +from flask_restful import Resource from controllers.console import api from controllers.console.app import _get_app @@ -6,9 +9,6 @@ from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from events.app_event import app_model_config_was_updated from extensions.ext_database import db -from flask import request -from flask_login import current_user -from flask_restful import Resource from libs.login import login_required from models.model import AppModelConfig from services.app_model_config_service import AppModelConfigService diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index daba012bd9..4e9d9ed9b4 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,4 +1,7 @@ -# -*- coding:utf-8 -*- +from flask_login import current_user +from flask_restful import Resource, marshal_with, reqparse +from werkzeug.exceptions import Forbidden, NotFound + from constants.languages import supported_language from controllers.console import api from controllers.console.app import _get_app @@ -6,11 +9,8 @@ from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from extensions.ext_database import db from fields.app_fields import app_site_fields -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse from libs.login import login_required from models.model import Site -from werkzeug.exceptions import Forbidden, NotFound def parse_app_site_args(): diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index f2c1726433..7aed7da404 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -1,16 +1,16 @@ -# -*- coding:utf-8 -*- from datetime import datetime from decimal import Decimal import pytz +from flask import jsonify +from flask_login import current_user +from flask_restful import Resource, reqparse + from controllers.console import api from controllers.console.app import _get_app from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from extensions.ext_database import db -from flask import jsonify -from flask_login import current_user -from flask_restful import Resource, reqparse from libs.helper import datetime_string from libs.login import login_required diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 0b3672efc9..20e028af99 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -2,14 +2,15 @@ import base64 import secrets from datetime import datetime +from flask_restful import Resource, reqparse + from constants.languages import supported_language from controllers.console import api from controllers.console.error import AlreadyActivateError from extensions.ext_database import db -from flask_restful import Resource, reqparse from libs.helper import email, str_len, timezone from libs.password import hash_password, valid_password -from models.account import AccountStatus, Tenant +from models.account import AccountStatus from services.account_service import RegisterService diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index d0b28c6d4b..293ec1c4d3 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -1,13 +1,14 @@ import logging import requests -from controllers.console import api from flask import current_app, redirect, request from flask_login import current_user from flask_restful import Resource +from werkzeug.exceptions import Forbidden + +from controllers.console import api from libs.login import login_required from libs.oauth_data_source import NotionOAuth -from werkzeug.exceptions import Forbidden from ..setup import setup_required from ..wraps import account_initialization_required diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 2c8fdeeaf5..d8cea95f48 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,14 +1,13 @@ -# -*- coding:utf-8 -*- -import flask import flask_login +from flask import current_app, request +from flask_restful import Resource, reqparse + import services from controllers.console import api from controllers.console.setup import setup_required -from flask import current_app, request -from flask_restful import Resource, reqparse from libs.helper import email from libs.password import valid_password -from services.account_service import AccountService +from services.account_service import AccountService, TenantService class LoginApi(Resource): @@ -30,6 +29,8 @@ class LoginApi(Resource): except services.errors.account.AccountLoginError: return {'code': 'unauthorized', 'message': 'Invalid email or password'}, 401 + TenantService.create_owner_tenant_if_not_exist(account) + AccountService.update_last_login(account, request) # todo: return the user info diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index b7d4e51910..05b1c36873 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -3,13 +3,14 @@ from datetime import datetime from typing import Optional import requests -from constants.languages import languages -from extensions.ext_database import db from flask import current_app, redirect, request from flask_restful import Resource + +from constants.languages import languages +from extensions.ext_database import db from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from models.account import Account, AccountStatus -from services.account_service import AccountService, RegisterService +from services.account_service import AccountService, RegisterService, TenantService from .. import api @@ -75,6 +76,8 @@ class OAuthCallback(Resource): account.initialized_at = datetime.utcnow() db.session.commit() + TenantService.create_owner_tenant_if_not_exist(account) + AccountService.update_last_login(account, request) token = AccountService.get_account_jwt_token(account) diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 71de01c779..72a6129efa 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -1,8 +1,9 @@ +from flask_login import current_user +from flask_restful import Resource, reqparse + from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, only_edition_cloud -from flask_login import current_user -from flask_restful import Resource, reqparse from libs.login import login_required from services.billing_service import BillingService diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index a9ecd3d27d..86fcf704c7 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -1,6 +1,11 @@ import datetime import json +from flask import request +from flask_login import current_user +from flask_restful import Resource, marshal_with, reqparse +from werkzeug.exceptions import NotFound + from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required @@ -8,15 +13,11 @@ from core.data_loader.loader.notion import NotionLoader from core.indexing_runner import IndexingRunner from extensions.ext_database import db from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields -from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse from libs.login import login_required from models.dataset import Document from models.source import DataSourceBinding from services.dataset_service import DatasetService, DocumentService from tasks.document_indexing_sync_task import document_indexing_sync_task -from werkzeug.exceptions import NotFound class DataSourceApi(Resource): diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 01700ea63b..2d26d0ecf4 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,5 +1,9 @@ -# -*- coding:utf-8 -*- import flask_restful +from flask import current_app, request +from flask_login import current_user +from flask_restful import Resource, marshal, marshal_with, reqparse +from werkzeug.exceptions import Forbidden, NotFound + import services from controllers.console import api from controllers.console.apikey import api_key_fields, api_key_list @@ -15,14 +19,10 @@ from extensions.ext_database import db from fields.app_fields import related_app_list from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields from fields.document_fields import document_status_fields -from flask import current_app, request -from flask_login import current_user -from flask_restful import Resource, marshal, marshal_with, reqparse from libs.login import login_required from models.dataset import Dataset, Document, DocumentSegment from models.model import ApiToken, UploadFile from services.dataset_service import DatasetService, DocumentService -from werkzeug.exceptions import Forbidden, NotFound def _validate_name(name): @@ -287,8 +287,8 @@ class DatasetIndexingEstimateApi(Resource): args['indexing_technique']) except LLMBadRequestError: raise ProviderNotInitializeError( - f"No Embedding Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider.") except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) elif args['info_list']['data_source_type'] == 'notion_import': @@ -303,8 +303,8 @@ class DatasetIndexingEstimateApi(Resource): args['indexing_technique']) except LLMBadRequestError: raise ProviderNotInitializeError( - f"No Embedding Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider.") except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) else: diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 586bbafbb0..3fb6f16cd6 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,36 +1,51 @@ -# -*- coding:utf-8 -*- from datetime import datetime -from typing import List + +from flask import request +from flask_login import current_user +from flask_restful import Resource, fields, marshal, marshal_with, reqparse +from sqlalchemy import asc, desc +from werkzeug.exceptions import Forbidden, NotFound import services from controllers.console import api -from controllers.console.app.error import (ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, - ProviderQuotaExceededError) -from controllers.console.datasets.error import (ArchivedDocumentImmutableError, DocumentAlreadyFinishedError, - DocumentIndexingError, InvalidActionError, InvalidMetadataError) +from controllers.console.app.error import ( + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from controllers.console.datasets.error import ( + ArchivedDocumentImmutableError, + DocumentAlreadyFinishedError, + DocumentIndexingError, + InvalidActionError, + InvalidMetadataError, +) from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check -from core.errors.error import (LLMBadRequestError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError, - QuotaExceededError) +from core.errors.error import ( + LLMBadRequestError, + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) from core.indexing_runner import IndexingRunner from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from extensions.ext_redis import redis_client -from fields.document_fields import (dataset_and_document_fields, document_fields, document_status_fields, - document_with_segments_fields) -from flask import request -from flask_login import current_user -from flask_restful import Resource, fields, marshal, marshal_with, reqparse +from fields.document_fields import ( + dataset_and_document_fields, + document_fields, + document_status_fields, + document_with_segments_fields, +) from libs.login import login_required from models.dataset import Dataset, DatasetProcessRule, Document, DocumentSegment from models.model import UploadFile from services.dataset_service import DatasetService, DocumentService -from sqlalchemy import asc, desc from tasks.add_document_to_index_task import add_document_to_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task -from werkzeug.exceptions import Forbidden, NotFound class DocumentResource(Resource): @@ -54,7 +69,7 @@ class DocumentResource(Resource): return document - def get_batch_documents(self, dataset_id: str, batch: str) -> List[Document]: + def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]: dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound('Dataset not found.') @@ -279,8 +294,8 @@ class DatasetInitApi(Resource): ) except InvokeAuthorizationError: raise ProviderNotInitializeError( - f"No Embedding Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider.") except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -355,8 +370,8 @@ class DocumentIndexingEstimateApi(DocumentResource): 'English', dataset_id) except LLMBadRequestError: raise ProviderNotInitializeError( - f"No Embedding Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider.") except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -425,8 +440,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): 'English', dataset_id) except LLMBadRequestError: raise ProviderNotInitializeError( - f"No Embedding Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider.") except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) elif dataset.data_source_type == 'notion_import': @@ -439,8 +454,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): None, 'English', dataset_id) except LLMBadRequestError: raise ProviderNotInitializeError( - f"No Embedding Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider.") except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) else: diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 8de5bc91d7..1395963f1d 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -1,8 +1,12 @@ -# -*- coding:utf-8 -*- import uuid from datetime import datetime import pandas as pd +from flask import request +from flask_login import current_user +from flask_restful import Resource, marshal, reqparse +from werkzeug.exceptions import Forbidden, NotFound + import services from controllers.console import api from controllers.console.app.error import ProviderNotInitializeError @@ -15,16 +19,12 @@ from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import segment_fields -from flask import request -from flask_login import current_user -from flask_restful import Resource, marshal, reqparse from libs.login import login_required from models.dataset import DocumentSegment from services.dataset_service import DatasetService, DocumentService, SegmentService from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task from tasks.disable_segment_from_index_task import disable_segment_from_index_task from tasks.enable_segment_to_index_task import enable_segment_to_index_task -from werkzeug.exceptions import Forbidden, NotFound class DatasetDocumentSegmentListApi(Resource): @@ -142,8 +142,8 @@ class DatasetDocumentSegmentApi(Resource): ) except LLMBadRequestError: raise ProviderNotInitializeError( - f"No Embedding Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider.") except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -233,8 +233,8 @@ class DatasetDocumentSegmentAddApi(Resource): ) except LLMBadRequestError: raise ProviderNotInitializeError( - f"No Embedding Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider.") except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) try: @@ -285,8 +285,8 @@ class DatasetDocumentSegmentUpdateApi(Resource): ) except LLMBadRequestError: raise ProviderNotInitializeError( - f"No Embedding Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider.") except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # check segment diff --git a/api/controllers/console/datasets/file.py b/api/controllers/console/datasets/file.py index c15b3c0cd4..0eba232289 100644 --- a/api/controllers/console/datasets/file.py +++ b/api/controllers/console/datasets/file.py @@ -1,13 +1,18 @@ -import services -from controllers.console import api -from controllers.console.datasets.error import (FileTooLargeError, NoFileUploadedError, TooManyFilesError, - UnsupportedFileTypeError) -from controllers.console.setup import setup_required -from controllers.console.wraps import account_initialization_required -from fields.file_fields import file_fields, upload_config_fields from flask import current_app, request from flask_login import current_user from flask_restful import Resource, marshal_with + +import services +from controllers.console import api +from controllers.console.datasets.error import ( + FileTooLargeError, + NoFileUploadedError, + TooManyFilesError, + UnsupportedFileTypeError, +) +from controllers.console.setup import setup_required +from controllers.console.wraps import account_initialization_required +from fields.file_fields import file_fields, upload_config_fields from libs.login import login_required from services.file_service import ALLOWED_EXTENSIONS, UNSTRUSTURED_ALLOWED_EXTENSIONS, FileService diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index a32a3217e5..faadc9a145 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -1,22 +1,31 @@ import logging +from flask_login import current_user +from flask_restful import Resource, marshal, reqparse +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound + import services from controllers.console import api -from controllers.console.app.error import (CompletionRequestError, ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, ProviderQuotaExceededError) +from controllers.console.app.error import ( + CompletionRequestError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) from controllers.console.datasets.error import DatasetNotInitializedError, HighQualityDatasetOnlyError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from core.errors.error import (LLMBadRequestError, ModelCurrentlyNotSupportError, ProviderTokenNotInitError, - QuotaExceededError) +from core.errors.error import ( + LLMBadRequestError, + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) from core.model_runtime.errors.invoke import InvokeError from fields.hit_testing_fields import hit_testing_record_fields -from flask_login import current_user -from flask_restful import Resource, marshal, reqparse from libs.login import login_required from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService -from werkzeug.exceptions import Forbidden, InternalServerError, NotFound class HitTestingApi(Resource): @@ -67,8 +76,8 @@ class HitTestingApi(Resource): raise ProviderModelCurrentlyNotSupportError() except LLMBadRequestError: raise ProviderNotInitializeError( - f"No Embedding Model or Reranking Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") + "No Embedding Model or Reranking Model available. Please configure a valid provider " + "in the Settings -> Model Provider.") except InvokeError as e: raise CompletionRequestError(e.description) except ValueError as e: diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 784c0c6330..f957d38174 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -1,21 +1,32 @@ -# -*- coding:utf-8 -*- import logging +from flask import request +from werkzeug.exceptions import InternalServerError + import services from controllers.console import api -from controllers.console.app.error import (AppUnavailableError, AudioTooLargeError, CompletionRequestError, - NoAudioUploadedError, ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, ProviderNotSupportSpeechToTextError, - ProviderQuotaExceededError, UnsupportedAudioTypeError) +from controllers.console.app.error import ( + AppUnavailableError, + AudioTooLargeError, + CompletionRequestError, + NoAudioUploadedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, + UnsupportedAudioTypeError, +) from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from flask import request from models.model import AppModelConfig from services.audio_service import AudioService -from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError, - ProviderNotSupportSpeechToTextServiceError, UnsupportedAudioTypeServiceError) -from werkzeug.exceptions import InternalServerError +from services.errors.audio import ( + AudioTooLargeServiceError, + NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, + UnsupportedAudioTypeServiceError, +) class ChatAudioApi(InstalledAppResource): @@ -74,6 +85,7 @@ class ChatTextApi(InstalledAppResource): response = AudioService.transcript_tts( tenant_id=app_model.tenant_id, text=request.form['text'], + voice=app_model.app_model_config.text_to_speech_dict.get('voice'), streaming=False ) return {'data': response.data.decode('latin1')} diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index b608130307..6406d5b3b0 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,14 +1,24 @@ -# -*- coding:utf-8 -*- import json import logging +from collections.abc import Generator from datetime import datetime -from typing import Generator, Union +from typing import Union + +from flask import Response, stream_with_context +from flask_login import current_user +from flask_restful import reqparse +from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.console import api -from controllers.console.app.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError, - ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, - ProviderQuotaExceededError) +from controllers.console.app.error import ( + AppUnavailableError, + CompletionRequestError, + ConversationCompletedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) from controllers.console.explore.error import NotChatAppError, NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource from core.application_queue_manager import ApplicationQueueManager @@ -16,12 +26,8 @@ from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db -from flask import Response, stream_with_context -from flask_login import current_user -from flask_restful import reqparse from libs.helper import uuid_value from services.completion_service import CompletionService -from werkzeug.exceptions import InternalServerError, NotFound # define completion api for user @@ -158,8 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: return Response(response=json.dumps(response), status=200, mimetype='application/json') else: def generate() -> Generator: - for chunk in response: - yield chunk + yield from response return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 1b6b493671..34a5904eca 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,16 +1,16 @@ -# -*- coding:utf-8 -*- +from flask_login import current_user +from flask_restful import marshal_with, reqparse +from flask_restful.inputs import int_range +from werkzeug.exceptions import NotFound + from controllers.console import api from controllers.console.explore.error import NotChatAppError from controllers.console.explore.wraps import InstalledAppResource from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields -from flask_login import current_user -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range -from libs.helper import TimestampField, uuid_value +from libs.helper import uuid_value from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError from services.web_conversation_service import WebConversationService -from werkzeug.exceptions import NotFound class ConversationListApi(InstalledAppResource): diff --git a/api/controllers/console/explore/error.py b/api/controllers/console/explore/error.py index e3180bf987..89c4d113a3 100644 --- a/api/controllers/console/explore/error.py +++ b/api/controllers/console/explore/error.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from libs.exception import BaseHTTPException diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 44c54427a4..920d9141ae 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,18 +1,18 @@ -# -*- coding:utf-8 -*- from datetime import datetime +from flask_login import current_user +from flask_restful import Resource, inputs, marshal_with, reqparse +from sqlalchemy import and_ +from werkzeug.exceptions import BadRequest, Forbidden, NotFound + from controllers.console import api from controllers.console.explore.wraps import InstalledAppResource from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db from fields.installed_app_fields import installed_app_list_fields -from flask_login import current_user -from flask_restful import Resource, inputs, marshal_with, reqparse from libs.login import login_required from models.model import App, InstalledApp, RecommendedApp from services.account_service import TenantService -from sqlalchemy import and_ -from werkzeug.exceptions import BadRequest, Forbidden, NotFound class InstalledAppsListApi(Resource): diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 684ecd8b28..47af28425f 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -1,31 +1,39 @@ -# -*- coding:utf-8 -*- import json import logging -from typing import Generator, Union +from collections.abc import Generator +from typing import Union + +from flask import Response, stream_with_context +from flask_login import current_user +from flask_restful import marshal_with, reqparse +from flask_restful.inputs import int_range +from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.console import api -from controllers.console.app.error import (AppMoreLikeThisDisabledError, CompletionRequestError, - ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, - ProviderQuotaExceededError) -from controllers.console.explore.error import (AppSuggestedQuestionsAfterAnswerDisabledError, NotChatAppError, - NotCompletionAppError) +from controllers.console.app.error import ( + AppMoreLikeThisDisabledError, + CompletionRequestError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from controllers.console.explore.error import ( + AppSuggestedQuestionsAfterAnswerDisabledError, + NotChatAppError, + NotCompletionAppError, +) from controllers.console.explore.wraps import InstalledAppResource from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.message_fields import message_infinite_scroll_pagination_fields -from flask import Response, stream_with_context -from flask_login import current_user -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range -from libs.helper import TimestampField, uuid_value +from libs.helper import uuid_value from services.completion_service import CompletionService from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService -from werkzeug.exceptions import InternalServerError, NotFound class MessageListApi(InstalledAppResource): @@ -115,8 +123,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: return Response(response=json.dumps(response), status=200, mimetype='application/json') else: def generate() -> Generator: - for chunk in response: - yield chunk + yield from response return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index c073ebad01..c4afb0b923 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,11 +1,11 @@ -# -*- coding:utf-8 -*- import json +from flask import current_app +from flask_restful import fields, marshal_with + from controllers.console import api from controllers.console.explore.wraps import InstalledAppResource from extensions.ext_database import db -from flask import current_app -from flask_restful import fields, marshal_with from models.model import AppModelConfig, InstalledApp from models.tools import ApiToolProvider @@ -77,7 +77,7 @@ class ExploreAppMetaApi(InstalledAppResource): # get all tools tools = agent_config.get('tools', []) url_prefix = (current_app.config.get("CONSOLE_API_URL") - + f"/console/api/workspaces/current/tool-provider/builtin/") + + "/console/api/workspaces/current/tool-provider/builtin/") for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 4ce8fbfbe9..fd90be03b1 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,15 +1,15 @@ -# -*- coding:utf-8 -*- +from flask_login import current_user +from flask_restful import Resource, fields, marshal_with +from sqlalchemy import and_ + from constants.languages import languages from controllers.console import api from controllers.console.app.error import AppNotFoundError from controllers.console.wraps import account_initialization_required from extensions.ext_database import db -from flask_login import current_user -from flask_restful import Resource, fields, marshal_with from libs.login import login_required from models.model import App, InstalledApp, RecommendedApp from services.account_service import TenantService -from sqlalchemy import and_ app_fields = { 'id': fields.String, diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 9d355df355..cf86b2fee1 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -1,14 +1,15 @@ +from flask_login import current_user +from flask_restful import fields, marshal_with, reqparse +from flask_restful.inputs import int_range +from werkzeug.exceptions import NotFound + from controllers.console import api from controllers.console.explore.error import NotCompletionAppError from controllers.console.explore.wraps import InstalledAppResource from fields.conversation_fields import message_file_fields -from flask_login import current_user -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range from libs.helper import TimestampField, uuid_value from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService -from werkzeug.exceptions import NotFound feedback_fields = { 'rating': fields.String diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index d02b869bf7..84890f1b46 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -1,12 +1,13 @@ from functools import wraps +from flask_login import current_user +from flask_restful import Resource +from werkzeug.exceptions import NotFound + from controllers.console.wraps import account_initialization_required from extensions.ext_database import db -from flask_login import current_user -from flask_restful import Resource from libs.login import login_required from models.model import InstalledApp -from werkzeug.exceptions import NotFound def installed_app_required(view=None): diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index 78374cf2a9..fa73c44c22 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -1,9 +1,10 @@ +from flask_login import current_user +from flask_restful import Resource, marshal_with, reqparse + from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from fields.api_based_extension_fields import api_based_extension_fields -from flask_login import current_user -from flask_restful import Resource, marshal_with, reqparse from libs.login import login_required from models.api_based_extension import APIBasedExtension from services.api_based_extension_service import APIBasedExtensionService diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 40f86fc235..824549050f 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -1,5 +1,6 @@ from flask_login import current_user from flask_restful import Resource + from services.feature_service import FeatureService from . import api diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index d1994a84c9..b319f706b4 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -2,6 +2,7 @@ import os from flask import current_app, session from flask_restful import Resource, reqparse + from libs.helper import str_len from models.model import DifySetup from services.account_service import TenantService diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 765161ff9d..a8d0dd4344 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,9 +1,9 @@ -# -*- coding:utf-8 -*- from functools import wraps -from extensions.ext_database import db from flask import current_app, request from flask_restful import Resource, reqparse + +from extensions.ext_database import db from libs.helper import email, str_len from libs.password import valid_password from models.model import DifySetup diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index ba49506618..a50e4c41a8 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import json import logging @@ -6,7 +5,6 @@ import logging import requests from flask import current_app from flask_restful import Resource, reqparse -from werkzeug.exceptions import InternalServerError from . import api diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 1f856394e2..b7cfba9d04 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -1,17 +1,21 @@ -# -*- coding:utf-8 -*- from datetime import datetime import pytz -from constants.languages import supported_language -from controllers.console import api -from controllers.console.setup import setup_required -from controllers.console.workspace.error import (AccountAlreadyInitedError, CurrentPasswordIncorrectError, - InvalidInvitationCodeError, RepeatPasswordNotMatchError) -from controllers.console.wraps import account_initialization_required -from extensions.ext_database import db from flask import current_app, request from flask_login import current_user from flask_restful import Resource, fields, marshal_with, reqparse + +from constants.languages import supported_language +from controllers.console import api +from controllers.console.setup import setup_required +from controllers.console.workspace.error import ( + AccountAlreadyInitedError, + CurrentPasswordIncorrectError, + InvalidInvitationCodeError, + RepeatPasswordNotMatchError, +) +from controllers.console.wraps import account_initialization_required +from extensions.ext_database import db from libs.helper import TimestampField, timezone from libs.login import login_required from models.account import AccountIntegrate, InvitationCode diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 84be878545..cf57cd4b24 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,16 +1,17 @@ -# -*- coding:utf-8 -*- +from flask import current_app +from flask_login import current_user +from flask_restful import Resource, abort, fields, marshal_with, reqparse + import services from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db -from flask import current_app -from flask_login import current_user -from flask_restful import Resource, abort, fields, marshal_with, reqparse from libs.helper import TimestampField from libs.login import login_required from models.account import Account from services.account_service import RegisterService, TenantService +from services.errors.account import AccountAlreadyInTenantError account_fields = { 'id': fields.String, @@ -71,6 +72,13 @@ class MemberInviteEmailApi(Resource): 'email': invitee_email, 'url': f'{console_web_url}/activate?email={invitee_email}&token={token}' }) + except AccountAlreadyInTenantError: + invitation_results.append({ + 'status': 'success', + 'email': invitee_email, + 'url': f'{console_web_url}/signin' + }) + break except Exception as e: invitation_results.append({ 'status': 'failed', diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index cb76e5cdd2..c888159f83 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -1,18 +1,19 @@ import io +from flask import send_file +from flask_login import current_user +from flask_restful import Resource, reqparse +from werkzeug.exceptions import Forbidden + from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder -from flask import send_file -from flask_login import current_user -from flask_restful import Resource, reqparse from libs.login import login_required from services.billing_service import BillingService from services.model_provider_service import ModelProviderService -from werkzeug.exceptions import Forbidden class ModelProviderListApi(Resource): diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 305c9f09af..5745c0d408 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -1,16 +1,17 @@ import logging +from flask_login import current_user +from flask_restful import Resource, reqparse +from werkzeug.exceptions import Forbidden + from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.utils.encoders import jsonable_encoder -from flask_login import current_user -from flask_restful import Resource, reqparse from libs.login import login_required from services.model_provider_service import ModelProviderService -from werkzeug.exceptions import Forbidden class DefaultModelApi(Resource): diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index fb42146eee..c2c5286d51 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,15 +1,15 @@ import io -import json + +from flask import send_file +from flask_login import current_user +from flask_restful import Resource, reqparse +from werkzeug.exceptions import Forbidden from controllers.console import api from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required -from flask import send_file -from flask_login import current_user -from flask_restful import Resource, reqparse from libs.login import login_required from services.tools_manage_service import ToolManageService -from werkzeug.exceptions import Forbidden class ToolProviderListApi(Resource): diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 8f00d76f7a..7b3f08f467 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -1,18 +1,22 @@ -# -*- coding:utf-8 -*- import logging +from flask import request +from flask_login import current_user +from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse + import services from controllers.console import api from controllers.console.admin import admin_required -from controllers.console.datasets.error import (FileTooLargeError, NoFileUploadedError, TooManyFilesError, - UnsupportedFileTypeError) +from controllers.console.datasets.error import ( + FileTooLargeError, + NoFileUploadedError, + TooManyFilesError, + UnsupportedFileTypeError, +) from controllers.console.error import AccountNotLinkTenantError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db -from flask import request -from flask_login import current_user -from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse from libs.helper import TimestampField from libs.login import login_required from models.account import Tenant diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 7bfb064a23..d5777a330c 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -1,10 +1,10 @@ -# -*- coding:utf-8 -*- import json from functools import wraps -from controllers.console.workspace.error import AccountNotInitializedError from flask import abort, current_app, request from flask_login import current_user + +from controllers.console.workspace.error import AccountNotInitializedError from services.feature_service import FeatureService from services.operation_service import OperationService diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index 4227f139dd..247b5d45e1 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -1,11 +1,12 @@ -import services -from controllers.files import api from flask import Response, request from flask_restful import Resource +from werkzeug.exceptions import NotFound + +import services +from controllers.files import api from libs.exception import BaseHTTPException from services.account_service import TenantService from services.file_service import FileService -from werkzeug.exceptions import NotFound class ImagePreviewApi(Resource): @@ -40,7 +41,7 @@ class WorkspaceWebappLogoApi(Resource): webapp_logo_file_id = custom_config.get('replace_webapp_logo') if custom_config is not None else None if not webapp_logo_file_id: - raise NotFound(f'webapp logo is not found') + raise NotFound('webapp logo is not found') try: generator, mimetype = FileService.get_public_image_preview( diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index b4a290ec87..0a254c1699 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -1,10 +1,11 @@ -from controllers.files import api -from core.tools.tool_file_manager import ToolFileManager from flask import Response from flask_restful import Resource, reqparse -from libs.exception import BaseHTTPException from werkzeug.exceptions import Forbidden, NotFound +from controllers.files import api +from core.tools.tool_file_manager import ToolFileManager +from libs.exception import BaseHTTPException + class ToolFilePreviewApi(Resource): def get(self, file_id, extension): @@ -31,7 +32,7 @@ class ToolFilePreviewApi(Resource): ) if not result: - raise NotFound(f'file is not found') + raise NotFound('file is not found') generator, mimetype = result except Exception: diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 63591f8f49..9cd9770c09 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,11 +1,11 @@ -# -*- coding:utf-8 -*- import json +from flask import current_app +from flask_restful import fields, marshal_with + from controllers.service_api import api from controllers.service_api.wraps import AppApiResource from extensions.ext_database import db -from flask import current_app -from flask_restful import fields, marshal_with from models.model import App, AppModelConfig from models.tools import ApiToolProvider @@ -77,7 +77,7 @@ class AppMetaApi(AppApiResource): # get all tools tools = agent_config.get('tools', []) url_prefix = (current_app.config.get("CONSOLE_API_URL") - + f"/console/api/workspaces/current/tool-provider/builtin/") + + "/console/api/workspaces/current/tool-provider/builtin/") for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 9c5ae9a836..d2906b1d6e 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -1,21 +1,33 @@ import logging +from flask import request +from flask_restful import reqparse +from werkzeug.exceptions import InternalServerError + import services from controllers.service_api import api -from controllers.service_api.app.error import (AppUnavailableError, AudioTooLargeError, CompletionRequestError, - NoAudioUploadedError, ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, ProviderNotSupportSpeechToTextError, - ProviderQuotaExceededError, UnsupportedAudioTypeError) +from controllers.service_api.app.error import ( + AppUnavailableError, + AudioTooLargeError, + CompletionRequestError, + NoAudioUploadedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, + UnsupportedAudioTypeError, +) from controllers.service_api.wraps import AppApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from flask import request -from flask_restful import reqparse from models.model import App, AppModelConfig from services.audio_service import AudioService -from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError, - ProviderNotSupportSpeechToTextServiceError, UnsupportedAudioTypeServiceError) -from werkzeug.exceptions import InternalServerError +from services.errors.audio import ( + AudioTooLargeServiceError, + NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, + UnsupportedAudioTypeServiceError, +) class AudioApi(AppApiResource): @@ -74,6 +86,7 @@ class TextApi(AppApiResource): tenant_id=app_model.tenant_id, text=args['text'], end_user=args['user'], + voice=app_model.app_model_config.text_to_speech_dict.get('voice'), streaming=args['streaming'] ) diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index cc1ad0888e..5331f796e7 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -1,23 +1,31 @@ import json import logging -from typing import Generator, Union +from collections.abc import Generator +from typing import Union + +from flask import Response, stream_with_context +from flask_restful import reqparse +from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.service_api import api from controllers.service_api.app import create_or_update_end_user_for_user_id -from controllers.service_api.app.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError, - NotChatAppError, ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, ProviderQuotaExceededError) +from controllers.service_api.app.error import ( + AppUnavailableError, + CompletionRequestError, + ConversationCompletedError, + NotChatAppError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) from controllers.service_api.wraps import AppApiResource from core.application_queue_manager import ApplicationQueueManager from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from flask import Response, stream_with_context -from flask_restful import reqparse from libs.helper import uuid_value from services.completion_service import CompletionService -from werkzeug.exceptions import InternalServerError, NotFound class CompletionApi(AppApiResource): @@ -175,8 +183,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: return Response(response=json.dumps(response), status=200, mimetype='application/json') else: def generate() -> Generator: - for chunk in response: - yield chunk + yield from response return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 604e2f93db..3c157bed99 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,16 +1,16 @@ -# -*- coding:utf-8 -*- +from flask import request +from flask_restful import marshal_with, reqparse +from flask_restful.inputs import int_range +from werkzeug.exceptions import NotFound + import services from controllers.service_api import api from controllers.service_api.app import create_or_update_end_user_for_user_id from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import AppApiResource from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields -from flask import request -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range -from libs.helper import TimestampField, uuid_value +from libs.helper import uuid_value from services.conversation_service import ConversationService -from werkzeug.exceptions import NotFound class ConversationApi(AppApiResource): diff --git a/api/controllers/service_api/app/error.py b/api/controllers/service_api/app/error.py index 56beb56949..eb953d0950 100644 --- a/api/controllers/service_api/app/error.py +++ b/api/controllers/service_api/app/error.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from libs.exception import BaseHTTPException diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index 8e7984ced1..a901375ec0 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -1,12 +1,17 @@ +from flask import request +from flask_restful import marshal_with + import services from controllers.service_api import api from controllers.service_api.app import create_or_update_end_user_for_user_id -from controllers.service_api.app.error import (FileTooLargeError, NoFileUploadedError, TooManyFilesError, - UnsupportedFileTypeError) +from controllers.service_api.app.error import ( + FileTooLargeError, + NoFileUploadedError, + TooManyFilesError, + UnsupportedFileTypeError, +) from controllers.service_api.wraps import AppApiResource from fields.file_fields import file_fields -from flask import request -from flask_restful import marshal_with from services.file_service import FileService diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index c90a1fb1e2..d90f536a42 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,4 +1,7 @@ -# -*- coding:utf-8 -*- +from flask_restful import fields, marshal_with, reqparse +from flask_restful.inputs import int_range +from werkzeug.exceptions import NotFound + import services from controllers.service_api import api from controllers.service_api.app import create_or_update_end_user_for_user_id @@ -6,12 +9,9 @@ from controllers.service_api.app.error import NotChatAppError from controllers.service_api.wraps import AppApiResource from extensions.ext_database import db from fields.conversation_fields import message_file_fields -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range from libs.helper import TimestampField, uuid_value from models.model import EndUser, Message from services.message_service import MessageService -from werkzeug.exceptions import NotFound class MessageListApi(AppApiResource): diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 900a796674..60c7ca4549 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,3 +1,6 @@ +from flask import request +from flask_restful import marshal, reqparse + import services.dataset_service from controllers.service_api import api from controllers.service_api.dataset.error import DatasetNameDuplicateError @@ -5,8 +8,6 @@ from controllers.service_api.wraps import DatasetApiResource from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager from fields.dataset_fields import dataset_detail_fields -from flask import request -from flask_restful import marshal, reqparse from libs.login import current_user from models.dataset import Dataset from services.dataset_service import DatasetService diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index d7694070f0..cbe0517ed3 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -1,23 +1,27 @@ import json +from flask import request +from flask_restful import marshal, reqparse +from sqlalchemy import desc +from werkzeug.exceptions import NotFound + import services.dataset_service from controllers.service_api import api from controllers.service_api.app.error import ProviderNotInitializeError -from controllers.service_api.dataset.error import (ArchivedDocumentImmutableError, DocumentIndexingError, - NoFileUploadedError, TooManyFilesError) +from controllers.service_api.dataset.error import ( + ArchivedDocumentImmutableError, + DocumentIndexingError, + NoFileUploadedError, + TooManyFilesError, +) from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check from core.errors.error import ProviderTokenNotInitError from extensions.ext_database import db from fields.document_fields import document_fields, document_status_fields -from flask import request -from flask_login import current_user -from flask_restful import marshal, reqparse from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment from services.dataset_service import DocumentService from services.file_service import FileService -from sqlalchemy import desc -from werkzeug.exceptions import NotFound class DocumentAddByTextApi(DatasetApiResource): diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 4cc313e042..d4a6b6aa4f 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -1,3 +1,7 @@ +from flask_login import current_user +from flask_restful import marshal, reqparse +from werkzeug.exceptions import NotFound + from controllers.service_api import api from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check @@ -6,11 +10,8 @@ from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from fields.segment_fields import segment_fields -from flask_login import current_user -from flask_restful import marshal, reqparse from models.dataset import Dataset, DocumentSegment from services.dataset_service import DatasetService, DocumentService, SegmentService -from werkzeug.exceptions import NotFound class SegmentApi(DatasetApiResource): @@ -45,8 +46,8 @@ class SegmentApi(DatasetApiResource): ) except LLMBadRequestError: raise ProviderNotInitializeError( - f"No Embedding Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider.") except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # validate args @@ -89,8 +90,8 @@ class SegmentApi(DatasetApiResource): ) except LLMBadRequestError: raise ProviderNotInitializeError( - f"No Embedding Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider.") except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -181,8 +182,8 @@ class DatasetSegmentApi(DatasetApiResource): ) except LLMBadRequestError: raise ProviderNotInitializeError( - f"No Embedding Model available. Please configure a valid provider " - f"in the Settings -> Model Provider.") + "No Embedding Model available. Please configure a valid provider " + "in the Settings -> Model Provider.") except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # check segment diff --git a/api/controllers/service_api/index.py b/api/controllers/service_api/index.py index 489018cf9b..932388b562 100644 --- a/api/controllers/service_api/index.py +++ b/api/controllers/service_api/index.py @@ -1,7 +1,8 @@ -from controllers.service_api import api from flask import current_app from flask_restful import Resource +from controllers.service_api import api + class IndexApi(Resource): def get(self): diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 60e573ec93..a0d89fe62f 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,16 +1,16 @@ -# -*- coding:utf-8 -*- from datetime import datetime from functools import wraps -from extensions.ext_database import db from flask import current_app, request from flask_login import user_logged_in from flask_restful import Resource +from werkzeug.exceptions import NotFound, Unauthorized + +from extensions.ext_database import db from libs.login import _get_user from models.account import Account, Tenant, TenantAccountJoin from models.model import ApiToken, App from services.feature_service import FeatureService -from werkzeug.exceptions import NotFound, Unauthorized def validate_app_token(view=None): diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 82a9ad8683..25492b1143 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,11 +1,11 @@ -# -*- coding:utf-8 -*- import json +from flask import current_app +from flask_restful import fields, marshal_with + from controllers.web import api from controllers.web.wraps import WebApiResource from extensions.ext_database import db -from flask import current_app -from flask_restful import fields, marshal_with from models.model import App, AppModelConfig from models.tools import ApiToolProvider @@ -76,7 +76,7 @@ class AppMeta(WebApiResource): # get all tools tools = agent_config.get('tools', []) url_prefix = (current_app.config.get("CONSOLE_API_URL") - + f"/console/api/workspaces/current/tool-provider/builtin/") + + "/console/api/workspaces/current/tool-provider/builtin/") for tool in tools: keys = list(tool.keys()) if len(keys) >= 4: diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 44ca7b660a..c628c16606 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -1,21 +1,32 @@ -# -*- coding:utf-8 -*- import logging +from flask import request +from werkzeug.exceptions import InternalServerError + import services from controllers.web import api -from controllers.web.error import (AppUnavailableError, AudioTooLargeError, CompletionRequestError, - NoAudioUploadedError, ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, ProviderNotSupportSpeechToTextError, - ProviderQuotaExceededError, UnsupportedAudioTypeError) +from controllers.web.error import ( + AppUnavailableError, + AudioTooLargeError, + CompletionRequestError, + NoAudioUploadedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, + UnsupportedAudioTypeError, +) from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from flask import request from models.model import App, AppModelConfig from services.audio_service import AudioService -from services.errors.audio import (AudioTooLargeServiceError, NoAudioUploadedServiceError, - ProviderNotSupportSpeechToTextServiceError, UnsupportedAudioTypeServiceError) -from werkzeug.exceptions import InternalServerError +from services.errors.audio import ( + AudioTooLargeServiceError, + NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, + UnsupportedAudioTypeServiceError, +) class AudioApi(WebApiResource): @@ -57,17 +68,23 @@ class AudioApi(WebApiResource): except ValueError as e: raise e except Exception as e: - logging.exception("internal server error.") + logging.exception(f"internal server error: {str(e)}") raise InternalServerError() class TextApi(WebApiResource): def post(self, app_model: App, end_user): + app_model_config: AppModelConfig = app_model.app_model_config + + if not app_model_config.text_to_speech_dict['enabled']: + raise AppUnavailableError() + try: response = AudioService.transcript_tts( tenant_id=app_model.tenant_id, text=request.form['text'], end_user=end_user.external_user_id, + voice=app_model.app_model_config.text_to_speech_dict.get('voice'), streaming=False ) @@ -94,7 +111,7 @@ class TextApi(WebApiResource): except ValueError as e: raise e except Exception as e: - logging.exception("internal server error.") + logging.exception(f"internal server error: {str(e)}") raise InternalServerError() diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index af571f1ff7..61d4f8c362 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,23 +1,31 @@ -# -*- coding:utf-8 -*- import json import logging -from typing import Generator, Union +from collections.abc import Generator +from typing import Union + +from flask import Response, stream_with_context +from flask_restful import reqparse +from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.web import api -from controllers.web.error import (AppUnavailableError, CompletionRequestError, ConversationCompletedError, - NotChatAppError, NotCompletionAppError, ProviderModelCurrentlyNotSupportError, - ProviderNotInitializeError, ProviderQuotaExceededError) +from controllers.web.error import ( + AppUnavailableError, + CompletionRequestError, + ConversationCompletedError, + NotChatAppError, + NotCompletionAppError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) from controllers.web.wraps import WebApiResource from core.application_queue_manager import ApplicationQueueManager from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError -from flask import Response, stream_with_context -from flask_restful import reqparse from libs.helper import uuid_value from services.completion_service import CompletionService -from werkzeug.exceptions import InternalServerError, NotFound # define completion api for user @@ -146,8 +154,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: return Response(response=json.dumps(response), status=200, mimetype='application/json') else: def generate() -> Generator: - for chunk in response: - yield chunk + yield from response return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 1f17f7883e..c287f2a879 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,15 +1,15 @@ -# -*- coding:utf-8 -*- +from flask_restful import marshal_with, reqparse +from flask_restful.inputs import int_range +from werkzeug.exceptions import NotFound + from controllers.web import api from controllers.web.error import NotChatAppError from controllers.web.wraps import WebApiResource from fields.conversation_fields import conversation_infinite_scroll_pagination_fields, simple_conversation_fields -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range -from libs.helper import TimestampField, uuid_value +from libs.helper import uuid_value from services.conversation_service import ConversationService from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError from services.web_conversation_service import WebConversationService -from werkzeug.exceptions import NotFound class ConversationListApi(WebApiResource): diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index 4566c323a2..9cb3c8f235 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from libs.exception import BaseHTTPException diff --git a/api/controllers/web/file.py b/api/controllers/web/file.py index c43fe6fdf5..ca83f6037a 100644 --- a/api/controllers/web/file.py +++ b/api/controllers/web/file.py @@ -1,10 +1,11 @@ +from flask import request +from flask_restful import marshal_with + import services from controllers.web import api from controllers.web.error import FileTooLargeError, NoFileUploadedError, TooManyFilesError, UnsupportedFileTypeError from controllers.web.wraps import WebApiResource from fields.file_fields import file_fields -from flask import request -from flask_restful import marshal_with from services.file_service import FileService diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 2712e84691..e03bdd63bb 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -1,30 +1,37 @@ -# -*- coding:utf-8 -*- import json import logging -from typing import Generator, Union +from collections.abc import Generator +from typing import Union + +from flask import Response, stream_with_context +from flask_restful import fields, marshal_with, reqparse +from flask_restful.inputs import int_range +from werkzeug.exceptions import InternalServerError, NotFound import services from controllers.web import api -from controllers.web.error import (AppMoreLikeThisDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError, - CompletionRequestError, NotChatAppError, NotCompletionAppError, - ProviderModelCurrentlyNotSupportError, ProviderNotInitializeError, - ProviderQuotaExceededError) +from controllers.web.error import ( + AppMoreLikeThisDisabledError, + AppSuggestedQuestionsAfterAnswerDisabledError, + CompletionRequestError, + NotChatAppError, + NotCompletionAppError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) from controllers.web.wraps import WebApiResource from core.entities.application_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import message_file_fields from fields.message_fields import agent_thought_fields -from flask import Response, stream_with_context -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range from libs.helper import TimestampField, uuid_value from services.completion_service import CompletionService from services.errors.app import MoreLikeThisDisabledError from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError from services.message_service import MessageService -from werkzeug.exceptions import InternalServerError, NotFound class MessageListApi(WebApiResource): @@ -153,8 +160,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: return Response(response=json.dumps(response), status=200, mimetype='application/json') else: def generate() -> Generator: - for chunk in response: - yield chunk + yield from response return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index bc6cf6028b..92b28d8125 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -1,13 +1,13 @@ -# -*- coding:utf-8 -*- import uuid +from flask import request +from flask_restful import Resource +from werkzeug.exceptions import NotFound, Unauthorized + from controllers.web import api from extensions.ext_database import db -from flask import request -from flask_restful import Resource from libs.passport import PassportService from models.model import App, EndUser, Site -from werkzeug.exceptions import NotFound, Unauthorized class PassportResource(Resource): diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index b353b9682e..e17869ffdb 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -1,13 +1,14 @@ +from flask_restful import fields, marshal_with, reqparse +from flask_restful.inputs import int_range +from werkzeug.exceptions import NotFound + from controllers.web import api from controllers.web.error import NotCompletionAppError from controllers.web.wraps import WebApiResource from fields.conversation_fields import message_file_fields -from flask_restful import fields, marshal_with, reqparse -from flask_restful.inputs import int_range from libs.helper import TimestampField, uuid_value from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService -from werkzeug.exceptions import NotFound feedback_fields = { 'rating': fields.String diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 9f1297a06c..d8e2d59707 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,14 +1,13 @@ -# -*- coding:utf-8 -*- -import os + +from flask import current_app +from flask_restful import fields, marshal_with +from werkzeug.exceptions import Forbidden from controllers.web import api from controllers.web.wraps import WebApiResource from extensions.ext_database import db -from flask import current_app -from flask_restful import fields, marshal_with from models.model import Site from services.feature_service import FeatureService -from werkzeug.exceptions import Forbidden class AppSiteApi(WebApiResource): diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index 0803a3b5ea..bdaa476f34 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,12 +1,12 @@ -# -*- coding:utf-8 -*- from functools import wraps -from extensions.ext_database import db from flask import request from flask_restful import Resource +from werkzeug.exceptions import NotFound, Unauthorized + +from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site -from werkzeug.exceptions import NotFound, Unauthorized def validate_jwt_token(view=None): diff --git a/api/core/agent/agent/agent_llm_callback.py b/api/core/agent/agent/agent_llm_callback.py index 8331731200..5ec549de8e 100644 --- a/api/core/agent/agent/agent_llm_callback.py +++ b/api/core/agent/agent/agent_llm_callback.py @@ -1,5 +1,5 @@ import logging -from typing import List, Optional +from typing import Optional from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler from core.model_runtime.callbacks.base_callback import Callback @@ -17,7 +17,7 @@ class AgentLLMCallback(Callback): def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ Before invoke callback @@ -38,7 +38,7 @@ class AgentLLMCallback(Callback): def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None): """ On new chunk callback @@ -58,7 +58,7 @@ class AgentLLMCallback(Callback): def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ After invoke callback @@ -80,7 +80,7 @@ class AgentLLMCallback(Callback): def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ Invoke error callback diff --git a/api/core/agent/agent/calc_token_mixin.py b/api/core/agent/agent/calc_token_mixin.py index 1ca6c49812..9c0f9c5b36 100644 --- a/api/core/agent/agent/calc_token_mixin.py +++ b/api/core/agent/agent/calc_token_mixin.py @@ -1,16 +1,14 @@ -from typing import List, cast +from typing import cast from core.entities.application_entities import ModelConfigEntity -from core.entities.message_entities import lc_messages_to_prompt_messages from core.model_runtime.entities.message_entities import PromptMessage from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from langchain.schema import BaseMessage class CalcTokenMixin: - def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int: + def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int: """ Got the rest tokens available for the model after excluding messages tokens and completion max tokens diff --git a/api/core/agent/agent/multi_dataset_router_agent.py b/api/core/agent/agent/multi_dataset_router_agent.py index c13641b84d..eb594c3d21 100644 --- a/api/core/agent/agent/multi_dataset_router_agent.py +++ b/api/core/agent/agent/multi_dataset_router_agent.py @@ -1,10 +1,6 @@ -from typing import Any, List, Optional, Sequence, Tuple, Union, cast +from collections.abc import Sequence +from typing import Any, Optional, Union -from core.entities.application_entities import ModelConfigEntity -from core.entities.message_entities import lc_messages_to_prompt_messages -from core.model_manager import ModelInstance -from core.model_runtime.entities.message_entities import PromptMessageTool -from core.third_party.langchain.llms.fake import FakeLLM from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message from langchain.callbacks.base import BaseCallbackManager @@ -14,6 +10,12 @@ from langchain.schema import AgentAction, AgentFinish, AIMessage, SystemMessage from langchain.tools import BaseTool from pydantic import root_validator +from core.entities.application_entities import ModelConfigEntity +from core.entities.message_entities import lc_messages_to_prompt_messages +from core.model_manager import ModelInstance +from core.model_runtime.entities.message_entities import PromptMessageTool +from core.third_party.langchain.llms.fake import FakeLLM + class MultiDatasetRouterAgent(OpenAIFunctionsAgent): """ @@ -41,7 +43,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): def plan( self, - intermediate_steps: List[Tuple[AgentAction, str]], + intermediate_steps: list[tuple[AgentAction, str]], callbacks: Callbacks = None, **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: @@ -84,7 +86,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): def real_plan( self, - intermediate_steps: List[Tuple[AgentAction, str]], + intermediate_steps: list[tuple[AgentAction, str]], callbacks: Callbacks = None, **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: @@ -145,7 +147,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): async def aplan( self, - intermediate_steps: List[Tuple[AgentAction, str]], + intermediate_steps: list[tuple[AgentAction, str]], callbacks: Callbacks = None, **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: @@ -157,7 +159,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): model_config: ModelConfigEntity, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, - extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, + extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, system_message: Optional[SystemMessage] = SystemMessage( content="You are a helpful AI assistant." ), diff --git a/api/core/agent/agent/openai_function_call.py b/api/core/agent/agent/openai_function_call.py index e17282a293..1f2d5f24b3 100644 --- a/api/core/agent/agent/openai_function_call.py +++ b/api/core/agent/agent/openai_function_call.py @@ -1,4 +1,24 @@ -from typing import Any, List, Optional, Sequence, Tuple, Union, cast +from collections.abc import Sequence +from typing import Any, Optional, Union + +from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent +from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message +from langchain.callbacks.base import BaseCallbackManager +from langchain.callbacks.manager import Callbacks +from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken +from langchain.memory.prompt import SUMMARY_PROMPT +from langchain.prompts.chat import BaseMessagePromptTemplate +from langchain.schema import ( + AgentAction, + AgentFinish, + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, + get_buffer_string, +) +from langchain.tools import BaseTool +from pydantic import root_validator from core.agent.agent.agent_llm_callback import AgentLLMCallback from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError @@ -7,19 +27,7 @@ from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages from core.model_manager import ModelInstance from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.third_party.langchain.llms.fake import FakeLLM -from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent -from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message -from langchain.callbacks.base import BaseCallbackManager -from langchain.callbacks.manager import Callbacks -from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken -from langchain.memory.prompt import SUMMARY_PROMPT -from langchain.prompts.chat import BaseMessagePromptTemplate -from langchain.schema import (AgentAction, AgentFinish, AIMessage, BaseMessage, HumanMessage, SystemMessage, - get_buffer_string) -from langchain.tools import BaseTool -from pydantic import root_validator class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin): @@ -44,7 +52,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi model_config: ModelConfigEntity, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, - extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, + extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, system_message: Optional[SystemMessage] = SystemMessage( content="You are a helpful AI assistant." ), @@ -118,7 +126,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi def plan( self, - intermediate_steps: List[Tuple[AgentAction, str]], + intermediate_steps: list[tuple[AgentAction, str]], callbacks: Callbacks = None, **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: @@ -200,7 +208,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi def return_stopped_response( self, early_stopping_method: str, - intermediate_steps: List[Tuple[AgentAction, str]], + intermediate_steps: list[tuple[AgentAction, str]], **kwargs: Any, ) -> AgentFinish: try: @@ -208,7 +216,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi except ValueError: return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "") - def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]: + def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]: # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0 rest_tokens = self.get_message_rest_tokens( self.model_config, @@ -257,7 +265,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi return new_messages def predict_new_summary( - self, messages: List[BaseMessage], existing_summary: str + self, messages: list[BaseMessage], existing_summary: str ) -> str: new_lines = get_buffer_string( messages, @@ -268,7 +276,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT) return chain.predict(summary=existing_summary, new_lines=new_lines) - def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int: + def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ diff --git a/api/core/agent/agent/structed_multi_dataset_router_agent.py b/api/core/agent/agent/structed_multi_dataset_router_agent.py index c8e6a84b09..e104bb01f9 100644 --- a/api/core/agent/agent/structed_multi_dataset_router_agent.py +++ b/api/core/agent/agent/structed_multi_dataset_router_agent.py @@ -1,8 +1,7 @@ import re -from typing import Any, List, Optional, Sequence, Tuple, Union, cast +from collections.abc import Sequence +from typing import Any, Optional, Union, cast -from core.chain.llm_chain import LLMChain -from core.entities.application_entities import ModelConfigEntity from langchain import BasePromptTemplate, PromptTemplate from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE @@ -13,6 +12,9 @@ from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, Sy from langchain.schema import AgentAction, AgentFinish, OutputParserException from langchain.tools import BaseTool +from core.chain.llm_chain import LLMChain +from core.entities.application_entities import ModelConfigEntity + FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. Valid "action" values: "Final Answer" or {tool_names} @@ -67,7 +69,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): def plan( self, - intermediate_steps: List[Tuple[AgentAction, str]], + intermediate_steps: list[tuple[AgentAction, str]], callbacks: Callbacks = None, **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: @@ -124,8 +126,8 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): suffix: str = SUFFIX, human_message_template: str = HUMAN_MESSAGE_TEMPLATE, format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[List[str]] = None, - memory_prompts: Optional[List[BasePromptTemplate]] = None, + input_variables: Optional[list[str]] = None, + memory_prompts: Optional[list[BasePromptTemplate]] = None, ) -> BasePromptTemplate: tool_strings = [] for tool in tools: @@ -152,7 +154,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): tools: Sequence[BaseTool], prefix: str = PREFIX, format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[List[str]] = None, + input_variables: Optional[list[str]] = None, ) -> PromptTemplate: """Create prompt in the style of the zero shot agent. @@ -179,7 +181,7 @@ Thought: {agent_scratchpad} return PromptTemplate(template=template, input_variables=input_variables) def _construct_scratchpad( - self, intermediate_steps: List[Tuple[AgentAction, str]] + self, intermediate_steps: list[tuple[AgentAction, str]] ) -> str: agent_scratchpad = "" for action, observation in intermediate_steps: @@ -212,8 +214,8 @@ Thought: {agent_scratchpad} suffix: str = SUFFIX, human_message_template: str = HUMAN_MESSAGE_TEMPLATE, format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[List[str]] = None, - memory_prompts: Optional[List[BasePromptTemplate]] = None, + input_variables: Optional[list[str]] = None, + memory_prompts: Optional[list[BasePromptTemplate]] = None, **kwargs: Any, ) -> Agent: """Construct an agent from an LLM and tools.""" diff --git a/api/core/agent/agent/structured_chat.py b/api/core/agent/agent/structured_chat.py index af0130b314..e1be624204 100644 --- a/api/core/agent/agent/structured_chat.py +++ b/api/core/agent/agent/structured_chat.py @@ -1,11 +1,7 @@ import re -from typing import Any, List, Optional, Sequence, Tuple, Union, cast +from collections.abc import Sequence +from typing import Any, Optional, Union, cast -from core.agent.agent.agent_llm_callback import AgentLLMCallback -from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError -from core.chain.llm_chain import LLMChain -from core.entities.application_entities import ModelConfigEntity -from core.entities.message_entities import lc_messages_to_prompt_messages from langchain import BasePromptTemplate, PromptTemplate from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE @@ -14,10 +10,23 @@ from langchain.callbacks.base import BaseCallbackManager from langchain.callbacks.manager import Callbacks from langchain.memory.prompt import SUMMARY_PROMPT from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate -from langchain.schema import (AgentAction, AgentFinish, AIMessage, BaseMessage, HumanMessage, OutputParserException, - get_buffer_string) +from langchain.schema import ( + AgentAction, + AgentFinish, + AIMessage, + BaseMessage, + HumanMessage, + OutputParserException, + get_buffer_string, +) from langchain.tools import BaseTool +from core.agent.agent.agent_llm_callback import AgentLLMCallback +from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError +from core.chain.llm_chain import LLMChain +from core.entities.application_entities import ModelConfigEntity +from core.entities.message_entities import lc_messages_to_prompt_messages + FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input). The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English. Valid "action" values: "Final Answer" or {tool_names} @@ -74,7 +83,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): def plan( self, - intermediate_steps: List[Tuple[AgentAction, str]], + intermediate_steps: list[tuple[AgentAction, str]], callbacks: Callbacks = None, **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: @@ -119,7 +128,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): return AgentFinish({"output": "I'm sorry, the answer of model is invalid, " "I don't know how to respond to that."}, "") - def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs): + def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs): if len(intermediate_steps) >= 2 and self.summary_model_config: should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1] should_summary_messages = [AIMessage(content=observation) @@ -146,7 +155,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): return self.get_full_inputs([intermediate_steps[-1]], **kwargs) def predict_new_summary( - self, messages: List[BaseMessage], existing_summary: str + self, messages: list[BaseMessage], existing_summary: str ) -> str: new_lines = get_buffer_string( messages, @@ -165,8 +174,8 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): suffix: str = SUFFIX, human_message_template: str = HUMAN_MESSAGE_TEMPLATE, format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[List[str]] = None, - memory_prompts: Optional[List[BasePromptTemplate]] = None, + input_variables: Optional[list[str]] = None, + memory_prompts: Optional[list[BasePromptTemplate]] = None, ) -> BasePromptTemplate: tool_strings = [] for tool in tools: @@ -192,7 +201,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): tools: Sequence[BaseTool], prefix: str = PREFIX, format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[List[str]] = None, + input_variables: Optional[list[str]] = None, ) -> PromptTemplate: """Create prompt in the style of the zero shot agent. @@ -219,7 +228,7 @@ Thought: {agent_scratchpad} return PromptTemplate(template=template, input_variables=input_variables) def _construct_scratchpad( - self, intermediate_steps: List[Tuple[AgentAction, str]] + self, intermediate_steps: list[tuple[AgentAction, str]] ) -> str: agent_scratchpad = "" for action, observation in intermediate_steps: @@ -252,8 +261,8 @@ Thought: {agent_scratchpad} suffix: str = SUFFIX, human_message_template: str = HUMAN_MESSAGE_TEMPLATE, format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[List[str]] = None, - memory_prompts: Optional[List[BasePromptTemplate]] = None, + input_variables: Optional[list[str]] = None, + memory_prompts: Optional[list[BasePromptTemplate]] = None, agent_llm_callback: Optional[AgentLLMCallback] = None, **kwargs: Any, ) -> Agent: diff --git a/api/core/agent/agent_executor.py b/api/core/agent/agent_executor.py index 2565fb2315..70fe00ee13 100644 --- a/api/core/agent/agent_executor.py +++ b/api/core/agent/agent_executor.py @@ -2,6 +2,12 @@ import enum import logging from typing import Optional, Union +from langchain.agents import AgentExecutor as LCAgentExecutor +from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent +from langchain.callbacks.manager import Callbacks +from langchain.tools import BaseTool +from pydantic import BaseModel, Extra + from core.agent.agent.agent_llm_callback import AgentLLMCallback from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent @@ -15,11 +21,6 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.errors.invoke import InvokeError from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool -from langchain.agents import AgentExecutor as LCAgentExecutor -from langchain.agents import BaseMultiActionAgent, BaseSingleActionAgent -from langchain.callbacks.manager import Callbacks -from langchain.tools import BaseTool -from pydantic import BaseModel, Extra class PlanningStrategy(str, enum.Enum): diff --git a/api/core/app_runner/app_runner.py b/api/core/app_runner/app_runner.py index d751e301cc..2b8ddc5d4e 100644 --- a/api/core/app_runner/app_runner.py +++ b/api/core/app_runner/app_runner.py @@ -1,10 +1,16 @@ import time -from typing import Generator, List, Optional, Tuple, Union, cast +from collections.abc import Generator +from typing import Optional, Union, cast from core.application_queue_manager import ApplicationQueueManager, PublishFrom -from core.entities.application_entities import (ApplicationGenerateEntity, AppOrchestrationConfigEntity, - ExternalDataVariableEntity, InvokeFrom, ModelConfigEntity, - PromptTemplateEntity) +from core.entities.application_entities import ( + ApplicationGenerateEntity, + AppOrchestrationConfigEntity, + ExternalDataVariableEntity, + InvokeFrom, + ModelConfigEntity, + PromptTemplateEntity, +) from core.features.annotation_reply import AnnotationReplyFeature from core.features.external_data_fetch import ExternalDataFetchFeature from core.features.hosting_moderation import HostingModerationFeature @@ -79,7 +85,7 @@ class AppRunner: return rest_tokens def recale_llm_max_tokens(self, model_config: ModelConfigEntity, - prompt_messages: List[PromptMessage]): + prompt_messages: list[PromptMessage]): # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) @@ -121,7 +127,7 @@ class AppRunner: query: Optional[str] = None, context: Optional[str] = None, memory: Optional[TokenBufferMemory] = None) \ - -> Tuple[List[PromptMessage], Optional[List[str]]]: + -> tuple[list[PromptMessage], Optional[list[str]]]: """ Organize prompt messages :param context: @@ -290,7 +296,7 @@ class AppRunner: tenant_id: str, app_orchestration_config_entity: AppOrchestrationConfigEntity, inputs: dict, - query: str) -> Tuple[bool, dict, str]: + query: str) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. :param app_id: app id diff --git a/api/core/app_runner/assistant_app_runner.py b/api/core/app_runner/assistant_app_runner.py index d0b9bb872c..a4845d0ff1 100644 --- a/api/core/app_runner/assistant_app_runner.py +++ b/api/core/app_runner/assistant_app_runner.py @@ -38,7 +38,7 @@ class AssistantApplicationRunner(AppRunner): """ app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first() if not app_record: - raise ValueError(f"App not found") + raise ValueError("App not found") app_orchestration_config = application_generate_entity.app_orchestration_config_entity @@ -222,6 +222,7 @@ class AssistantApplicationRunner(AppRunner): conversation=conversation, message=message, query=query, + inputs=inputs, ) elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING: assistant_fc_runner = AssistantFunctionCallApplicationRunner( diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py index ae2b712187..e1972efb51 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app_runner/basic_app_runner.py @@ -35,7 +35,7 @@ class BasicApplicationRunner(AppRunner): """ app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first() if not app_record: - raise ValueError(f"App not found") + raise ValueError("App not found") app_orchestration_config = application_generate_entity.app_orchestration_config_entity diff --git a/api/core/app_runner/generate_task_pipeline.py b/api/core/app_runner/generate_task_pipeline.py index 8a7e6f457a..20e4bc7992 100644 --- a/api/core/app_runner/generate_task_pipeline.py +++ b/api/core/app_runner/generate_task_pipeline.py @@ -1,30 +1,45 @@ import json import logging import time -from typing import Generator, Optional, Union, cast +from collections.abc import Generator +from typing import Optional, Union, cast + +from pydantic import BaseModel from core.app_runner.moderation_handler import ModerationRule, OutputModerationHandler from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.entities.application_entities import ApplicationGenerateEntity, InvokeFrom -from core.entities.queue_entities import (AnnotationReplyEvent, QueueAgentMessageEvent, QueueAgentThoughtEvent, - QueueErrorEvent, QueueMessageEndEvent, QueueMessageEvent, - QueueMessageFileEvent, QueueMessageReplaceEvent, QueuePingEvent, - QueueRetrieverResourcesEvent, QueueStopEvent) +from core.entities.queue_entities import ( + AnnotationReplyEvent, + QueueAgentMessageEvent, + QueueAgentThoughtEvent, + QueueErrorEvent, + QueueMessageEndEvent, + QueueMessageEvent, + QueueMessageFileEvent, + QueueMessageReplaceEvent, + QueuePingEvent, + QueueRetrieverResourcesEvent, + QueueStopEvent, +) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - PromptMessage, PromptMessageContentType, PromptMessageRole, - TextPromptMessageContent) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageRole, + TextPromptMessageContent, +) from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.prompt_template import PromptTemplateParser from core.tools.tool_file_manager import ToolFileManager -from core.tools.tool_manager import ToolManager from events.message_event import message_was_created from extensions.ext_database import db from models.model import Conversation, Message, MessageAgentThought, MessageFile -from pydantic import BaseModel from services.annotation_service import AppAnnotationService logger = logging.getLogger(__name__) @@ -104,7 +119,7 @@ class GenerateTaskPipeline: } self._task_state.llm_result.message.content = annotation.content - elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)): + elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent): self._task_state.llm_result = event.llm_result else: @@ -187,7 +202,7 @@ class GenerateTaskPipeline: data = self._error_to_stream_response_data(self._handle_error(event)) yield self._yield_response(data) break - elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)): + elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent): self._task_state.llm_result = event.llm_result else: @@ -339,7 +354,7 @@ class GenerateTaskPipeline: yield self._yield_response(response) - elif isinstance(event, (QueueMessageEvent, QueueAgentMessageEvent)): + elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent): chunk = event.chunk delta_text = chunk.delta.message.content if delta_text is None: @@ -477,7 +492,11 @@ class GenerateTaskPipeline: } # Determine the response based on the type of exception - data = error_responses.get(type(e)) + data = None + for k, v in error_responses.items(): + if isinstance(e, k): + data = v + if data: data.setdefault('message', getattr(e, 'description', str(e))) else: diff --git a/api/core/app_runner/moderation_handler.py b/api/core/app_runner/moderation_handler.py index 24ea085612..b2098344c8 100644 --- a/api/core/app_runner/moderation_handler.py +++ b/api/core/app_runner/moderation_handler.py @@ -1,20 +1,21 @@ import logging import threading import time -from typing import Any, Dict, Optional +from typing import Any, Optional + +from flask import Flask, current_app +from pydantic import BaseModel from core.application_queue_manager import PublishFrom from core.moderation.base import ModerationAction, ModerationOutputsResult from core.moderation.factory import ModerationFactory -from flask import Flask, current_app -from pydantic import BaseModel logger = logging.getLogger(__name__) class ModerationRule(BaseModel): type: str - config: Dict[str, Any] + config: dict[str, Any] class OutputModerationHandler(BaseModel): diff --git a/api/core/application_manager.py b/api/core/application_manager.py index 7f07bed3a5..e073eac4b9 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -2,19 +2,34 @@ import json import logging import threading import uuid -from typing import Any, Generator, Optional, Tuple, Union, cast +from collections.abc import Generator +from typing import Any, Optional, Union, cast + +from flask import Flask, current_app +from pydantic import ValidationError from core.app_runner.assistant_app_runner import AssistantApplicationRunner from core.app_runner.basic_app_runner import BasicApplicationRunner from core.app_runner.generate_task_pipeline import GenerateTaskPipeline from core.application_queue_manager import ApplicationQueueManager, ConversationTaskStoppedException, PublishFrom -from core.entities.application_entities import (AdvancedChatPromptTemplateEntity, - AdvancedCompletionPromptTemplateEntity, AgentEntity, AgentPromptEntity, - AgentToolEntity, ApplicationGenerateEntity, - AppOrchestrationConfigEntity, DatasetEntity, - DatasetRetrieveConfigEntity, ExternalDataVariableEntity, - FileUploadEntity, InvokeFrom, ModelConfigEntity, PromptTemplateEntity, - SensitiveWordAvoidanceEntity) +from core.entities.application_entities import ( + AdvancedChatPromptTemplateEntity, + AdvancedCompletionPromptTemplateEntity, + AgentEntity, + AgentPromptEntity, + AgentToolEntity, + ApplicationGenerateEntity, + AppOrchestrationConfigEntity, + DatasetEntity, + DatasetRetrieveConfigEntity, + ExternalDataVariableEntity, + FileUploadEntity, + InvokeFrom, + ModelConfigEntity, + PromptTemplateEntity, + SensitiveWordAvoidanceEntity, + TextToSpeechEntity, +) from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.file.file_obj import FileObj @@ -26,10 +41,8 @@ from core.prompt.prompt_template import PromptTemplateParser from core.provider_manager import ProviderManager from core.tools.prompt.template import REACT_PROMPT_TEMPLATES from extensions.ext_database import db -from flask import Flask, current_app from models.account import Account from models.model import App, Conversation, EndUser, Message, MessageFile -from pydantic import ValidationError logger = logging.getLogger(__name__) @@ -560,7 +573,11 @@ class ApplicationManager: text_to_speech_dict = copy_app_model_config_dict.get('text_to_speech') if text_to_speech_dict: if 'enabled' in text_to_speech_dict and text_to_speech_dict['enabled']: - properties['text_to_speech'] = True + properties['text_to_speech'] = TextToSpeechEntity( + enabled=text_to_speech_dict.get('enabled'), + voice=text_to_speech_dict.get('voice'), + language=text_to_speech_dict.get('language'), + ) # sensitive word avoidance sensitive_word_avoidance_dict = copy_app_model_config_dict.get('sensitive_word_avoidance') @@ -574,7 +591,7 @@ class ApplicationManager: return AppOrchestrationConfigEntity(**properties) def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \ - -> Tuple[Conversation, Message]: + -> tuple[Conversation, Message]: """ Initialize generate records :param application_generate_entity: application generate entity diff --git a/api/core/application_queue_manager.py b/api/core/application_queue_manager.py index 65b52fd1f3..9590a1e726 100644 --- a/api/core/application_queue_manager.py +++ b/api/core/application_queue_manager.py @@ -1,17 +1,30 @@ import queue import time +from collections.abc import Generator from enum import Enum -from typing import Any, Generator +from typing import Any + +from sqlalchemy.orm import DeclarativeMeta from core.entities.application_entities import InvokeFrom -from core.entities.queue_entities import (AnnotationReplyEvent, AppQueueEvent, QueueAgentMessageEvent, - QueueAgentThoughtEvent, QueueErrorEvent, QueueMessage, QueueMessageEndEvent, - QueueMessageEvent, QueueMessageFileEvent, QueueMessageReplaceEvent, - QueuePingEvent, QueueRetrieverResourcesEvent, QueueStopEvent) +from core.entities.queue_entities import ( + AnnotationReplyEvent, + AppQueueEvent, + QueueAgentMessageEvent, + QueueAgentThoughtEvent, + QueueErrorEvent, + QueueMessage, + QueueMessageEndEvent, + QueueMessageEvent, + QueueMessageFileEvent, + QueueMessageReplaceEvent, + QueuePingEvent, + QueueRetrieverResourcesEvent, + QueueStopEvent, +) from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from extensions.ext_redis import redis_client from models.model import MessageAgentThought, MessageFile -from sqlalchemy.orm import DeclarativeMeta class PublishFrom(Enum): diff --git a/api/core/callback_handler/agent_loop_gather_callback_handler.py b/api/core/callback_handler/agent_loop_gather_callback_handler.py index edee77e25f..1d25b8ab69 100644 --- a/api/core/callback_handler/agent_loop_gather_callback_handler.py +++ b/api/core/callback_handler/agent_loop_gather_callback_handler.py @@ -1,7 +1,11 @@ import json import logging import time -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Optional, Union, cast + +from langchain.agents import openai_functions_agent, openai_functions_multi_agent +from langchain.callbacks.base import BaseCallbackHandler +from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.callback_handler.entity.agent_loop import AgentLoop @@ -10,9 +14,6 @@ from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResu from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from extensions.ext_database import db -from langchain.agents import openai_functions_agent, openai_functions_multi_agent -from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import AgentAction, AgentFinish, BaseMessage, ChatGeneration, LLMResult from models.model import Message, MessageAgentThought, MessageChain @@ -36,7 +37,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): self._message_agent_thought = None @property - def agent_loops(self) -> List[AgentLoop]: + def agent_loops(self) -> list[AgentLoop]: return self._agent_loops def clear_agent_loops(self) -> None: @@ -94,14 +95,14 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): def on_chat_model_start( self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], + serialized: dict[str, Any], + messages: list[list[BaseMessage]], **kwargs: Any ) -> Any: pass def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any ) -> None: pass @@ -119,7 +120,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): def on_tool_start( self, - serialized: Dict[str, Any], + serialized: dict[str, Any], input_str: str, **kwargs: Any, ) -> None: diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index ae77bf6cd1..3fed7d0ad5 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from langchain.callbacks.base import BaseCallbackHandler from langchain.input import print_text @@ -21,7 +21,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel): def on_tool_start( self, tool_name: str, - tool_inputs: Dict[str, Any], + tool_inputs: dict[str, Any], ) -> None: """Do nothing.""" print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color) @@ -29,7 +29,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel): def on_tool_end( self, tool_name: str, - tool_inputs: Dict[str, Any], + tool_inputs: dict[str, Any], tool_outputs: str, ) -> None: """If not the final action, print out observation.""" diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 9947028806..903953486a 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,9 +1,9 @@ -from typing import List, Union + +from langchain.schema import Document from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.entities.application_entities import InvokeFrom from extensions.ext_database import db -from langchain.schema import Document from models.dataset import DatasetQuery, DocumentSegment from models.model import DatasetRetrieverResource @@ -39,22 +39,26 @@ class DatasetIndexToolCallbackHandler: db.session.add(dataset_query) db.session.commit() - def on_tool_end(self, documents: List[Document]) -> None: + def on_tool_end(self, documents: list[Document]) -> None: """Handle tool end.""" for document in documents: - doc_id = document.metadata['doc_id'] + query = db.session.query(DocumentSegment).filter( + DocumentSegment.index_node_id == document.metadata['doc_id'] + ) + + # if 'dataset_id' in document.metadata: + if 'dataset_id' in document.metadata: + query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id']) # add hit count to document segment - db.session.query(DocumentSegment).filter( - DocumentSegment.index_node_id == doc_id - ).update( + query.update( {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False ) db.session.commit() - def return_retriever_resource_info(self, resource: List): + def return_retriever_resource_info(self, resource: list): """Handle return_retriever_resource_info.""" if resource and len(resource) > 0: for item in resource: diff --git a/api/core/callback_handler/std_out_callback_handler.py b/api/core/callback_handler/std_out_callback_handler.py index 9f586d2c9b..1f95471afb 100644 --- a/api/core/callback_handler/std_out_callback_handler.py +++ b/api/core/callback_handler/std_out_callback_handler.py @@ -1,6 +1,6 @@ import os import sys -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from langchain.callbacks.base import BaseCallbackHandler from langchain.input import print_text @@ -16,8 +16,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): def on_chat_model_start( self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], + serialized: dict[str, Any], + messages: list[list[BaseMessage]], **kwargs: Any ) -> Any: print_text("\n[on_chat_model_start]\n", color='blue') @@ -26,7 +26,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): print_text(str(sub_message) + "\n", color='blue') def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any ) -> None: """Print out the prompts.""" print_text("\n[on_llm_start]\n", color='blue') @@ -48,13 +48,13 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue') def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any ) -> None: """Print out that we are entering a chain.""" chain_type = serialized['id'][-1] print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink') - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: """Print out that we finished a chain.""" print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink') @@ -66,7 +66,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): def on_tool_start( self, - serialized: Dict[str, Any], + serialized: dict[str, Any], input_str: str, **kwargs: Any, ) -> None: diff --git a/api/core/chain/llm_chain.py b/api/core/chain/llm_chain.py index 20b71f2f64..86fb156292 100644 --- a/api/core/chain/llm_chain.py +++ b/api/core/chain/llm_chain.py @@ -1,26 +1,27 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Optional + +from langchain import LLMChain as LCLLMChain +from langchain.callbacks.manager import CallbackManagerForChainRun +from langchain.schema import Generation, LLMResult +from langchain.schema.language_model import BaseLanguageModel from core.agent.agent.agent_llm_callback import AgentLLMCallback from core.entities.application_entities import ModelConfigEntity from core.entities.message_entities import lc_messages_to_prompt_messages from core.model_manager import ModelInstance from core.third_party.langchain.llms.fake import FakeLLM -from langchain import LLMChain as LCLLMChain -from langchain.callbacks.manager import CallbackManagerForChainRun -from langchain.schema import Generation, LLMResult -from langchain.schema.language_model import BaseLanguageModel class LLMChain(LCLLMChain): model_config: ModelConfigEntity """The language model instance to use.""" llm: BaseLanguageModel = FakeLLM(response="") - parameters: Dict[str, Any] = {} + parameters: dict[str, Any] = {} agent_llm_callback: Optional[AgentLLMCallback] = None def generate( self, - input_list: List[Dict[str, Any]], + input_list: list[dict[str, Any]], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> LLMResult: """Generate LLM result from inputs.""" diff --git a/api/core/data_loader/file_extractor.py b/api/core/data_loader/file_extractor.py index 14a9693623..4a6eb3654d 100644 --- a/api/core/data_loader/file_extractor.py +++ b/api/core/data_loader/file_extractor.py @@ -1,8 +1,12 @@ import tempfile from pathlib import Path -from typing import List, Optional, Union +from typing import Optional, Union import requests +from flask import current_app +from langchain.document_loaders import Docx2txtLoader, TextLoader +from langchain.schema import Document + from core.data_loader.loader.csv_loader import CSVLoader from core.data_loader.loader.excel import ExcelLoader from core.data_loader.loader.html import HTMLLoader @@ -16,9 +20,6 @@ from core.data_loader.loader.unstructured.unstructured_pptx import UnstructuredP from core.data_loader.loader.unstructured.unstructured_text import UnstructuredTextLoader from core.data_loader.loader.unstructured.unstructured_xml import UnstructuredXmlLoader from extensions.ext_storage import storage -from flask import current_app -from langchain.document_loaders import Docx2txtLoader, TextLoader -from langchain.schema import Document from models.model import UploadFile SUPPORT_URL_CONTENT_TYPES = ['application/pdf', 'text/plain'] @@ -27,7 +28,7 @@ USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTM class FileExtractor: @classmethod - def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document], str]: + def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[list[Document], str]: with tempfile.TemporaryDirectory() as temp_dir: suffix = Path(upload_file.key).suffix file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" @@ -36,7 +37,7 @@ class FileExtractor: return cls.load_from_file(file_path, return_text, upload_file, is_automatic) @classmethod - def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document], str]: + def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]: response = requests.get(url, headers={ "User-Agent": USER_AGENT }) @@ -52,7 +53,7 @@ class FileExtractor: @classmethod def load_from_file(cls, file_path: str, return_text: bool = False, upload_file: Optional[UploadFile] = None, - is_automatic: bool = False) -> Union[List[Document], str]: + is_automatic: bool = False) -> Union[list[Document], str]: input_file = Path(file_path) delimiter = '\n' file_extension = input_file.suffix.lower() diff --git a/api/core/data_loader/loader/csv_loader.py b/api/core/data_loader/loader/csv_loader.py index a4d4ed2b39..ce252c157e 100644 --- a/api/core/data_loader/loader/csv_loader.py +++ b/api/core/data_loader/loader/csv_loader.py @@ -1,6 +1,6 @@ import csv import logging -from typing import Dict, List, Optional +from typing import Optional from langchain.document_loaders import CSVLoader as LCCSVLoader from langchain.document_loaders.helpers import detect_file_encodings @@ -14,7 +14,7 @@ class CSVLoader(LCCSVLoader): self, file_path: str, source_column: Optional[str] = None, - csv_args: Optional[Dict] = None, + csv_args: Optional[dict] = None, encoding: Optional[str] = None, autodetect_encoding: bool = True, ): @@ -24,7 +24,7 @@ class CSVLoader(LCCSVLoader): self.csv_args = csv_args or {} self.autodetect_encoding = autodetect_encoding - def load(self) -> List[Document]: + def load(self) -> list[Document]: """Load data into document objects.""" try: with open(self.file_path, newline="", encoding=self.encoding) as csvfile: diff --git a/api/core/data_loader/loader/excel.py b/api/core/data_loader/loader/excel.py index 5e76c21a8f..cddb298547 100644 --- a/api/core/data_loader/loader/excel.py +++ b/api/core/data_loader/loader/excel.py @@ -1,6 +1,4 @@ -import json import logging -from typing import List from langchain.document_loaders.base import BaseLoader from langchain.schema import Document @@ -24,7 +22,7 @@ class ExcelLoader(BaseLoader): """Initialize with file path.""" self._file_path = file_path - def load(self) -> List[Document]: + def load(self) -> list[Document]: data = [] keys = [] wb = load_workbook(filename=self._file_path, read_only=True) diff --git a/api/core/data_loader/loader/html.py b/api/core/data_loader/loader/html.py index 414975007b..6a9b48a5b2 100644 --- a/api/core/data_loader/loader/html.py +++ b/api/core/data_loader/loader/html.py @@ -1,5 +1,4 @@ import logging -from typing import List from bs4 import BeautifulSoup from langchain.document_loaders.base import BaseLoader @@ -23,7 +22,7 @@ class HTMLLoader(BaseLoader): """Initialize with file path.""" self._file_path = file_path - def load(self) -> List[Document]: + def load(self) -> list[Document]: return [Document(page_content=self._load_as_text())] def _load_as_text(self) -> str: diff --git a/api/core/data_loader/loader/markdown.py b/api/core/data_loader/loader/markdown.py index 545c6b10ed..ecbc6d548f 100644 --- a/api/core/data_loader/loader/markdown.py +++ b/api/core/data_loader/loader/markdown.py @@ -1,6 +1,6 @@ import logging import re -from typing import List, Optional, Tuple, cast +from typing import Optional, cast from langchain.document_loaders.base import BaseLoader from langchain.document_loaders.helpers import detect_file_encodings @@ -42,7 +42,7 @@ class MarkdownLoader(BaseLoader): self._encoding = encoding self._autodetect_encoding = autodetect_encoding - def load(self) -> List[Document]: + def load(self) -> list[Document]: tups = self.parse_tups(self._file_path) documents = [] for header, value in tups: @@ -54,13 +54,13 @@ class MarkdownLoader(BaseLoader): return documents - def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]: + def markdown_to_tups(self, markdown_text: str) -> list[tuple[Optional[str], str]]: """Convert a markdown file to a dictionary. The keys are the headers and the values are the text under each header. """ - markdown_tups: List[Tuple[Optional[str], str]] = [] + markdown_tups: list[tuple[Optional[str], str]] = [] lines = markdown_text.split("\n") current_header = None @@ -103,11 +103,11 @@ class MarkdownLoader(BaseLoader): content = re.sub(pattern, r"\1", content) return content - def parse_tups(self, filepath: str) -> List[Tuple[Optional[str], str]]: + def parse_tups(self, filepath: str) -> list[tuple[Optional[str], str]]: """Parse file into tuples.""" content = "" try: - with open(filepath, "r", encoding=self._encoding) as f: + with open(filepath, encoding=self._encoding) as f: content = f.read() except UnicodeDecodeError as e: if self._autodetect_encoding: diff --git a/api/core/data_loader/loader/notion.py b/api/core/data_loader/loader/notion.py index 914c04d5c0..f8d8837683 100644 --- a/api/core/data_loader/loader/notion.py +++ b/api/core/data_loader/loader/notion.py @@ -1,12 +1,13 @@ import json import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional import requests -from extensions.ext_database import db from flask import current_app from langchain.document_loaders.base import BaseLoader from langchain.schema import Document + +from extensions.ext_database import db from models.dataset import Document as DocumentModel from models.source import DataSourceBinding @@ -66,7 +67,7 @@ class NotionLoader(BaseLoader): document_model=document_model ) - def load(self) -> List[Document]: + def load(self) -> list[Document]: self.update_last_edited_time( self._document_model ) @@ -77,7 +78,7 @@ class NotionLoader(BaseLoader): def _load_data_as_documents( self, notion_obj_id: str, notion_page_type: str - ) -> List[Document]: + ) -> list[Document]: docs = [] if notion_page_type == 'database': # get all the pages in the database @@ -93,8 +94,8 @@ class NotionLoader(BaseLoader): return docs def _get_notion_database_data( - self, database_id: str, query_dict: Dict[str, Any] = {} - ) -> List[Document]: + self, database_id: str, query_dict: dict[str, Any] = {} + ) -> list[Document]: """Get all the pages from a Notion database.""" res = requests.post( DATABASE_URL_TMPL.format(database_id=database_id), @@ -148,12 +149,12 @@ class NotionLoader(BaseLoader): return database_content_list - def _get_notion_block_data(self, page_id: str) -> List[str]: + def _get_notion_block_data(self, page_id: str) -> list[str]: result_lines_arr = [] cur_block_id = page_id while True: block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) - query_dict: Dict[str, Any] = {} + query_dict: dict[str, Any] = {} res = requests.request( "GET", @@ -215,7 +216,7 @@ class NotionLoader(BaseLoader): cur_block_id = block_id while True: block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) - query_dict: Dict[str, Any] = {} + query_dict: dict[str, Any] = {} res = requests.request( "GET", @@ -279,7 +280,7 @@ class NotionLoader(BaseLoader): cur_block_id = block_id while not done: block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) - query_dict: Dict[str, Any] = {} + query_dict: dict[str, Any] = {} res = requests.request( "GET", @@ -345,7 +346,7 @@ class NotionLoader(BaseLoader): else: retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id) - query_dict: Dict[str, Any] = {} + query_dict: dict[str, Any] = {} res = requests.request( "GET", diff --git a/api/core/data_loader/loader/pdf.py b/api/core/data_loader/loader/pdf.py index 8b08393d91..a3452b367b 100644 --- a/api/core/data_loader/loader/pdf.py +++ b/api/core/data_loader/loader/pdf.py @@ -1,10 +1,11 @@ import logging -from typing import List, Optional +from typing import Optional -from extensions.ext_storage import storage from langchain.document_loaders import PyPDFium2Loader from langchain.document_loaders.base import BaseLoader from langchain.schema import Document + +from extensions.ext_storage import storage from models.model import UploadFile logger = logging.getLogger(__name__) @@ -27,7 +28,7 @@ class PdfLoader(BaseLoader): self._file_path = file_path self._upload_file = upload_file - def load(self) -> List[Document]: + def load(self) -> list[Document]: plaintext_file_key = '' plaintext_file_exists = False if self._upload_file: diff --git a/api/core/data_loader/loader/unstructured/unstructured_eml.py b/api/core/data_loader/loader/unstructured/unstructured_eml.py index 26e0ce8cda..2fa3aac133 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_eml.py +++ b/api/core/data_loader/loader/unstructured/unstructured_eml.py @@ -1,6 +1,5 @@ import base64 import logging -from typing import List from bs4 import BeautifulSoup from langchain.document_loaders.base import BaseLoader @@ -24,7 +23,7 @@ class UnstructuredEmailLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> List[Document]: + def load(self) -> list[Document]: from unstructured.partition.email import partition_email elements = partition_email(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_markdown.py b/api/core/data_loader/loader/unstructured/unstructured_markdown.py index cf6e7c9c8a..036a2afd25 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_markdown.py +++ b/api/core/data_loader/loader/unstructured/unstructured_markdown.py @@ -1,5 +1,4 @@ import logging -from typing import List from langchain.document_loaders.base import BaseLoader from langchain.schema import Document @@ -34,7 +33,7 @@ class UnstructuredMarkdownLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> List[Document]: + def load(self) -> list[Document]: from unstructured.partition.md import partition_md elements = partition_md(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_msg.py b/api/core/data_loader/loader/unstructured/unstructured_msg.py index 5a9813237e..495be328ed 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_msg.py +++ b/api/core/data_loader/loader/unstructured/unstructured_msg.py @@ -1,5 +1,4 @@ import logging -from typing import List from langchain.document_loaders.base import BaseLoader from langchain.schema import Document @@ -24,7 +23,7 @@ class UnstructuredMsgLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> List[Document]: + def load(self) -> list[Document]: from unstructured.partition.msg import partition_msg elements = partition_msg(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_ppt.py b/api/core/data_loader/loader/unstructured/unstructured_ppt.py index 9b1e6b5abf..cfac91cc7b 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_ppt.py +++ b/api/core/data_loader/loader/unstructured/unstructured_ppt.py @@ -1,5 +1,4 @@ import logging -from typing import List from langchain.document_loaders.base import BaseLoader from langchain.schema import Document @@ -23,7 +22,7 @@ class UnstructuredPPTLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> List[Document]: + def load(self) -> list[Document]: from unstructured.partition.ppt import partition_ppt elements = partition_ppt(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_pptx.py b/api/core/data_loader/loader/unstructured/unstructured_pptx.py index 0eecee9ffe..41e3bfcb54 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_pptx.py +++ b/api/core/data_loader/loader/unstructured/unstructured_pptx.py @@ -1,5 +1,4 @@ import logging -from typing import List from langchain.document_loaders.base import BaseLoader from langchain.schema import Document @@ -22,7 +21,7 @@ class UnstructuredPPTXLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> List[Document]: + def load(self) -> list[Document]: from unstructured.partition.pptx import partition_pptx elements = partition_pptx(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_text.py b/api/core/data_loader/loader/unstructured/unstructured_text.py index dd684b37f2..09d14fdb17 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_text.py +++ b/api/core/data_loader/loader/unstructured/unstructured_text.py @@ -1,5 +1,4 @@ import logging -from typing import List from langchain.document_loaders.base import BaseLoader from langchain.schema import Document @@ -24,7 +23,7 @@ class UnstructuredTextLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> List[Document]: + def load(self) -> list[Document]: from unstructured.partition.text import partition_text elements = partition_text(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_xml.py b/api/core/data_loader/loader/unstructured/unstructured_xml.py index 0ddbb74b9c..cca6e1b0b7 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_xml.py +++ b/api/core/data_loader/loader/unstructured/unstructured_xml.py @@ -1,5 +1,4 @@ import logging -from typing import List from langchain.document_loaders.base import BaseLoader from langchain.schema import Document @@ -24,7 +23,7 @@ class UnstructuredXmlLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> List[Document]: + def load(self) -> list[Document]: from unstructured.partition.xml import partition_xml elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url) diff --git a/api/core/docstore/dataset_docstore.py b/api/core/docstore/dataset_docstore.py index 49e87ec340..556b3aceda 100644 --- a/api/core/docstore/dataset_docstore.py +++ b/api/core/docstore/dataset_docstore.py @@ -1,12 +1,14 @@ -from typing import Any, Dict, Optional, Sequence, cast +from collections.abc import Sequence +from typing import Any, Optional, cast + +from langchain.schema import Document +from sqlalchemy import func from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from extensions.ext_database import db -from langchain.schema import Document from models.dataset import Dataset, DocumentSegment -from sqlalchemy import func class DatasetDocumentStore: @@ -21,10 +23,10 @@ class DatasetDocumentStore: self._document_id = document_id @classmethod - def from_dict(cls, config_dict: Dict[str, Any]) -> "DatasetDocumentStore": + def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore": return cls(**config_dict) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Serialize to dict.""" return { "dataset_id": self._dataset.id, @@ -39,7 +41,7 @@ class DatasetDocumentStore: return self._user_id @property - def docs(self) -> Dict[str, Document]: + def docs(self) -> dict[str, Document]: document_segments = db.session.query(DocumentSegment).filter( DocumentSegment.dataset_id == self._dataset.id ).all() diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index 185b87b8b6..a86afd817a 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -1,18 +1,17 @@ import base64 -import json import logging -from typing import List, Optional, cast +from typing import Optional, cast import numpy as np +from langchain.embeddings.base import Embeddings +from sqlalchemy.exc import IntegrityError + from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from extensions.ext_database import db from extensions.ext_redis import redis_client -from langchain.embeddings.base import Embeddings from libs import helper -from models.dataset import Embedding -from sqlalchemy.exc import IntegrityError logger = logging.getLogger(__name__) @@ -22,7 +21,7 @@ class CacheEmbedding(Embeddings): self._model_instance = model_instance self._user = user - def embed_documents(self, texts: List[str]) -> List[List[float]]: + def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs in batches of 10.""" text_embeddings = [] try: @@ -53,7 +52,7 @@ class CacheEmbedding(Embeddings): return text_embeddings - def embed_query(self, text: str) -> List[float]: + def embed_query(self, text: str) -> list[float]: """Embed query text.""" # use doc embedding cache or store if not exists hash = helper.generate_text_hash(text) diff --git a/api/core/entities/application_entities.py b/api/core/entities/application_entities.py index 6883a004e4..abcf605c92 100644 --- a/api/core/entities/application_entities.py +++ b/api/core/entities/application_entities.py @@ -1,11 +1,12 @@ from enum import Enum -from typing import Any, Literal, Optional, Union, cast +from typing import Any, Literal, Optional, Union + +from pydantic import BaseModel from core.entities.provider_configuration import ProviderModelBundle from core.file.file_obj import FileObj from core.model_runtime.entities.message_entities import PromptMessageRole from core.model_runtime.entities.model_entities import AIModelEntity -from pydantic import BaseModel class ModelConfigEntity(BaseModel): @@ -41,6 +42,7 @@ class AdvancedCompletionPromptTemplateEntity(BaseModel): """ Advanced Completion Prompt Template Entity. """ + class RolePrefixEntity(BaseModel): """ Role Prefix Entity. @@ -56,6 +58,7 @@ class PromptTemplateEntity(BaseModel): """ Prompt Template Entity. """ + class PromptType(Enum): """ Prompt Type. @@ -96,6 +99,7 @@ class DatasetRetrieveConfigEntity(BaseModel): """ Dataset Retrieve Config Entity. """ + class RetrieveStrategy(Enum): """ Dataset Retrieve Strategy. @@ -142,6 +146,15 @@ class SensitiveWordAvoidanceEntity(BaseModel): config: dict[str, Any] = {} +class TextToSpeechEntity(BaseModel): + """ + Sensitive Word Avoidance Entity. + """ + enabled: bool + voice: Optional[str] = None + language: Optional[str] = None + + class FileUploadEntity(BaseModel): """ File Upload Entity. @@ -158,6 +171,7 @@ class AgentToolEntity(BaseModel): tool_name: str tool_parameters: dict[str, Any] = {} + class AgentPromptEntity(BaseModel): """ Agent Prompt Entity. @@ -165,6 +179,7 @@ class AgentPromptEntity(BaseModel): first_prompt: str next_iteration: str + class AgentScratchpadUnit(BaseModel): """ Agent First Prompt Entity. @@ -181,12 +196,14 @@ class AgentScratchpadUnit(BaseModel): thought: Optional[str] = None action_str: Optional[str] = None observation: Optional[str] = None - action: Optional[Action] = None + action: Optional[Action] = None + class AgentEntity(BaseModel): """ Agent Entity. """ + class Strategy(Enum): """ Agent Strategy. @@ -201,6 +218,7 @@ class AgentEntity(BaseModel): tools: list[AgentToolEntity] = None max_iteration: int = 5 + class AppOrchestrationConfigEntity(BaseModel): """ App Orchestration Config Entity. @@ -218,7 +236,7 @@ class AppOrchestrationConfigEntity(BaseModel): show_retrieve_source: bool = False more_like_this: bool = False speech_to_text: bool = False - text_to_speech: bool = False + text_to_speech: dict = {} sensitive_word_avoidance: Optional[SensitiveWordAvoidanceEntity] = None diff --git a/api/core/entities/message_entities.py b/api/core/entities/message_entities.py index 9b0b287f28..6f767aafc7 100644 --- a/api/core/entities/message_entities.py +++ b/api/core/entities/message_entities.py @@ -1,12 +1,19 @@ import enum from typing import Any, cast -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - PromptMessage, SystemPromptMessage, TextPromptMessageContent, - ToolPromptMessage, UserPromptMessage) from langchain.schema import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage from pydantic import BaseModel +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) + class PromptMessageFileType(enum.Enum): IMAGE = 'image' @@ -34,7 +41,7 @@ class ImagePromptMessageFile(PromptMessageFile): class LCHumanMessageWithFiles(HumanMessage): - # content: Union[str, List[Union[str, Dict]]] + # content: Union[str, list[Union[str, Dict]]] content: str files: list[PromptMessageFile] diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 3888807227..05719e5b8d 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,10 +1,11 @@ from enum import Enum from typing import Optional +from pydantic import BaseModel + from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import ModelType, ProviderModel -from core.model_runtime.entities.provider_entities import ProviderEntity, SimpleProviderEntity -from pydantic import BaseModel +from core.model_runtime.entities.provider_entities import ProviderEntity class ModelStatus(Enum): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index a7a365fe69..b83ae0c8e7 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1,23 +1,28 @@ import datetime import json import logging +from collections.abc import Iterator from json import JSONDecodeError -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Optional + +from pydantic import BaseModel from core.entities.model_entities import ModelStatus, ModelWithProviderEntity, SimpleModelProviderEntity from core.entities.provider_entities import CustomConfiguration, SystemConfiguration, SystemConfigurationStatus from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_runtime.entities.model_entities import FetchFrom, ModelType -from core.model_runtime.entities.provider_entities import (ConfigurateMethod, CredentialFormSchema, FormType, - ProviderEntity) +from core.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers.__base.ai_model import AIModel from core.model_runtime.model_providers.__base.model_provider import ModelProvider -from core.model_runtime.utils import encoders from extensions.ext_database import db from models.provider import Provider, ProviderModel, ProviderType, TenantPreferredModelProvider -from pydantic import BaseModel logger = logging.getLogger(__name__) @@ -131,7 +136,7 @@ class ProviderConfiguration(BaseModel): if self.provider.provider_credential_schema else [] ) - def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]: + def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]: """ Validate custom credentials. :param credentials: provider credentials @@ -278,7 +283,7 @@ class ProviderConfiguration(BaseModel): return None def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \ - -> Tuple[ProviderModel, dict]: + -> tuple[ProviderModel, dict]: """ Validate custom model credentials. @@ -707,7 +712,7 @@ class ProviderConfigurations(BaseModel): Model class for provider configuration dict. """ tenant_id: str - configurations: Dict[str, ProviderConfiguration] = {} + configurations: dict[str, ProviderConfiguration] = {} def __init__(self, tenant_id: str): super().__init__(tenant_id=tenant_id) @@ -755,7 +760,7 @@ class ProviderConfigurations(BaseModel): return all_models - def to_list(self) -> List[ProviderConfiguration]: + def to_list(self) -> list[ProviderConfiguration]: """ Convert to list. diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index ab6fea0a2f..114dfaf911 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -1,9 +1,10 @@ from enum import Enum from typing import Optional +from pydantic import BaseModel + from core.model_runtime.entities.model_entities import ModelType from models.provider import ProviderQuotaType -from pydantic import BaseModel class QuotaUnit(Enum): diff --git a/api/core/entities/queue_entities.py b/api/core/entities/queue_entities.py index d6ef28b138..c1f8fb7e89 100644 --- a/api/core/entities/queue_entities.py +++ b/api/core/entities/queue_entities.py @@ -1,9 +1,10 @@ from enum import Enum from typing import Any -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from pydantic import BaseModel +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk + class QueueEvent(Enum): """ diff --git a/api/core/extension/api_based_extension_requestor.py b/api/core/extension/api_based_extension_requestor.py index c244fe88f1..40e60687b2 100644 --- a/api/core/extension/api_based_extension_requestor.py +++ b/api/core/extension/api_based_extension_requestor.py @@ -1,6 +1,7 @@ import os import requests + from models.api_based_extension import APIBasedExtensionPoint @@ -30,10 +31,10 @@ class APIBasedExtensionRequestor: try: # proxy support for security proxies = None - if os.environ.get("API_BASED_EXTENSION_HTTP_PROXY") and os.environ.get("API_BASED_EXTENSION_HTTPS_PROXY"): + if os.environ.get("SSRF_PROXY_HTTP_URL") and os.environ.get("SSRF_PROXY_HTTPS_URL"): proxies = { - 'http': os.environ.get("API_BASED_EXTENSION_HTTP_PROXY"), - 'https': os.environ.get("API_BASED_EXTENSION_HTTPS_PROXY"), + 'http': os.environ.get("SSRF_PROXY_HTTP_URL"), + 'https': os.environ.get("SSRF_PROXY_HTTPS_URL"), } response = requests.request( diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 6b27062f13..c19aaefe9e 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -61,7 +61,7 @@ class Extensible: builtin_file_path = os.path.join(subdir_path, '__builtin__') if os.path.exists(builtin_file_path): - with open(builtin_file_path, 'r', encoding='utf-8') as f: + with open(builtin_file_path, encoding='utf-8') as f: position = int(f.read().strip()) if (extension_name + '.py') not in file_names: @@ -93,7 +93,7 @@ class Extensible: json_path = os.path.join(subdir_path, 'schema.json') json_data = {} if os.path.exists(json_path): - with open(json_path, 'r', encoding='utf-8') as f: + with open(json_path, encoding='utf-8') as f: json_data = json.load(f) extensions[extension_name] = ModuleExtension( diff --git a/api/core/features/agent_runner.py b/api/core/features/agent_runner.py index 66d41dace0..7412d81281 100644 --- a/api/core/features/agent_runner.py +++ b/api/core/features/agent_runner.py @@ -1,5 +1,7 @@ import logging -from typing import List, Optional, cast +from typing import Optional, cast + +from langchain.tools import BaseTool from core.agent.agent.agent_llm_callback import AgentLLMCallback from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy @@ -7,20 +9,20 @@ from core.application_queue_manager import ApplicationQueueManager from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler -from core.entities.application_entities import (AgentEntity, AgentToolEntity, AppOrchestrationConfigEntity, InvokeFrom, - ModelConfigEntity) +from core.entities.application_entities import ( + AgentEntity, + AppOrchestrationConfigEntity, + InvokeFrom, + ModelConfigEntity, +) from core.memory.token_buffer_memory import TokenBufferMemory from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.model_providers import model_provider_factory from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db -from langchain import WikipediaAPIWrapper -from langchain.callbacks.base import BaseCallbackHandler -from langchain.tools import BaseTool, Tool, WikipediaQueryRun from models.dataset import Dataset from models.model import Message -from pydantic import BaseModel, Field logger = logging.getLogger(__name__) diff --git a/api/core/features/annotation_reply.py b/api/core/features/annotation_reply.py index 09945aaf6e..bdc5467e62 100644 --- a/api/core/features/annotation_reply.py +++ b/api/core/features/annotation_reply.py @@ -1,13 +1,14 @@ import logging from typing import Optional +from flask import current_app + from core.embedding.cached_embedding import CacheEmbedding from core.entities.application_entities import InvokeFrom from core.index.vector_index.vector_index import VectorIndex from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db -from flask import current_app from models.dataset import Dataset from models.model import App, AppAnnotationSetting, Message, MessageAnnotation from services.annotation_service import AppAnnotationService diff --git a/api/core/features/assistant_base_runner.py b/api/core/features/assistant_base_runner.py index 5538918234..c62028eaf0 100644 --- a/api/core/features/assistant_base_runner.py +++ b/api/core/features/assistant_base_runner.py @@ -2,14 +2,20 @@ import json import logging from datetime import datetime from mimetypes import guess_extension -from typing import List, Optional, Tuple, Union, cast +from typing import Optional, Union, cast from core.app_runner.app_runner import AppRunner from core.application_queue_manager import ApplicationQueueManager from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler -from core.entities.application_entities import (AgentEntity, AgentToolEntity, ApplicationGenerateEntity, - AppOrchestrationConfigEntity, InvokeFrom, ModelConfigEntity) +from core.entities.application_entities import ( + AgentEntity, + AgentToolEntity, + ApplicationGenerateEntity, + AppOrchestrationConfigEntity, + InvokeFrom, + ModelConfigEntity, +) from core.file.message_file_parser import FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance @@ -18,8 +24,12 @@ from core.model_runtime.entities.message_entities import PromptMessage, PromptMe from core.model_runtime.entities.model_entities import ModelFeature from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder -from core.tools.entities.tool_entities import (ToolInvokeMessage, ToolInvokeMessageBinary, ToolParameter, - ToolRuntimeVariablePool) +from core.tools.entities.tool_entities import ( + ToolInvokeMessage, + ToolInvokeMessageBinary, + ToolParameter, + ToolRuntimeVariablePool, +) from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool from core.tools.tool.tool import Tool from core.tools.tool_file_manager import ToolFileManager @@ -40,7 +50,7 @@ class BaseAssistantApplicationRunner(AppRunner): message: Message, user_id: str, memory: Optional[TokenBufferMemory] = None, - prompt_messages: Optional[List[PromptMessage]] = None, + prompt_messages: Optional[list[PromptMessage]] = None, variables_pool: Optional[ToolRuntimeVariablePool] = None, db_variables: Optional[ToolConversationVariables] = None, model_instance: ModelInstance = None @@ -112,7 +122,7 @@ class BaseAssistantApplicationRunner(AppRunner): return app_orchestration_config - def _convert_tool_response_to_str(self, tool_response: List[ToolInvokeMessage]) -> str: + def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str: """ Handle tool response """ @@ -124,13 +134,13 @@ class BaseAssistantApplicationRunner(AppRunner): result += f"result link: {response.message}. please tell user to check it." elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \ response.type == ToolInvokeMessage.MessageType.IMAGE: - result += f"image has been created and sent to user already, you should tell user to check it now." + result += "image has been created and sent to user already, you should tell user to check it now." else: result += f"tool response: {response.message}." return result - def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> Tuple[PromptMessageTool, Tool]: + def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]: """ convert tool to prompt message tool """ @@ -315,7 +325,7 @@ class BaseAssistantApplicationRunner(AppRunner): return prompt_tool - def extract_tool_response_binary(self, tool_response: List[ToolInvokeMessage]) -> List[ToolInvokeMessageBinary]: + def extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]: """ Extract tool response binary """ @@ -346,7 +356,7 @@ class BaseAssistantApplicationRunner(AppRunner): return result - def create_message_files(self, messages: List[ToolInvokeMessageBinary]) -> List[Tuple[MessageFile, bool]]: + def create_message_files(self, messages: list[ToolInvokeMessageBinary]) -> list[tuple[MessageFile, bool]]: """ Create message file @@ -394,7 +404,7 @@ class BaseAssistantApplicationRunner(AppRunner): return result def create_agent_thought(self, message_id: str, message: str, - tool_name: str, tool_input: str, messages_ids: List[str] + tool_name: str, tool_input: str, messages_ids: list[str] ) -> MessageAgentThought: """ Create agent thought @@ -439,7 +449,7 @@ class BaseAssistantApplicationRunner(AppRunner): thought: str, observation: str, answer: str, - messages_ids: List[str], + messages_ids: list[str], llm_usage: LLMUsage = None) -> MessageAgentThought: """ Save agent thought @@ -495,7 +505,7 @@ class BaseAssistantApplicationRunner(AppRunner): db.session.commit() - def get_history_prompt_messages(self) -> List[PromptMessage]: + def get_history_prompt_messages(self) -> list[PromptMessage]: """ Get history prompt messages """ @@ -506,7 +516,7 @@ class BaseAssistantApplicationRunner(AppRunner): return self.history_prompt_messages - def transform_tool_invoke_messages(self, messages: List[ToolInvokeMessage]) -> List[ToolInvokeMessage]: + def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: """ Transform tool message into agent thought """ diff --git a/api/core/features/assistant_cot_runner.py b/api/core/features/assistant_cot_runner.py index c9cb2ba61a..b8d08bb5d3 100644 --- a/api/core/features/assistant_cot_runner.py +++ b/api/core/features/assistant_cot_runner.py @@ -1,18 +1,28 @@ import json -import logging import re -from typing import Dict, Generator, List, Literal, Union +from collections.abc import Generator +from typing import Literal, Union from core.application_queue_manager import PublishFrom from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit from core.features.assistant_base_runner import BaseAssistantApplicationRunner -from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) from core.model_runtime.utils.encoders import jsonable_encoder -from core.tools.errors import (ToolInvokeError, ToolNotFoundError, ToolNotSupportedError, ToolParameterValidationError, - ToolProviderCredentialValidationError, ToolProviderNotFoundError) +from core.tools.errors import ( + ToolInvokeError, + ToolNotFoundError, + ToolNotSupportedError, + ToolParameterValidationError, + ToolProviderCredentialValidationError, + ToolProviderNotFoundError, +) from models.model import Conversation, Message @@ -20,6 +30,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): def run(self, conversation: Conversation, message: Message, query: str, + inputs: dict[str, str], ) -> Union[Generator, LLMResult]: """ Run Cot agent application @@ -27,7 +38,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): app_orchestration_config = self.app_orchestration_config self._repack_app_orchestration_config(app_orchestration_config) - agent_scratchpad: List[AgentScratchpadUnit] = [] + agent_scratchpad: list[AgentScratchpadUnit] = [] # check model mode if self.app_orchestration_config.model_config.mode == "completion": @@ -35,13 +46,18 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): if 'Observation' not in app_orchestration_config.model_config.stop: app_orchestration_config.model_config.stop.append('Observation') + # override inputs + inputs = inputs or {} + instruction = self.app_orchestration_config.prompt_template.simple_prompt_template + instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) + iteration_step = 1 max_iteration_steps = min(self.app_orchestration_config.agent.max_iteration, 5) + 1 prompt_messages = self.history_prompt_messages # convert tools into ModelRuntime Tool format - prompt_messages_tools: List[PromptMessageTool] = [] + prompt_messages_tools: list[PromptMessageTool] = [] tool_instances = {} for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []: try: @@ -68,7 +84,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): } final_answer = '' - def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage): + def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if not final_llm_usage_dict['usage']: final_llm_usage_dict['usage'] = usage else: @@ -108,7 +124,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): tools=prompt_messages_tools, agent_scratchpad=agent_scratchpad, agent_prompt_message=app_orchestration_config.agent.prompt, - instruction=app_orchestration_config.prompt_template.simple_prompt_template, + instruction=instruction, input=query ) @@ -223,7 +239,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): message_file_ids = [message_file.id for message_file, _ in message_files] except ToolProviderCredentialValidationError as e: - error_response = f"Please check your tool provider credentials" + error_response = "Please check your tool provider credentials" except ( ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError ) as e: @@ -300,6 +316,18 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): system_fingerprint='' ), PublishFrom.APPLICATION_MANAGER) + def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dict) -> str: + """ + fill in inputs from external data tools + """ + for key, value in inputs.items(): + try: + instruction = instruction.replace(f'{{{{{key}}}}}', str(value)) + except Exception as e: + continue + + return instruction + def _extract_response_scratchpad(self, content: str) -> AgentScratchpadUnit: """ extract response from llm response @@ -446,7 +474,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): next_iteration = agent_prompt_message.next_iteration if not isinstance(first_prompt, str) or not isinstance(next_iteration, str): - raise ValueError(f"first_prompt or next_iteration is required in CoT agent mode") + raise ValueError("first_prompt or next_iteration is required in CoT agent mode") # check instruction, tools, and tool_names slots if not first_prompt.find("{{instruction}}") >= 0: @@ -466,7 +494,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): if not next_iteration.find("{{observation}}") >= 0: raise ValueError("{{observation}} is required in next_iteration") - def _convert_scratchpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str: + def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str: """ convert agent scratchpad list to str """ @@ -479,13 +507,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): return result def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"], - prompt_messages: List[PromptMessage], - tools: List[PromptMessageTool], - agent_scratchpad: List[AgentScratchpadUnit], + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + agent_scratchpad: list[AgentScratchpadUnit], agent_prompt_message: AgentPromptEntity, instruction: str, input: str, - ) -> List[PromptMessage]: + ) -> list[PromptMessage]: """ organize chain of thought prompt messages, a standard prompt message is like: Respond to the human as helpfully and accurately as possible. diff --git a/api/core/features/assistant_fc_runner.py b/api/core/features/assistant_fc_runner.py index 8b42244838..7ad9d7bd2a 100644 --- a/api/core/features/assistant_fc_runner.py +++ b/api/core/features/assistant_fc_runner.py @@ -1,15 +1,27 @@ import json import logging -from typing import Any, Dict, Generator, List, Tuple, Union +from collections.abc import Generator +from typing import Any, Union from core.application_queue_manager import PublishFrom from core.features.assistant_base_runner import BaseAssistantApplicationRunner -from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, ToolPromptMessage, UserPromptMessage) -from core.tools.errors import (ToolInvokeError, ToolNotFoundError, ToolNotSupportedError, ToolParameterValidationError, - ToolProviderCredentialValidationError, ToolProviderNotFoundError) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.tools.errors import ( + ToolInvokeError, + ToolNotFoundError, + ToolNotSupportedError, + ToolParameterValidationError, + ToolProviderCredentialValidationError, + ToolProviderNotFoundError, +) from models.model import Conversation, Message, MessageAgentThought logger = logging.getLogger(__name__) @@ -33,7 +45,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): ) # convert tools into ModelRuntime Tool format - prompt_messages_tools: List[PromptMessageTool] = [] + prompt_messages_tools: list[PromptMessageTool] = [] tool_instances = {} for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []: try: @@ -59,13 +71,13 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): # continue to run until there is not any tool call function_call_state = True - agent_thoughts: List[MessageAgentThought] = [] + agent_thoughts: list[MessageAgentThought] = [] llm_usage = { 'usage': None } final_answer = '' - def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage): + def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if not final_llm_usage_dict['usage']: final_llm_usage_dict['usage'] = usage else: @@ -106,7 +118,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): callbacks=[], ) - tool_calls: List[Tuple[str, str, Dict[str, Any]]] = [] + tool_calls: list[tuple[str, str, dict[str, Any]]] = [] # save full response response = '' @@ -266,7 +278,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): message_file_ids.append(message_file.id) except ToolProviderCredentialValidationError as e: - error_response = f"Please check your tool provider credentials" + error_response = "Please check your tool provider credentials" except ( ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError ) as e: @@ -353,7 +365,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): return True return False - def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]: + def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: """ Extract tool calls from llm result chunk @@ -370,7 +382,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): return tool_calls - def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]: + def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: """ Extract blocking tool calls from llm result diff --git a/api/core/features/dataset_retrieval.py b/api/core/features/dataset_retrieval.py index f8fcea7c10..488a8ca8d0 100644 --- a/api/core/features/dataset_retrieval.py +++ b/api/core/features/dataset_retrieval.py @@ -1,4 +1,6 @@ -from typing import List, Optional, cast +from typing import Optional, cast + +from langchain.tools import BaseTool from core.agent.agent_executor import AgentConfiguration, AgentExecutor, PlanningStrategy from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler @@ -9,7 +11,6 @@ from core.model_runtime.model_providers.__base.large_language_model import Large from core.tools.tool.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool from extensions.ext_database import db -from langchain.tools import BaseTool from models.dataset import Dataset @@ -95,7 +96,7 @@ class DatasetRetrievalFeature: return_resource: bool, invoke_from: InvokeFrom, hit_callback: DatasetIndexToolCallbackHandler) \ - -> Optional[List[BaseTool]]: + -> Optional[list[BaseTool]]: """ A dataset tool is a tool that can be used to retrieve information from a dataset :param tenant_id: tenant id diff --git a/api/core/features/external_data_fetch.py b/api/core/features/external_data_fetch.py index 791fbf6ae3..7f23c8ed72 100644 --- a/api/core/features/external_data_fetch.py +++ b/api/core/features/external_data_fetch.py @@ -2,11 +2,12 @@ import concurrent import json import logging from concurrent.futures import ThreadPoolExecutor -from typing import Optional, Tuple +from typing import Optional + +from flask import Flask, current_app from core.entities.application_entities import ExternalDataVariableEntity from core.external_data_tool.factory import ExternalDataToolFactory -from flask import Flask, current_app logger = logging.getLogger(__name__) @@ -61,7 +62,7 @@ class ExternalDataFetchFeature: app_id: str, external_data_tool: ExternalDataVariableEntity, inputs: dict, - query: str) -> Tuple[Optional[str], Optional[str]]: + query: str) -> tuple[Optional[str], Optional[str]]: """ Query external data tool. :param flask_app: flask app diff --git a/api/core/features/moderation.py b/api/core/features/moderation.py index 9735fad0e7..a9d65f56e8 100644 --- a/api/core/features/moderation.py +++ b/api/core/features/moderation.py @@ -1,5 +1,4 @@ import logging -from typing import Tuple from core.entities.application_entities import AppOrchestrationConfigEntity from core.moderation.base import ModerationAction, ModerationException @@ -13,7 +12,7 @@ class ModerationFeature: tenant_id: str, app_orchestration_config_entity: AppOrchestrationConfigEntity, inputs: dict, - query: str) -> Tuple[bool, dict, str]: + query: str) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. :param app_id: app id diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py index 626dbbca43..435074f743 100644 --- a/api/core/file/file_obj.py +++ b/api/core/file/file_obj.py @@ -1,11 +1,12 @@ import enum from typing import Optional +from pydantic import BaseModel + from core.file.upload_file_parser import UploadFileParser from core.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db from models.model import UploadFile -from pydantic import BaseModel class FileType(enum.Enum): diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index c92f9e6950..1b7b8b87da 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -1,6 +1,7 @@ -from typing import Dict, List, Optional, Union +from typing import Optional, Union import requests + from core.file.file_obj import FileBelongsTo, FileObj, FileTransferMethod, FileType from extensions.ext_database import db from models.account import Account @@ -14,8 +15,8 @@ class MessageFileParser: self.tenant_id = tenant_id self.app_id = app_id - def validate_and_transform_files_arg(self, files: List[dict], app_model_config: AppModelConfig, - user: Union[Account, EndUser]) -> List[FileObj]: + def validate_and_transform_files_arg(self, files: list[dict], app_model_config: AppModelConfig, + user: Union[Account, EndUser]) -> list[FileObj]: """ validate and transform files arg @@ -95,7 +96,7 @@ class MessageFileParser: # return all file objs return new_files - def transform_message_files(self, files: List[MessageFile], app_model_config: Optional[AppModelConfig]) -> List[FileObj]: + def transform_message_files(self, files: list[MessageFile], app_model_config: Optional[AppModelConfig]) -> list[FileObj]: """ transform message files @@ -109,8 +110,8 @@ class MessageFileParser: # return all file objs return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] - def _to_file_objs(self, files: List[Union[Dict, MessageFile]], - file_upload_config: dict) -> Dict[FileType, List[FileObj]]: + def _to_file_objs(self, files: list[Union[dict, MessageFile]], + file_upload_config: dict) -> dict[FileType, list[FileObj]]: """ transform files to file objs @@ -118,7 +119,7 @@ class MessageFileParser: :param file_upload_config: :return: """ - type_file_objs: Dict[FileType, List[FileObj]] = { + type_file_objs: dict[FileType, list[FileObj]] = { # Currently only support image FileType.IMAGE: [] } diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py index ca63301a59..b259a911d8 100644 --- a/api/core/file/upload_file_parser.py +++ b/api/core/file/upload_file_parser.py @@ -6,9 +6,10 @@ import os import time from typing import Optional -from extensions.ext_storage import storage from flask import current_app +from extensions.ext_storage import storage + IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) diff --git a/api/core/generator/llm_generator.py b/api/core/generator/llm_generator.py index 2a15575360..072b02dc94 100644 --- a/api/core/generator/llm_generator.py +++ b/api/core/generator/llm_generator.py @@ -1,6 +1,8 @@ import json import logging +from langchain.schema import OutputParserException + from core.model_manager import ModelManager from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import ModelType @@ -9,7 +11,6 @@ from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorO from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser from core.prompt.prompt_template import PromptTemplateParser from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT -from langchain.schema import OutputParserException class LLMGenerator: diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py new file mode 100644 index 0000000000..0bfe763fac --- /dev/null +++ b/api/core/helper/ssrf_proxy.py @@ -0,0 +1,47 @@ +""" +Proxy requests to avoid SSRF +""" + +import os + +from httpx import get as _get +from httpx import head as _head +from httpx import options as _options +from httpx import patch as _patch +from httpx import post as _post +from httpx import put as _put +from requests import delete as _delete + +SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '') +SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '') + +requests_proxies = { + 'http': SSRF_PROXY_HTTP_URL, + 'https': SSRF_PROXY_HTTPS_URL +} if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None + +httpx_proxies = { + 'http://': SSRF_PROXY_HTTP_URL, + 'https://': SSRF_PROXY_HTTPS_URL +} if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None + +def get(url, *args, **kwargs): + return _get(url=url, *args, proxies=httpx_proxies, **kwargs) + +def post(url, *args, **kwargs): + return _post(url=url, *args, proxies=httpx_proxies, **kwargs) + +def put(url, *args, **kwargs): + return _put(url=url, *args, proxies=httpx_proxies, **kwargs) + +def patch(url, *args, **kwargs): + return _patch(url=url, *args, proxies=httpx_proxies, **kwargs) + +def delete(url, *args, **kwargs): + return _delete(url=url, *args, proxies=requests_proxies, **kwargs) + +def head(url, *args, **kwargs): + return _head(url=url, *args, proxies=httpx_proxies, **kwargs) + +def options(url, *args, **kwargs): + return _options(url=url, *args, proxies=httpx_proxies, **kwargs) diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index b2917682dc..58b551f295 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -1,10 +1,11 @@ from typing import Optional +from flask import Config, Flask +from pydantic import BaseModel + from core.entities.provider_entities import QuotaUnit, RestrictModel from core.model_runtime.entities.model_entities import ModelType -from flask import Config, Flask from models.provider import ProviderQuotaType -from pydantic import BaseModel class HostingQuota(BaseModel): @@ -123,7 +124,6 @@ class HostingConfiguration: restrict_models=[ RestrictModel(model="gpt-4", model_type=ModelType.LLM), RestrictModel(model="gpt-4-turbo-preview", model_type=ModelType.LLM), - RestrictModel(model="gpt-4-32k", model_type=ModelType.LLM), RestrictModel(model="gpt-4-1106-preview", model_type=ModelType.LLM), RestrictModel(model="gpt-4-0125-preview", model_type=ModelType.LLM), RestrictModel(model="gpt-3.5-turbo", model_type=ModelType.LLM), diff --git a/api/core/index/base.py b/api/core/index/base.py index 33178ff83b..f8eb1a134a 100644 --- a/api/core/index/base.py +++ b/api/core/index/base.py @@ -1,9 +1,10 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, List +from typing import Any from langchain.schema import BaseRetriever, Document + from models.dataset import Dataset @@ -52,7 +53,7 @@ class BaseIndex(ABC): def search( self, query: str, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: raise NotImplementedError def delete(self) -> None: diff --git a/api/core/index/index.py b/api/core/index/index.py index 56ce3c99c6..42971c895e 100644 --- a/api/core/index/index.py +++ b/api/core/index/index.py @@ -1,10 +1,11 @@ +from flask import current_app +from langchain.embeddings import OpenAIEmbeddings + from core.embedding.cached_embedding import CacheEmbedding from core.index.keyword_table_index.keyword_table_index import KeywordTableConfig, KeywordTableIndex from core.index.vector_index.vector_index import VectorIndex from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType -from flask import current_app -from langchain.embeddings import OpenAIEmbeddings from models.dataset import Dataset diff --git a/api/core/index/keyword_table_index/jieba_keyword_table_handler.py b/api/core/index/keyword_table_index/jieba_keyword_table_handler.py index fc07402206..df93a1903a 100644 --- a/api/core/index/keyword_table_index/jieba_keyword_table_handler.py +++ b/api/core/index/keyword_table_index/jieba_keyword_table_handler.py @@ -1,17 +1,17 @@ import re -from typing import Set import jieba -from core.index.keyword_table_index.stopwords import STOPWORDS from jieba.analyse import default_tfidf +from core.index.keyword_table_index.stopwords import STOPWORDS + class JiebaKeywordTableHandler: def __init__(self): default_tfidf.stop_words = STOPWORDS - def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]: + def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> set[str]: """Extract keywords with JIEBA tfidf.""" keywords = jieba.analyse.extract_tags( sentence=text, @@ -20,7 +20,7 @@ class JiebaKeywordTableHandler: return set(self._expand_tokens_with_subtokens(keywords)) - def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]: + def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]: """Get subtokens from a list of tokens., filtering for stopwords.""" results = set() for token in tokens: diff --git a/api/core/index/keyword_table_index/keyword_table_index.py b/api/core/index/keyword_table_index/keyword_table_index.py index 06eef1ebf2..8bf0b13344 100644 --- a/api/core/index/keyword_table_index/keyword_table_index.py +++ b/api/core/index/keyword_table_index/keyword_table_index.py @@ -1,13 +1,14 @@ import json from collections import defaultdict -from typing import Any, Dict, List, Optional +from typing import Any, Optional + +from langchain.schema import BaseRetriever, Document +from pydantic import BaseModel, Extra, Field from core.index.base import BaseIndex from core.index.keyword_table_index.jieba_keyword_table_handler import JiebaKeywordTableHandler from extensions.ext_database import db -from langchain.schema import BaseRetriever, Document from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment -from pydantic import BaseModel, Extra, Field class KeywordTableConfig(BaseModel): @@ -115,7 +116,7 @@ class KeywordTableIndex(BaseIndex): def search( self, query: str, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: keyword_table = self._get_dataset_keyword_table() search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {} @@ -220,7 +221,7 @@ class KeywordTableIndex(BaseIndex): keywords = keyword_table_handler.extract_keywords(query) # go through text chunks in order of most matching keywords - chunk_indices_count: Dict[str, int] = defaultdict(int) + chunk_indices_count: dict[str, int] = defaultdict(int) keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())] for keyword in keywords: for node_id in keyword_table[keyword]: @@ -234,7 +235,7 @@ class KeywordTableIndex(BaseIndex): return sorted_chunk_indices[: k] - def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: List[str]): + def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): document_segment = db.session.query(DocumentSegment).filter( DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id @@ -243,7 +244,7 @@ class KeywordTableIndex(BaseIndex): document_segment.keywords = keywords db.session.commit() - def create_segment_keywords(self, node_id: str, keywords: List[str]): + def create_segment_keywords(self, node_id: str, keywords: list[str]): keyword_table = self._get_dataset_keyword_table() self._update_segment_keywords(self.dataset.id, node_id, keywords) keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) @@ -265,7 +266,7 @@ class KeywordTableIndex(BaseIndex): keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords)) self._save_dataset_keyword_table(keyword_table) - def update_segment_keywords_index(self, node_id: str, keywords: List[str]): + def update_segment_keywords_index(self, node_id: str, keywords: list[str]): keyword_table = self._get_dataset_keyword_table() keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) self._save_dataset_keyword_table(keyword_table) @@ -281,7 +282,7 @@ class KeywordTableRetriever(BaseRetriever, BaseModel): extra = Extra.forbid arbitrary_types_allowed = True - def get_relevant_documents(self, query: str) -> List[Document]: + def get_relevant_documents(self, query: str) -> list[Document]: """Get documents relevant for a query. Args: @@ -292,7 +293,7 @@ class KeywordTableRetriever(BaseRetriever, BaseModel): """ return self.index.search(query, **self.search_kwargs) - async def aget_relevant_documents(self, query: str) -> List[Document]: + async def aget_relevant_documents(self, query: str) -> list[Document]: raise NotImplementedError("KeywordTableRetriever does not support async") diff --git a/api/core/index/vector_index/base.py b/api/core/index/vector_index/base.py index ccc1833821..36aa1917a6 100644 --- a/api/core/index/vector_index/base.py +++ b/api/core/index/vector_index/base.py @@ -1,16 +1,16 @@ import json import logging from abc import abstractmethod -from typing import Any, List, cast +from typing import Any, cast -from core.index.base import BaseIndex -from extensions.ext_database import db from langchain.embeddings.base import Embeddings from langchain.schema import BaseRetriever, Document from langchain.vectorstores import VectorStore -from models.dataset import Dataset, DatasetCollectionBinding + +from core.index.base import BaseIndex +from extensions.ext_database import db +from models.dataset import Dataset, DatasetCollectionBinding, DocumentSegment from models.dataset import Document as DatasetDocument -from models.dataset import DocumentSegment class BaseVectorIndex(BaseIndex): @@ -43,13 +43,13 @@ class BaseVectorIndex(BaseIndex): def search_by_full_text_index( self, query: str, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: raise NotImplementedError def search( self, query: str, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: vector_store = self._get_vector_store() vector_store = cast(self._get_vector_store_class(), vector_store) diff --git a/api/core/index/vector_index/milvus_vector_index.py b/api/core/index/vector_index/milvus_vector_index.py index 67ba5a7b32..a18cf35a27 100644 --- a/api/core/index/vector_index/milvus_vector_index.py +++ b/api/core/index/vector_index/milvus_vector_index.py @@ -1,13 +1,14 @@ -from typing import Any, List, cast +from typing import Any, cast + +from langchain.embeddings.base import Embeddings +from langchain.schema import Document +from langchain.vectorstores import VectorStore +from pydantic import BaseModel, root_validator from core.index.base import BaseIndex from core.index.vector_index.base import BaseVectorIndex from core.vector_store.milvus_vector_store import MilvusVectorStore -from langchain.embeddings.base import Embeddings -from langchain.schema import Document -from langchain.vectorstores import VectorStore from models.dataset import Dataset -from pydantic import BaseModel, root_validator class MilvusConfig(BaseModel): @@ -159,6 +160,6 @@ class MilvusVectorIndex(BaseVectorIndex): ], )) - def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]: + def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: # milvus/zilliz doesn't support bm25 search return [] diff --git a/api/core/index/vector_index/qdrant_vector_index.py b/api/core/index/vector_index/qdrant_vector_index.py index f755fe4101..046260d2f8 100644 --- a/api/core/index/vector_index/qdrant_vector_index.py +++ b/api/core/index/vector_index/qdrant_vector_index.py @@ -1,17 +1,18 @@ import os -from typing import Any, List, Optional, cast +from typing import Any, Optional, cast import qdrant_client +from langchain.embeddings.base import Embeddings +from langchain.schema import Document +from langchain.vectorstores import VectorStore +from pydantic import BaseModel +from qdrant_client.http.models import HnswConfigDiff + from core.index.base import BaseIndex from core.index.vector_index.base import BaseVectorIndex from core.vector_store.qdrant_vector_store import QdrantVectorStore from extensions.ext_database import db -from langchain.embeddings.base import Embeddings -from langchain.schema import BaseRetriever, Document -from langchain.vectorstores import VectorStore from models.dataset import Dataset, DatasetCollectionBinding -from pydantic import BaseModel -from qdrant_client.http.models import HnswConfigDiff class QdrantConfig(BaseModel): @@ -209,7 +210,7 @@ class QdrantVectorIndex(BaseVectorIndex): return False - def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]: + def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: vector_store = self._get_vector_store() vector_store = cast(self._get_vector_store_class(), vector_store) diff --git a/api/core/index/vector_index/vector_index.py b/api/core/index/vector_index/vector_index.py index 0a69c4f734..ed6e2699d6 100644 --- a/api/core/index/vector_index/vector_index.py +++ b/api/core/index/vector_index/vector_index.py @@ -1,9 +1,10 @@ import json +from flask import current_app +from langchain.embeddings.base import Embeddings + from core.index.vector_index.base import BaseVectorIndex from extensions.ext_database import db -from flask import current_app -from langchain.embeddings.base import Embeddings from models.dataset import Dataset, Document @@ -25,7 +26,7 @@ class VectorIndex: vector_type = self._dataset.index_struct_dict['type'] if not vector_type: - raise ValueError(f"Vector store must be specified.") + raise ValueError("Vector store must be specified.") if vector_type == "weaviate": from core.index.vector_index.weaviate_vector_index import WeaviateConfig, WeaviateVectorIndex diff --git a/api/core/index/vector_index/weaviate_vector_index.py b/api/core/index/vector_index/weaviate_vector_index.py index b4add6c11a..72a74a039f 100644 --- a/api/core/index/vector_index/weaviate_vector_index.py +++ b/api/core/index/vector_index/weaviate_vector_index.py @@ -1,15 +1,16 @@ -from typing import Any, List, Optional, cast +from typing import Any, Optional, cast import requests import weaviate +from langchain.embeddings.base import Embeddings +from langchain.schema import Document +from langchain.vectorstores import VectorStore +from pydantic import BaseModel, root_validator + from core.index.base import BaseIndex from core.index.vector_index.base import BaseVectorIndex from core.vector_store.weaviate_vector_store import WeaviateVectorStore -from langchain.embeddings.base import Embeddings -from langchain.schema import BaseRetriever, Document -from langchain.vectorstores import VectorStore from models.dataset import Dataset -from pydantic import BaseModel, root_validator class WeaviateConfig(BaseModel): @@ -171,7 +172,7 @@ class WeaviateVectorIndex(BaseVectorIndex): return False - def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]: + def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: vector_store = self._get_vector_store() vector_store = cast(self._get_vector_store_class(), vector_store) return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 2f1cf282f8..a14001d04e 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -5,7 +5,13 @@ import re import threading import time import uuid -from typing import AbstractSet, Any, Collection, List, Literal, Optional, Type, Union, cast +from typing import Optional, cast + +from flask import Flask, current_app +from flask_login import current_user +from langchain.schema import Document +from langchain.text_splitter import TextSplitter +from sqlalchemy.orm.exc import ObjectDeletedError from core.data_loader.file_extractor import FileExtractor from core.data_loader.loader.notion import NotionLoader @@ -17,22 +23,15 @@ from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import ModelType, PriceType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer -from core.spiltter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter +from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage -from flask import Flask, current_app -from flask_login import current_user -from langchain.schema import Document -from langchain.text_splitter import TS, TextSplitter, TokenTextSplitter from libs import helper -from models.dataset import Dataset, DatasetProcessRule +from models.dataset import Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument -from models.dataset import DocumentSegment from models.model import UploadFile from models.source import DataSourceBinding -from sqlalchemy.orm.exc import ObjectDeletedError class IndexingRunner: @@ -41,7 +40,7 @@ class IndexingRunner: self.storage = storage self.model_manager = ModelManager() - def run(self, dataset_documents: List[DatasetDocument]): + def run(self, dataset_documents: list[DatasetDocument]): """Run the indexing process.""" for dataset_document in dataset_documents: try: @@ -239,7 +238,7 @@ class IndexingRunner: dataset_document.stopped_at = datetime.datetime.utcnow() db.session.commit() - def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict, + def file_indexing_estimate(self, tenant_id: str, file_details: list[UploadFile], tmp_processing_rule: dict, doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, indexing_technique: str = 'economy') -> dict: """ @@ -495,7 +494,7 @@ class IndexingRunner: "preview": preview_texts } - def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> List[Document]: + def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> list[Document]: # load file if dataset_document.data_source_type not in ["upload_file", "notion_import"]: return [] @@ -527,7 +526,7 @@ class IndexingRunner: ) # replace doc id to document model id - text_docs = cast(List[Document], text_docs) + text_docs = cast(list[Document], text_docs) for text_doc in text_docs: # remove invalid symbol text_doc.page_content = self.filter_string(text_doc.page_content) @@ -541,7 +540,7 @@ class IndexingRunner: text = re.sub(r'\|>', '>', text) text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) # Unicode U+FFFE - text = re.sub(u'\uFFFE', '', text) + text = re.sub('\uFFFE', '', text) return text def _get_splitter(self, processing_rule: DatasetProcessRule, @@ -578,9 +577,9 @@ class IndexingRunner: return character_splitter - def _step_split(self, text_docs: List[Document], splitter: TextSplitter, + def _step_split(self, text_docs: list[Document], splitter: TextSplitter, dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \ - -> List[Document]: + -> list[Document]: """ Split the text documents into documents and save them to the document segment. """ @@ -625,9 +624,9 @@ class IndexingRunner: return documents - def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter, + def _split_to_documents(self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule, tenant_id: str, - document_form: str, document_language: str) -> List[Document]: + document_form: str, document_language: str) -> list[Document]: """ Split the text documents into nodes. """ @@ -700,8 +699,8 @@ class IndexingRunner: all_qa_documents.extend(format_documents) - def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter, - processing_rule: DatasetProcessRule) -> List[Document]: + def _split_to_documents_for_estimate(self, text_docs: list[Document], splitter: TextSplitter, + processing_rule: DatasetProcessRule) -> list[Document]: """ Split the text documents into nodes. """ @@ -771,7 +770,7 @@ class IndexingRunner: for q, a in matches if q and a ] - def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None: + def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]) -> None: """ Build the index for the document. """ @@ -878,7 +877,7 @@ class IndexingRunner: DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) db.session.commit() - def batch_add_segments(self, segments: List[DocumentSegment], dataset: Dataset): + def batch_add_segments(self, segments: list[DocumentSegment], dataset: Dataset): """ Batch add segments index processing """ diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 663daa0856..f1f8ab3a3b 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,7 +1,12 @@ from core.file.message_file_parser import MessageFileParser from core.model_manager import ModelInstance -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole, - TextPromptMessageContent, UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageRole, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers import model_provider_factory from extensions.ext_database import db diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 8a622e4f5b..aa16cf866f 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,4 +1,5 @@ -from typing import IO, Generator, List, Optional, Union, cast +from collections.abc import Generator +from typing import IO, Optional, Union, cast from core.entities.provider_configuration import ProviderModelBundle from core.errors.error import ProviderTokenNotInitError @@ -47,7 +48,7 @@ class ModelInstance: return credentials def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ -> Union[LLMResult, Generator]: """ @@ -63,7 +64,7 @@ class ModelInstance: :return: full response or stream response chunk generator result """ if not isinstance(self.model_type_instance, LargeLanguageModel): - raise Exception(f"Model type instance is not LargeLanguageModel") + raise Exception("Model type instance is not LargeLanguageModel") self.model_type_instance = cast(LargeLanguageModel, self.model_type_instance) return self.model_type_instance.invoke( @@ -88,7 +89,7 @@ class ModelInstance: :return: embeddings result """ if not isinstance(self.model_type_instance, TextEmbeddingModel): - raise Exception(f"Model type instance is not TextEmbeddingModel") + raise Exception("Model type instance is not TextEmbeddingModel") self.model_type_instance = cast(TextEmbeddingModel, self.model_type_instance) return self.model_type_instance.invoke( @@ -98,7 +99,8 @@ class ModelInstance: user=user ) - def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, + def invoke_rerank(self, query: str, docs: list[str], score_threshold: Optional[float] = None, + top_n: Optional[int] = None, user: Optional[str] = None) \ -> RerankResult: """ @@ -112,7 +114,7 @@ class ModelInstance: :return: rerank result """ if not isinstance(self.model_type_instance, RerankModel): - raise Exception(f"Model type instance is not RerankModel") + raise Exception("Model type instance is not RerankModel") self.model_type_instance = cast(RerankModel, self.model_type_instance) return self.model_type_instance.invoke( @@ -135,7 +137,7 @@ class ModelInstance: :return: false if text is safe, true otherwise """ if not isinstance(self.model_type_instance, ModerationModel): - raise Exception(f"Model type instance is not ModerationModel") + raise Exception("Model type instance is not ModerationModel") self.model_type_instance = cast(ModerationModel, self.model_type_instance) return self.model_type_instance.invoke( @@ -155,7 +157,7 @@ class ModelInstance: :return: text for given audio file """ if not isinstance(self.model_type_instance, Speech2TextModel): - raise Exception(f"Model type instance is not Speech2TextModel") + raise Exception("Model type instance is not Speech2TextModel") self.model_type_instance = cast(Speech2TextModel, self.model_type_instance) return self.model_type_instance.invoke( @@ -165,18 +167,20 @@ class ModelInstance: user=user ) - def invoke_tts(self, content_text: str, streaming: bool, user: Optional[str] = None) \ + def invoke_tts(self, content_text: str, tenant_id: str, voice: str, streaming: bool, user: Optional[str] = None) \ -> str: """ - Invoke large language model + Invoke large language tts model :param content_text: text content to be translated + :param tenant_id: user tenant id :param user: unique user id + :param voice: model timbre :param streaming: output is streaming :return: text for given audio file """ if not isinstance(self.model_type_instance, TTSModel): - raise Exception(f"Model type instance is not TTSModel") + raise Exception("Model type instance is not TTSModel") self.model_type_instance = cast(TTSModel, self.model_type_instance) return self.model_type_instance.invoke( @@ -184,9 +188,28 @@ class ModelInstance: credentials=self.credentials, content_text=content_text, user=user, + tenant_id=tenant_id, + voice=voice, streaming=streaming ) + def get_tts_voices(self, language: str) -> list: + """ + Invoke large language tts model voices + + :param language: tts language + :return: tts model voices + """ + if not isinstance(self.model_type_instance, TTSModel): + raise Exception("Model type instance is not TTSModel") + + self.model_type_instance = cast(TTSModel, self.model_type_instance) + return self.model_type_instance.get_tts_model_voices( + model=self.model, + credentials=self.credentials, + language=language + ) + class ModelManager: def __init__(self) -> None: diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index 58150ef4da..51af9786fd 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import List, Optional +from typing import Optional from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool @@ -23,7 +23,7 @@ class Callback(ABC): def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ Before invoke callback @@ -42,7 +42,7 @@ class Callback(ABC): def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None): """ On new chunk callback @@ -62,7 +62,7 @@ class Callback(ABC): def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ After invoke callback @@ -82,7 +82,7 @@ class Callback(ABC): def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ Invoke error callback diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py index e6268a7b09..0406853b88 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/core/model_runtime/callbacks/logging_callback.py @@ -1,7 +1,7 @@ import json import logging import sys -from typing import List, Optional +from typing import Optional from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) class LoggingCallback(Callback): def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ Before invoke callback @@ -30,7 +30,7 @@ class LoggingCallback(Callback): """ self.print_text("\n[on_llm_before_invoke]\n", color='blue') self.print_text(f"Model: {model}\n", color='blue') - self.print_text(f"Parameters:\n", color='blue') + self.print_text("Parameters:\n", color='blue') for key, value in model_parameters.items(): self.print_text(f"\t{key}: {value}\n", color='blue') @@ -38,7 +38,7 @@ class LoggingCallback(Callback): self.print_text(f"\tstop: {stop}\n", color='blue') if tools: - self.print_text(f"\tTools:\n", color='blue') + self.print_text("\tTools:\n", color='blue') for tool in tools: self.print_text(f"\t\t{tool.name}\n", color='blue') @@ -47,7 +47,7 @@ class LoggingCallback(Callback): if user: self.print_text(f"User: {user}\n", color='blue') - self.print_text(f"Prompt messages:\n", color='blue') + self.print_text("Prompt messages:\n", color='blue') for prompt_message in prompt_messages: if prompt_message.name: self.print_text(f"\tname: {prompt_message.name}\n", color='blue') @@ -60,7 +60,7 @@ class LoggingCallback(Callback): def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None): """ On new chunk callback @@ -81,7 +81,7 @@ class LoggingCallback(Callback): def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ After invoke callback @@ -101,7 +101,7 @@ class LoggingCallback(Callback): self.print_text(f"Content: {result.message.content}\n", color='yellow') if result.message.tool_calls: - self.print_text(f"Tool calls:\n", color='yellow') + self.print_text("Tool calls:\n", color='yellow') for tool_call in result.message.tool_calls: self.print_text(f"\t{tool_call.id}\n", color='yellow') self.print_text(f"\t{tool_call.function.name}\n", color='yellow') @@ -113,7 +113,7 @@ class LoggingCallback(Callback): def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ Invoke error callback diff --git a/api/core/model_runtime/docs/en_US/provider_scale_out.md b/api/core/model_runtime/docs/en_US/provider_scale_out.md index d93a5426b5..ba356c5cab 100644 --- a/api/core/model_runtime/docs/en_US/provider_scale_out.md +++ b/api/core/model_runtime/docs/en_US/provider_scale_out.md @@ -161,7 +161,7 @@ In `llm.py`, create an Anthropic LLM class, which we name `AnthropicLargeLanguag ```python def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ diff --git a/api/core/model_runtime/docs/en_US/schema.md b/api/core/model_runtime/docs/en_US/schema.md index 9606579e1c..61cd2c32d4 100644 --- a/api/core/model_runtime/docs/en_US/schema.md +++ b/api/core/model_runtime/docs/en_US/schema.md @@ -48,6 +48,10 @@ - `file_upload_limit` (int) Maximum file upload limit, in MB (available for model type `speech2text`) - `supported_file_extensions` (string) Supported file extension formats, e.g., mp3, mp4 (available for model type `speech2text`) - `default_voice` (string) default voice, e.g.:alloy,echo,fable,onyx,nova,shimmer(available for model type `tts`) + - `voices` (list) List of available voice.(available for model type `tts`) + - `mode` (string) voice model.(available for model type `tts`) + - `name` (string) voice model display name.(available for model type `tts`) + - `lanuage` (string) the voice model supports languages.(available for model type `tts`) - `word_limit` (int) Single conversion word limit, paragraphwise by default(available for model type `tts`) - `audio_type` (string) Support audio file extension format, e.g.:mp3,wav(available for model type `tts`) - `max_workers` (int) Number of concurrent workers supporting text and audio conversion(available for model type`tts`) diff --git a/api/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md b/api/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md index ccf78d0cdb..7b3a8edba3 100644 --- a/api/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md +++ b/api/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md @@ -127,7 +127,7 @@ provider_credential_schema: ```python def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ diff --git a/api/core/model_runtime/docs/zh_Hans/interfaces.md b/api/core/model_runtime/docs/zh_Hans/interfaces.md index 5bd505a0ee..743e575ded 100644 --- a/api/core/model_runtime/docs/zh_Hans/interfaces.md +++ b/api/core/model_runtime/docs/zh_Hans/interfaces.md @@ -128,7 +128,7 @@ class XinferenceProvider(Provider): ```python def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ diff --git a/api/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md b/api/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md index c90fb577ca..56f379a92f 100644 --- a/api/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md +++ b/api/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md @@ -77,7 +77,7 @@ pricing: # 价格信息 ```python def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ diff --git a/api/core/model_runtime/docs/zh_Hans/schema.md b/api/core/model_runtime/docs/zh_Hans/schema.md index 1eab541d24..55202a1a80 100644 --- a/api/core/model_runtime/docs/zh_Hans/schema.md +++ b/api/core/model_runtime/docs/zh_Hans/schema.md @@ -48,7 +48,11 @@ - `max_chunks` (int) 最大分块数量 (模型类型 `text-embedding ` `moderation` 可用) - `file_upload_limit` (int) 文件最大上传限制,单位:MB。(模型类型 `speech2text` 可用) - `supported_file_extensions` (string) 支持文件扩展格式,如:mp3,mp4(模型类型 `speech2text` 可用) - - `default_voice` (string) 缺省音色,可选:alloy,echo,fable,onyx,nova,shimmer(模型类型 `tts` 可用) + - `default_voice` (string) 缺省音色,必选:alloy,echo,fable,onyx,nova,shimmer(模型类型 `tts` 可用) + - `voices` (list) 可选音色列表。 + - `mode` (string) 音色模型。(模型类型 `tts` 可用) + - `name` (string) 音色模型显示名称。(模型类型 `tts` 可用) + - `lanuage` (string) 音色模型支持语言。(模型类型 `tts` 可用) - `word_limit` (int) 单次转换字数限制,默认按段落分段(模型类型 `tts` 可用) - `audio_type` (string) 支持音频文件扩展格式,如:mp3,wav(模型类型 `tts` 可用) - `max_workers` (int) 支持文字音频转换并发任务数(模型类型 `tts` 可用) diff --git a/api/core/model_runtime/entities/defaults.py b/api/core/model_runtime/entities/defaults.py index b39427dccd..856f4ce7d1 100644 --- a/api/core/model_runtime/entities/defaults.py +++ b/api/core/model_runtime/entities/defaults.py @@ -1,8 +1,7 @@ -from typing import Dict from core.model_runtime.entities.model_entities import DefaultParameterName -PARAMETER_RULE_TEMPLATE: Dict[DefaultParameterName, dict] = { +PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { DefaultParameterName.TEMPERATURE: { 'label': { 'en_US': 'Temperature', diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index 76d4ef310e..b5bd9e267a 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -2,9 +2,10 @@ from decimal import Decimal from enum import Enum from typing import Optional +from pydantic import BaseModel + from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo -from pydantic import BaseModel class LLMMode(Enum): diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 2041cb3a97..e35be27f86 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -2,9 +2,10 @@ from decimal import Decimal from enum import Enum from typing import Any, Optional -from core.model_runtime.entities.common_entities import I18nObject from pydantic import BaseModel +from core.model_runtime.entities.common_entities import I18nObject + class ModelType(Enum): """ @@ -126,6 +127,7 @@ class ModelPropertyKey(Enum): SUPPORTED_FILE_EXTENSIONS = "supported_file_extensions" MAX_CHARACTERS_PER_CHUNK = "max_characters_per_chunk" DEFAULT_VOICE = "default_voice" + VOICES = "voices" WORD_LIMIT = "word_limit" AUDOI_TYPE = "audio_type" MAX_WORKERS = "max_workers" diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py index bd55d60795..acc453bb84 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/core/model_runtime/entities/provider_entities.py @@ -1,9 +1,10 @@ from enum import Enum from typing import Optional +from pydantic import BaseModel + from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, ModelType, ProviderModel -from pydantic import BaseModel class ConfigurateMethod(Enum): diff --git a/api/core/model_runtime/entities/text_embedding_entities.py b/api/core/model_runtime/entities/text_embedding_entities.py index 499c76eb7d..7be3def379 100644 --- a/api/core/model_runtime/entities/text_embedding_entities.py +++ b/api/core/model_runtime/entities/text_embedding_entities.py @@ -1,8 +1,9 @@ from decimal import Decimal -from core.model_runtime.entities.model_entities import ModelUsage from pydantic import BaseModel +from core.model_runtime.entities.model_entities import ModelUsage + class EmbeddingUsage(ModelUsage): """ diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 11f9a7a6fb..a9f7a539e2 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -4,10 +4,18 @@ from abc import ABC, abstractmethod from typing import Optional import yaml + from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE -from core.model_runtime.entities.model_entities import (AIModelEntity, DefaultParameterName, FetchFrom, ModelType, - PriceConfig, PriceInfo, PriceType) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + ModelType, + PriceConfig, + PriceInfo, + PriceType, +) from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer @@ -145,7 +153,7 @@ class AIModel(ABC): # read _position.yaml file position_map = {} if os.path.exists(position_file_path): - with open(position_file_path, 'r', encoding='utf-8') as f: + with open(position_file_path, encoding='utf-8') as f: positions = yaml.safe_load(f) # convert list to dict with key as model provider name, value as index position_map = {position: index for index, position in enumerate(positions)} @@ -153,7 +161,7 @@ class AIModel(ABC): # traverse all model_schema_yaml_paths for model_schema_yaml_path in model_schema_yaml_paths: # read yaml data from yaml file - with open(model_schema_yaml_path, 'r', encoding='utf-8') as f: + with open(model_schema_yaml_path, encoding='utf-8') as f: yaml_data = yaml.safe_load(f) new_parameter_rules = [] diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 75ea7bacef..1f7edd245f 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -3,14 +3,20 @@ import os import re import time from abc import abstractmethod -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.callbacks.logging_callback import LoggingCallback from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool -from core.model_runtime.entities.model_entities import (ModelPropertyKey, ModelType, ParameterRule, ParameterType, - PriceType) +from core.model_runtime.entities.model_entities import ( + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, + PriceType, +) from core.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) @@ -24,7 +30,7 @@ class LargeLanguageModel(AIModel): def invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ -> Union[LLMResult, Generator]: """ @@ -117,7 +123,7 @@ class LargeLanguageModel(AIModel): def _invoke_result_generator(self, model: str, result: Generator, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator: """ Invoke result generator @@ -181,7 +187,7 @@ class LargeLanguageModel(AIModel): @abstractmethod def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -213,7 +219,7 @@ class LargeLanguageModel(AIModel): """ raise NotImplementedError - def enforce_stop_tokens(self, text: str, stop: List[str]) -> str: + def enforce_stop_tokens(self, text: str, stop: list[str]) -> str: """Cut off the text as soon as any stop words occur.""" return re.split("|".join(stop), text, maxsplit=1)[0] @@ -324,7 +330,7 @@ class LargeLanguageModel(AIModel): def _trigger_before_invoke_callbacks(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) -> None: """ Trigger before invoke callbacks @@ -362,7 +368,7 @@ class LargeLanguageModel(AIModel): def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) -> None: """ Trigger new chunk callbacks @@ -401,7 +407,7 @@ class LargeLanguageModel(AIModel): def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) -> None: """ Trigger after invoke callbacks @@ -441,7 +447,7 @@ class LargeLanguageModel(AIModel): def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) -> None: """ Trigger invoke error callbacks @@ -522,7 +528,7 @@ class LargeLanguageModel(AIModel): raise ValueError( f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.") elif parameter_rule.type == ParameterType.FLOAT: - if not isinstance(parameter_value, (float, int)): + if not isinstance(parameter_value, float | int): raise ValueError(f"Model Parameter {parameter_name} should be float.") # validate parameter value precision diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index a856d42588..97ce07d35f 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -1,9 +1,9 @@ import importlib import os from abc import ABC, abstractmethod -from typing import Dict, Optional import yaml + from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.entities.provider_entities import ProviderEntity from core.model_runtime.model_providers.__base.ai_model import AIModel @@ -11,7 +11,7 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel class ModelProvider(ABC): provider_schema: ProviderEntity = None - model_instance_map: Dict[str, AIModel] = {} + model_instance_map: dict[str, AIModel] = {} @abstractmethod def validate_provider_credentials(self, credentials: dict) -> None: @@ -46,7 +46,7 @@ class ModelProvider(ABC): yaml_path = os.path.join(current_path, f'{provider_name}.yaml') yaml_data = {} if os.path.exists(yaml_path): - with open(yaml_path, 'r', encoding='utf-8') as f: + with open(yaml_path, encoding='utf-8') as f: yaml_data = yaml.safe_load(f) try: diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index ff20cf7b9f..722d80c91e 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -15,29 +15,37 @@ class TTSModel(AIModel): """ model_type: ModelType = ModelType.TTS - def invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None): + def invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool, + user: Optional[str] = None): """ Invoke large language model :param model: model name + :param tenant_id: user tenant id :param credentials: model credentials + :param voice: model timbre :param content_text: text content to be translated :param streaming: output is streaming :param user: unique user id :return: translated audio file """ try: - return self._invoke(model=model, credentials=credentials, user=user, streaming=streaming, content_text=content_text) + self._is_ffmpeg_installed() + return self._invoke(model=model, credentials=credentials, user=user, streaming=streaming, + content_text=content_text, voice=voice, tenant_id=tenant_id) except Exception as e: raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None): + def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool, + user: Optional[str] = None): """ Invoke large language model :param model: model name + :param tenant_id: user tenant id :param credentials: model credentials + :param voice: model timbre :param content_text: text content to be translated :param streaming: output is streaming :param user: unique user id @@ -45,7 +53,25 @@ class TTSModel(AIModel): """ raise NotImplementedError - def _get_model_voice(self, model: str, credentials: dict) -> any: + def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: + """ + Get voice for given tts model voices + + :param language: tts language + :param model: model name + :param credentials: model credentials + :return: voices lists + """ + model_schema = self.get_model_schema(model, credentials) + + if model_schema and ModelPropertyKey.VOICES in model_schema.model_properties: + voices = model_schema.model_properties[ModelPropertyKey.VOICES] + if language: + return [{'name': d['name'], 'value': d['mode']} for d in voices if language and language in d.get('language')] + else: + return [{'name': d['name'], 'value': d['mode']} for d in voices] + + def _get_model_default_voice(self, model: str, credentials: dict) -> any: """ Get voice for given tts model diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/core/model_runtime/model_providers/_position.yaml index 9f7fe4c4f4..b2c6518395 100644 --- a/api/core/model_runtime/model_providers/_position.yaml +++ b/api/core/model_runtime/model_providers/_position.yaml @@ -14,6 +14,7 @@ - minimax - tongyi - wenxin +- moonshot - jina - chatglm - xinference diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 987f2fabf1..c743708896 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -1,23 +1,36 @@ -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union import anthropic from anthropic import Anthropic, Stream from anthropic.types import Completion, completion_create_params +from httpx import Timeout + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from httpx import Timeout class AnthropicLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -78,7 +91,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -243,7 +256,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): return message_text - def _convert_messages_to_prompt_anthropic(self, messages: List[PromptMessage]) -> str: + def _convert_messages_to_prompt_anthropic(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Anthropic model diff --git a/api/core/model_runtime/model_providers/azure_openai/_common.py b/api/core/model_runtime/model_providers/azure_openai/_common.py index 627b487357..b65138252b 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_common.py +++ b/api/core/model_runtime/model_providers/azure_openai/_common.py @@ -1,9 +1,16 @@ import openai -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) -from core.model_runtime.model_providers.azure_openai._constant import AZURE_OPENAI_API_VERSION from httpx import Timeout +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.model_providers.azure_openai._constant import AZURE_OPENAI_API_VERSION + class _CommonAzureOpenAI: @staticmethod diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index 8104df52dd..90dd2e7a6b 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -1,9 +1,18 @@ +from pydantic import BaseModel + from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.model_entities import (AIModelEntity, DefaultParameterName, FetchFrom, I18nObject, - ModelFeature, ModelPropertyKey, ModelType, ParameterRule, - PriceConfig) -from pydantic import BaseModel +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + I18nObject, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + PriceConfig, +) AZURE_OPENAI_API_VERSION = '2023-12-01-preview' diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index 326043aa39..4b89adaa49 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -1,24 +1,33 @@ import copy import logging -from typing import Generator, List, Optional, Union, cast +from collections.abc import Generator +from typing import Optional, Union, cast import tiktoken -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - PromptMessage, PromptMessageContentType, PromptMessageTool, - SystemPromptMessage, TextPromptMessageContent, - ToolPromptMessage, UserPromptMessage) -from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI -from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS, AzureBaseModel from openai import AzureOpenAI, Stream from openai.types import Completion from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall from openai.types.chat.chat_completion_message import FunctionCall +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI +from core.model_runtime.model_providers.azure_openai._constant import LLM_BASE_MODELS, AzureBaseModel + logger = logging.getLogger(__name__) @@ -26,7 +35,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: @@ -113,7 +122,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return ai_model_entity.entity if ai_model_entity else None def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, + prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) @@ -231,7 +240,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): def _chat_generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) @@ -529,7 +538,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, credentials: dict, messages: List[PromptMessage], + def _num_tokens_from_messages(self, credentials: dict, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. diff --git a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py index e472151cb5..e073bef014 100644 --- a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py @@ -1,17 +1,18 @@ import base64 import copy import time -from typing import Optional, Tuple, Union +from typing import Optional, Union import numpy as np import tiktoken +from openai import AzureOpenAI + from core.model_runtime.entities.model_entities import AIModelEntity, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI from core.model_runtime.model_providers.azure_openai._constant import EMBEDDING_BASE_MODELS, AzureBaseModel -from openai import AzureOpenAI class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): @@ -148,7 +149,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): @staticmethod def _embedding_invoke(model: str, client: AzureOpenAI, texts: Union[list[str], str], - extra_model_kwargs: dict) -> Tuple[list[list[float]], int]: + extra_model_kwargs: dict) -> tuple[list[list[float]], int]: response = client.embeddings.create( input=texts, model=model, diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py index 4562bb2be7..7549b2fb60 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py @@ -1,7 +1,7 @@ import re -class BaichuanTokenizer(object): +class BaichuanTokenizer: @classmethod def count_chinese_characters(cls, text: str) -> int: return len(re.findall(r'[\u4e00-\u9fa5]', text)) diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py index 48ed86f66b..639f6a21ce 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py @@ -1,18 +1,20 @@ +from collections.abc import Generator from enum import Enum from hashlib import md5 from json import dumps, loads -from os.path import join -from time import time -from typing import Any, Dict, Generator, List, Optional, Union +from typing import Any, Union -from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (BadRequestError, - InsufficientAccountBalance, - InternalServerError, - InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError) from requests import post +from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( + BadRequestError, + InsufficientAccountBalance, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) + class BaichuanMessage: class Role(Enum): @@ -23,10 +25,10 @@ class BaichuanMessage: role: str = Role.USER.value content: str - usage: Dict[str, int] = None + usage: dict[str, int] = None stop_reason: str = '' - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { 'role': self.role, 'content': self.content, @@ -36,7 +38,7 @@ class BaichuanMessage: self.content = content self.role = role -class BaichuanModel(object): +class BaichuanModel: api_key: str secret_key: str @@ -105,9 +107,9 @@ class BaichuanModel(object): message.stop_reason = stop_reason yield message - def _build_parameters(self, model: str, stream: bool, messages: List[BaichuanMessage], - parameters: Dict[str, Any]) \ - -> Dict[str, Any]: + def _build_parameters(self, model: str, stream: bool, messages: list[BaichuanMessage], + parameters: dict[str, Any]) \ + -> dict[str, Any]: if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b': prompt_messages = [] for message in messages: @@ -126,8 +128,10 @@ class BaichuanModel(object): 'role': message.role, }) # [baichuan] frequency_penalty must be between 1 and 2 - if parameters['frequency_penalty'] < 1 or parameters['frequency_penalty'] > 2: - parameters['frequency_penalty'] = 1 + if 'frequency_penalty' in parameters: + if parameters['frequency_penalty'] < 1 or parameters['frequency_penalty'] > 2: + parameters['frequency_penalty'] = 1 + # turbo api accepts flat parameters return { 'model': self._model_mapping(model), @@ -138,7 +142,7 @@ class BaichuanModel(object): else: raise BadRequestError(f"Unknown model: {model}") - def _build_headers(self, model: str, data: Dict[str, Any]) -> Dict[str, Any]: + def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]: if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b': # there is no secret key for turbo api return { @@ -152,8 +156,8 @@ class BaichuanModel(object): def _calculate_md5(self, input_string): return md5(input_string.encode('utf-8')).hexdigest() - def generate(self, model: str, stream: bool, messages: List[BaichuanMessage], - parameters: Dict[str, Any], timeout: int) \ + def generate(self, model: str, stream: bool, messages: list[BaichuanMessage], + parameters: dict[str, Any], timeout: int) \ -> Union[Generator, BaichuanMessage]: if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b': diff --git a/api/core/model_runtime/model_providers/baichuan/llm/llm.py b/api/core/model_runtime/model_providers/baichuan/llm/llm.py index c8bb1feb52..4278120093 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/llm.py @@ -1,26 +1,40 @@ -from typing import Generator, List, Optional, Union, cast +from collections.abc import Generator +from typing import cast -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo import BaichuanMessage, BaichuanModel -from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (BadRequestError, - InsufficientAccountBalance, - InternalServerError, - InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError) +from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( + BadRequestError, + InsufficientAccountBalance, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) class BaichuanLarguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: List[str] | None = None, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, @@ -30,7 +44,7 @@ class BaichuanLarguageModel(LargeLanguageModel): tools: list[PromptMessageTool] | None = None) -> int: return self._num_tokens_from_messages(prompt_messages) - def _num_tokens_from_messages(self, messages: List[PromptMessage],) -> int: + def _num_tokens_from_messages(self, messages: list[PromptMessage],) -> int: """Calculate num tokens for baichuan model""" def tokens(text: str): return BaichuanTokenizer._get_num_tokens(text) @@ -89,15 +103,15 @@ class BaichuanLarguageModel(LargeLanguageModel): ], parameters={ 'max_tokens': 1, }, timeout=60) - except (InvalidAPIKeyError, InvalidAuthenticationError) as e: + except Exception as e: raise CredentialsValidateFailedError(f"Invalid API key: {e}") def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: if tools is not None and len(tools) > 0: - raise InvokeBadRequestError(f"Baichuan model doesn't support tools") + raise InvokeBadRequestError("Baichuan model doesn't support tools") instance = BaichuanModel( api_key=credentials['api_key'], diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index 20aafea1eb..da4ba55881 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -1,21 +1,30 @@ import time -from json import dumps, loads -from typing import Optional, Tuple +from json import dumps +from typing import Optional + +from requests import post from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.baichuan.llm.baichuan_tokenizer import BaichuanTokenizer -from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import (BadRequestError, - InsufficientAccountBalance, - InternalServerError, - InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError) -from requests import post +from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors import ( + BadRequestError, + InsufficientAccountBalance, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) class BaichuanTextEmbeddingModel(TextEmbeddingModel): @@ -75,7 +84,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): return result def embedding(self, model: str, api_key, texts: list[str], user: Optional[str] = None) \ - -> Tuple[list[list[float]], int]: + -> tuple[list[list[float]], int]: """ Embed given texts diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 6a9c695350..c6aaa24ade 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -1,16 +1,34 @@ import json import logging -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union import boto3 from botocore.config import Config -from botocore.exceptions import (ClientError, EndpointConnectionError, NoRegionError, ServiceNotInRegionError, - UnknownServiceError) +from botocore.exceptions import ( + ClientError, + EndpointConnectionError, + NoRegionError, + ServiceNotInRegionError, + UnknownServiceError, +) + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -20,7 +38,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -142,7 +160,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): return message_text - def _convert_messages_to_prompt(self, messages: List[PromptMessage], model_prefix: str) -> str: + def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str) -> str: """ Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models @@ -164,7 +182,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() - def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, stream: bool = True): + def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): """ Create payload for bedrock api call depending on model provider """ @@ -214,7 +232,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke large language model diff --git a/api/core/model_runtime/model_providers/chatglm/llm/llm.py b/api/core/model_runtime/model_providers/chatglm/llm/llm.py index 6f78f7aa88..12dc75aece 100644 --- a/api/core/model_runtime/model_providers/chatglm/llm/llm.py +++ b/api/core/model_runtime/model_providers/chatglm/llm/llm.py @@ -1,30 +1,52 @@ import logging -from json import dumps +from collections.abc import Generator from os.path import join -from typing import Generator, List, Optional, cast +from typing import Optional, cast + +from httpx import Timeout +from openai import ( + APIConnectionError, + APITimeoutError, + AuthenticationError, + ConflictError, + InternalServerError, + NotFoundError, + OpenAI, + PermissionDeniedError, + RateLimitError, + Stream, + UnprocessableEntityError, +) +from openai.types.chat import ChatCompletion, ChatCompletionChunk +from openai.types.chat.chat_completion_message import FunctionCall from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageFunction, - PromptMessageTool, SystemPromptMessage, ToolPromptMessage, - UserPromptMessage) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils import helper -from httpx import Timeout -from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError, - NotFoundError, OpenAI, PermissionDeniedError, RateLimitError, Stream, UnprocessableEntityError) -from openai.types.chat import ChatCompletion, ChatCompletionChunk -from openai.types.chat.chat_completion_message import FunctionCall -from requests import post logger = logging.getLogger(__name__) class ChatGLMLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: List[str] | None = None, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ @@ -117,7 +139,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: List[str] | None = None, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ @@ -373,7 +395,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, messages: List[PromptMessage], + def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """Calculate num tokens for chatglm2 and chatglm3 with GPT2 tokenizer. diff --git a/api/core/model_runtime/model_providers/cohere/llm/llm.py b/api/core/model_runtime/model_providers/cohere/llm/llm.py index acff4177c3..667ba4c78c 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/llm.py +++ b/api/core/model_runtime/model_providers/cohere/llm/llm.py @@ -1,18 +1,31 @@ import logging -from typing import Generator, List, Optional, Tuple, Union, cast +from collections.abc import Generator +from typing import Optional, Union, cast import cohere from cohere.responses import Chat, Generations from cohere.responses.chat import StreamEnd, StreamingChat, StreamTextGeneration from cohere.responses.generation import StreamingGenerations, StreamingText + from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, - PromptMessageContentType, PromptMessageTool, - SystemPromptMessage, TextPromptMessageContent, - UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -26,7 +39,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -126,7 +139,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): raise CredentialsValidateFailedError(str(ex)) def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, + prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke llm model @@ -252,7 +265,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): break def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, + prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -294,7 +307,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop) def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat, - prompt_messages: list[PromptMessage], stop: Optional[List[str]] = None) \ + prompt_messages: list[PromptMessage], stop: Optional[list[str]] = None) \ -> LLMResult: """ Handle llm chat response @@ -340,7 +353,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat, prompt_messages: list[PromptMessage], - stop: Optional[List[str]] = None) -> Generator: + stop: Optional[list[str]] = None) -> Generator: """ Handle llm chat stream response @@ -415,7 +428,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): index += 1 def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \ - -> Tuple[str, list[dict]]: + -> tuple[str, list[dict]]: """ Convert prompt messages to message and chat histories :param prompt_messages: prompt messages @@ -483,7 +496,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): return response.length - def _num_tokens_from_messages(self, model: str, credentials: dict, messages: List[PromptMessage]) -> int: + def _num_tokens_from_messages(self, model: str, credentials: dict, messages: list[PromptMessage]) -> int: """Calculate num tokens Cohere model.""" messages = [self._convert_prompt_message_to_dict(m) for m in messages] message_strs = [f"{message['role']}: {message['message']}" for message in messages] diff --git a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py index 8c82cce766..7fee57f670 100644 --- a/api/core/model_runtime/model_providers/cohere/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/cohere/rerank/rerank.py @@ -1,9 +1,16 @@ from typing import Optional import cohere + from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.rerank_model import RerankModel diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py index a239727814..5eec721841 100644 --- a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -1,13 +1,20 @@ import time -from typing import Optional, Tuple +from typing import Optional import cohere import numpy as np from cohere.responses import Tokens + from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel @@ -161,7 +168,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> Tuple[list[list[float]], int]: + def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> tuple[list[list[float]], int]: """ Invoke embedding model diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 6fd5c9144c..686761ab5f 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -1,28 +1,41 @@ import logging -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union import google.api_core.exceptions as exceptions import google.generativeai as genai import google.generativeai.client as client -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, - PromptMessageContentType, PromptMessageRole, - PromptMessageTool, SystemPromptMessage, UserPromptMessage) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers import google -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from google.generativeai.types import ContentType, GenerateContentResponse, HarmBlockThreshold, HarmCategory from google.generativeai.types.content_types import to_part +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageContentType, + PromptMessageRole, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + logger = logging.getLogger(__name__) class GoogleLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -91,7 +104,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke large language model diff --git a/api/core/model_runtime/model_providers/huggingface_hub/_common.py b/api/core/model_runtime/model_providers/huggingface_hub/_common.py index 1140c947b9..dd8ae526e6 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/_common.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/_common.py @@ -1,6 +1,7 @@ -from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError from huggingface_hub.utils import BadRequestError, HfHubHTTPError +from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError + class _CommonHuggingfaceHub: diff --git a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py index e0701dff59..f43a8aedaf 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py @@ -1,23 +1,36 @@ -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) -from core.model_runtime.entities.model_entities import (AIModelEntity, DefaultParameterName, FetchFrom, - ModelPropertyKey, ModelType, ParameterRule) -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub from huggingface_hub import InferenceClient from huggingface_hub.hf_api import HfApi from huggingface_hub.utils import BadRequestError +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub + class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: client = InferenceClient(token=credentials['huggingfacehub_api_token']) diff --git a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py index f0dc632fae..0f0c166f3e 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/text_embedding/text_embedding.py @@ -4,13 +4,14 @@ from typing import Optional import numpy as np import requests +from huggingface_hub import HfApi, InferenceClient + from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.huggingface_hub._common import _CommonHuggingfaceHub -from huggingface_hub import HfApi, InferenceClient HUGGINGFACE_ENDPOINT_API = 'https://api.endpoints.huggingface.cloud/v2/endpoint/' diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py index 50238fbcde..5c146972cd 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py @@ -2,14 +2,21 @@ import time from json import JSONDecodeError, dumps from typing import Optional +from requests import post + from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.jina.text_embedding.jina_tokenizer import JinaTokenizer -from requests import post class JinaTextEmbeddingModel(TextEmbeddingModel): diff --git a/api/core/model_runtime/model_providers/localai/llm/llm.py b/api/core/model_runtime/model_providers/localai/llm/llm.py index 117ef8c399..694f5891f9 100644 --- a/api/core/model_runtime/model_providers/localai/llm/llm.py +++ b/api/core/model_runtime/model_providers/localai/llm/llm.py @@ -1,29 +1,59 @@ +from collections.abc import Generator from os.path import join -from typing import Generator, List, Optional, Union, cast +from typing import cast -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) -from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, - ParameterRule, ParameterType) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.utils import helper from httpx import Timeout -from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError, - NotFoundError, OpenAI, PermissionDeniedError, RateLimitError, Stream, UnprocessableEntityError) +from openai import ( + APIConnectionError, + APITimeoutError, + AuthenticationError, + ConflictError, + InternalServerError, + NotFoundError, + OpenAI, + PermissionDeniedError, + RateLimitError, + Stream, + UnprocessableEntityError, +) from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion_message import FunctionCall from openai.types.completion import Completion +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.utils import helper + class LocalAILarguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: List[str] | None = None, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, @@ -34,7 +64,7 @@ class LocalAILarguageModel(LargeLanguageModel): # tools is not supported yet return self._num_tokens_from_messages(prompt_messages, tools=tools) - def _num_tokens_from_messages(self, messages: List[PromptMessage], tools: list[PromptMessageTool]) -> int: + def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ Calculate num tokens for baichuan model LocalAI does not supports @@ -212,7 +242,7 @@ class LocalAILarguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: kwargs = self._to_client_kwargs(credentials) @@ -317,7 +347,7 @@ class LocalAILarguageModel(LargeLanguageModel): return message_dict - def _convert_prompt_message_to_completion_prompts(self, messages: List[PromptMessage]) -> str: + def _convert_prompt_message_to_completion_prompts(self, messages: list[PromptMessage]) -> str: """ Convert PromptMessage to completion prompts """ diff --git a/api/core/model_runtime/model_providers/localai/localai.py b/api/core/model_runtime/model_providers/localai/localai.py index 9ba94e5f21..6d2278fd54 100644 --- a/api/core/model_runtime/model_providers/localai/localai.py +++ b/api/core/model_runtime/model_providers/localai/localai.py @@ -1,7 +1,5 @@ import logging -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.model_provider import ModelProvider logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py index 511f09e3e7..39143127eb 100644 --- a/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/localai/text_embedding/text_embedding.py @@ -3,13 +3,20 @@ from json import JSONDecodeError, dumps from os.path import join from typing import Optional +from requests import post + from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from requests import post class LocalAITextEmbeddingModel(TextEmbeddingModel): diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py index 718ebb1013..6c41e0d2a5 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py @@ -1,22 +1,27 @@ -from hashlib import md5 +from collections.abc import Generator from json import dumps, loads -from time import time -from typing import Any, Dict, Generator, List, Union +from typing import Any, Union -from core.model_runtime.model_providers.minimax.llm.errors import (BadRequestError, InsufficientAccountBalanceError, - InternalServerError, InvalidAPIKeyError, - InvalidAuthenticationError, RateLimitReachedError) -from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage from requests import Response, post +from core.model_runtime.model_providers.minimax.llm.errors import ( + BadRequestError, + InsufficientAccountBalanceError, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) +from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage -class MinimaxChatCompletion(object): + +class MinimaxChatCompletion: """ Minimax Chat Completion API """ def generate(self, model: str, api_key: str, group_id: str, - prompt_messages: List[MinimaxMessage], model_parameters: dict, - tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \ + prompt_messages: list[MinimaxMessage], model_parameters: dict, + tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: """ generate chat completion diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py index 6233af26b6..81ea2e165e 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py @@ -1,23 +1,28 @@ -from hashlib import md5 +from collections.abc import Generator from json import dumps, loads -from time import time -from typing import Any, Dict, Generator, List, Union +from typing import Any, Union -from core.model_runtime.model_providers.minimax.llm.errors import (BadRequestError, InsufficientAccountBalanceError, - InternalServerError, InvalidAPIKeyError, - InvalidAuthenticationError, RateLimitReachedError) -from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage from requests import Response, post +from core.model_runtime.model_providers.minimax.llm.errors import ( + BadRequestError, + InsufficientAccountBalanceError, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) +from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage -class MinimaxChatCompletionPro(object): + +class MinimaxChatCompletionPro: """ Minimax Chat Completion Pro API, supports function calling however, we do not have enough time and energy to implement it, but the parameters are reserved """ def generate(self, model: str, api_key: str, group_id: str, - prompt_messages: List[MinimaxMessage], model_parameters: dict, - tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \ + prompt_messages: list[MinimaxMessage], model_parameters: dict, + tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: """ generate chat completion diff --git a/api/core/model_runtime/model_providers/minimax/llm/llm.py b/api/core/model_runtime/model_providers/minimax/llm/llm.py index 2657c85419..cc88d15736 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/llm.py +++ b/api/core/model_runtime/model_providers/minimax/llm/llm.py @@ -1,17 +1,34 @@ -from typing import Generator, List +from collections.abc import Generator from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, ToolPromptMessage, UserPromptMessage) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.minimax.llm.chat_completion import MinimaxChatCompletion from core.model_runtime.model_providers.minimax.llm.chat_completion_pro import MinimaxChatCompletionPro -from core.model_runtime.model_providers.minimax.llm.errors import (BadRequestError, InsufficientAccountBalanceError, - InternalServerError, InvalidAPIKeyError, - InvalidAuthenticationError, RateLimitReachedError) +from core.model_runtime.model_providers.minimax.llm.errors import ( + BadRequestError, + InsufficientAccountBalanceError, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage @@ -25,7 +42,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -62,7 +79,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): tools: list[PromptMessageTool] | None = None) -> int: return self._num_tokens_from_messages(prompt_messages, tools) - def _num_tokens_from_messages(self, messages: List[PromptMessage], tools: list[PromptMessageTool]) -> int: + def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ Calculate num tokens for minimax model @@ -77,7 +94,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface diff --git a/api/core/model_runtime/model_providers/minimax/llm/types.py b/api/core/model_runtime/model_providers/minimax/llm/types.py index 6229312445..b33a7ca9ac 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/types.py +++ b/api/core/model_runtime/model_providers/minimax/llm/types.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict +from typing import Any class MinimaxMessage: @@ -11,11 +11,11 @@ class MinimaxMessage: role: str = Role.USER.value content: str - usage: Dict[str, int] = None + usage: dict[str, int] = None stop_reason: str = '' - function_call: Dict[str, Any] = None + function_call: dict[str, Any] = None - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value: return { 'sender_type': 'BOT', diff --git a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py index 65f2a9a225..edf4d6005a 100644 --- a/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/minimax/text_embedding/text_embedding.py @@ -1,17 +1,29 @@ import time -from json import dumps, loads +from json import dumps from typing import Optional +from requests import post + from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.minimax.llm.errors import (BadRequestError, InsufficientAccountBalanceError, - InternalServerError, InvalidAPIKeyError, - InvalidAuthenticationError, RateLimitReachedError) -from requests import post +from core.model_runtime.model_providers.minimax.llm.errors import ( + BadRequestError, + InsufficientAccountBalanceError, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) class MinimaxTextEmbeddingModel(TextEmbeddingModel): diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 06932b018d..185ff62711 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -5,12 +5,13 @@ from collections import OrderedDict from typing import Optional import yaml +from pydantic import BaseModel + from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from core.model_runtime.model_providers.__base.model_provider import ModelProvider from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator -from pydantic import BaseModel logger = logging.getLogger(__name__) @@ -219,7 +220,7 @@ class ModelProviderFactory: # read _position.yaml file position_map = {} if os.path.exists(position_file_path): - with open(position_file_path, 'r', encoding='utf-8') as f: + with open(position_file_path, encoding='utf-8') as f: positions = yaml.safe_load(f) # convert list to dict with key as model provider name, value as index position_map = {position: index for index, position in enumerate(positions)} diff --git a/api/core/model_runtime/model_providers/moonshot/__init__.py b/api/core/model_runtime/model_providers/moonshot/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/moonshot/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/moonshot/_assets/icon_l_en.png new file mode 100644 index 0000000000..a411526d3d Binary files /dev/null and b/api/core/model_runtime/model_providers/moonshot/_assets/icon_l_en.png differ diff --git a/api/core/model_runtime/model_providers/moonshot/_assets/icon_s_en.png b/api/core/model_runtime/model_providers/moonshot/_assets/icon_s_en.png new file mode 100644 index 0000000000..58ba4b4623 Binary files /dev/null and b/api/core/model_runtime/model_providers/moonshot/_assets/icon_s_en.png differ diff --git a/api/core/model_runtime/model_providers/moonshot/llm/__init__.py b/api/core/model_runtime/model_providers/moonshot/llm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/moonshot/llm/_position.yaml b/api/core/model_runtime/model_providers/moonshot/llm/_position.yaml new file mode 100644 index 0000000000..1810ec61d6 --- /dev/null +++ b/api/core/model_runtime/model_providers/moonshot/llm/_position.yaml @@ -0,0 +1,3 @@ +- moonshot-v1-8k +- moonshot-v1-32k +- moonshot-v1-128k diff --git a/api/core/model_runtime/model_providers/moonshot/llm/llm.py b/api/core/model_runtime/model_providers/moonshot/llm/llm.py new file mode 100644 index 0000000000..5db3e2827b --- /dev/null +++ b/api/core/model_runtime/model_providers/moonshot/llm/llm.py @@ -0,0 +1,25 @@ +from collections.abc import Generator +from typing import Optional, Union + +from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel + + +class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + self._add_custom_parameters(credentials) + return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + + def validate_credentials(self, model: str, credentials: dict) -> None: + self._add_custom_parameters(credentials) + super().validate_credentials(model, credentials) + + @staticmethod + def _add_custom_parameters(credentials: dict) -> None: + credentials['mode'] = 'chat' + credentials['endpoint_url'] = 'https://api.moonshot.cn/v1' diff --git a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml new file mode 100644 index 0000000000..28bfaed98a --- /dev/null +++ b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-128k.yaml @@ -0,0 +1,25 @@ +model: moonshot-v1-128k +label: + zh_Hans: moonshot-v1-128k + en_US: moonshot-v1-128k +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 128000 +pricing: + input: '0.06' + output: '0.06' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml new file mode 100644 index 0000000000..0df1a837f9 --- /dev/null +++ b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-32k.yaml @@ -0,0 +1,25 @@ +model: moonshot-v1-32k +label: + zh_Hans: moonshot-v1-32k + en_US: moonshot-v1-32k +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 32000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 1024 + min: 1 + max: 32000 +pricing: + input: '0.024' + output: '0.024' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml new file mode 100644 index 0000000000..e4e0a0f069 --- /dev/null +++ b/api/core/model_runtime/model_providers/moonshot/llm/moonshot-v1-8k.yaml @@ -0,0 +1,25 @@ +model: moonshot-v1-8k +label: + zh_Hans: moonshot-v1-8k + en_US: moonshot-v1-8k +model_type: llm +features: + - agent-thought +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 8192 +pricing: + input: '0.012' + output: '0.012' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/moonshot/moonshot.py b/api/core/model_runtime/model_providers/moonshot/moonshot.py new file mode 100644 index 0000000000..5654ae1459 --- /dev/null +++ b/api/core/model_runtime/model_providers/moonshot/moonshot.py @@ -0,0 +1,30 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class MoonshotProvider(ModelProvider): + + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + try: + model_instance = self.get_model_instance(ModelType.LLM) + + model_instance.validate_credentials( + model='moonshot-v1-8k', + credentials=credentials + ) + except CredentialsValidateFailedError as ex: + raise ex + except Exception as ex: + logger.exception(f'{self.get_provider_schema().provider} credentials validate failed') + raise ex diff --git a/api/core/model_runtime/model_providers/moonshot/moonshot.yaml b/api/core/model_runtime/model_providers/moonshot/moonshot.yaml new file mode 100644 index 0000000000..1885ee9d94 --- /dev/null +++ b/api/core/model_runtime/model_providers/moonshot/moonshot.yaml @@ -0,0 +1,32 @@ +provider: moonshot +label: + zh_Hans: 月之暗面 + en_US: Moonshot +description: + en_US: Models provided by Moonshot, such as moonshot-v1-8k, moonshot-v1-32k, and moonshot-v1-128k. + zh_Hans: Moonshot 提供的模型,例如 moonshot-v1-8k、moonshot-v1-32k 和 moonshot-v1-128k。 +icon_small: + en_US: icon_s_en.png +icon_large: + en_US: icon_l_en.png +background: "#FFFFFF" +help: + title: + en_US: Get your API Key from Moonshot + zh_Hans: 从 Moonshot 获取 API Key + url: + en_US: https://platform.moonshot.cn/console/api-keys +supported_model_types: + - llm +configurate_methods: + - predefined-model +provider_credential_schema: + credential_form_schemas: + - variable: api_key + label: + en_US: API Key + type: secret-input + required: true + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key diff --git a/api/core/model_runtime/model_providers/ollama/llm/llm.py b/api/core/model_runtime/model_providers/ollama/llm/llm.py index 27c4be125a..e4388699e3 100644 --- a/api/core/model_runtime/model_providers/ollama/llm/llm.py +++ b/api/core/model_runtime/model_providers/ollama/llm/llm.py @@ -1,21 +1,44 @@ import json import logging import re +from collections.abc import Generator from decimal import Decimal -from typing import Generator, List, Optional, Union, cast +from typing import Optional, Union, cast from urllib.parse import urljoin import requests + from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - PromptMessage, PromptMessageContentType, PromptMessageTool, - SystemPromptMessage, TextPromptMessageContent, - UserPromptMessage) -from core.model_runtime.entities.model_entities import (AIModelEntity, DefaultParameterName, FetchFrom, I18nObject, - ModelFeature, ModelPropertyKey, ModelType, ParameterRule, - ParameterType, PriceConfig) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + I18nObject, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, + PriceConfig, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -29,7 +52,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -109,7 +132,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, + prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -376,7 +399,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): return message_dict - def _num_tokens_from_messages(self, messages: List[PromptMessage]) -> int: + def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int: """ Calculate num tokens. diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index 5d96ac65ff..fd73728b78 100644 --- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -7,12 +7,25 @@ from urllib.parse import urljoin import numpy as np import requests + from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, - PriceConfig, PriceType) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + PriceConfig, + PriceType, +) from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel diff --git a/api/core/model_runtime/model_providers/openai/_common.py b/api/core/model_runtime/model_providers/openai/_common.py index 91705c3ba8..436461c11e 100644 --- a/api/core/model_runtime/model_providers/openai/_common.py +++ b/api/core/model_runtime/model_providers/openai/_common.py @@ -1,8 +1,15 @@ import openai -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) from httpx import Timeout +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) + class _CommonOpenAI: def _to_credential_kwargs(self, credentials: dict) -> dict: diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 7722c69a95..2a1137d443 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -1,24 +1,31 @@ import logging -from typing import Generator, List, Optional, Union, cast +from collections.abc import Generator +from typing import Optional, Union, cast import tiktoken -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - PromptMessage, PromptMessageContentType, - PromptMessageFunction, PromptMessageTool, SystemPromptMessage, - TextPromptMessageContent, ToolPromptMessage, - UserPromptMessage) -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType, PriceConfig -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.openai._common import _CommonOpenAI -from core.model_runtime.utils import helper from openai import OpenAI, Stream from openai.types import Completion from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall from openai.types.chat.chat_completion_message import FunctionCall +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType, PriceConfig +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.openai._common import _CommonOpenAI + logger = logging.getLogger(__name__) @@ -29,7 +36,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -209,7 +216,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return ai_model_entities def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, + prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -360,7 +367,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): def _chat_generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -491,8 +498,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): continue delta = chunk.choices[0] + has_finish_reason = delta.finish_reason is not None - if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == '') and \ + if not has_finish_reason and (delta.delta.content is None or delta.delta.content == '') and \ delta.delta.function_call is None: continue @@ -514,7 +522,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if assistant_message_function_call: # start of stream function call delta_assistant_message_function_call_storage = assistant_message_function_call - continue + if not has_finish_reason: + continue # tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls) function_call = self._extract_response_function_call(assistant_message_function_call) @@ -528,7 +537,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): full_assistant_content += delta.delta.content if delta.delta.content else '' - if delta.finish_reason is not None: + if has_finish_reason: # calculate num tokens prompt_tokens = self._num_tokens_from_messages(model, prompt_messages, tools) @@ -698,7 +707,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, model: str, messages: List[PromptMessage], + def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. diff --git a/api/core/model_runtime/model_providers/openai/moderation/moderation.py b/api/core/model_runtime/model_providers/openai/moderation/moderation.py index 2a0901d752..b1d0e57ad2 100644 --- a/api/core/model_runtime/model_providers/openai/moderation/moderation.py +++ b/api/core/model_runtime/model_providers/openai/moderation/moderation.py @@ -1,11 +1,12 @@ from typing import Optional +from openai import OpenAI +from openai.types import ModerationCreateResponse + from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.moderation_model import ModerationModel from core.model_runtime.model_providers.openai._common import _CommonOpenAI -from openai import OpenAI -from openai.types import ModerationCreateResponse class OpenAIModerationModel(_CommonOpenAI, ModerationModel): diff --git a/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py index b2b337a563..efbdd054f9 100644 --- a/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/openai/speech2text/speech2text.py @@ -1,9 +1,10 @@ from typing import IO, Optional +from openai import OpenAI + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel from core.model_runtime.model_providers.openai._common import _CommonOpenAI -from openai import OpenAI class OpenAISpeech2TextModel(_CommonOpenAI, Speech2TextModel): diff --git a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py index 87a5cf1a2a..e23a2edf87 100644 --- a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py @@ -1,15 +1,16 @@ import base64 import time -from typing import Optional, Tuple, Union +from typing import Optional, Union import numpy as np import tiktoken +from openai import OpenAI + from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.openai._common import _CommonOpenAI -from openai import OpenAI class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): @@ -161,7 +162,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): raise CredentialsValidateFailedError(str(ex)) def _embedding_invoke(self, model: str, client: OpenAI, texts: Union[list[str], str], - extra_model_kwargs: dict) -> Tuple[list[list[float]], int]: + extra_model_kwargs: dict) -> tuple[list[list[float]], int]: """ Invoke embedding model diff --git a/api/core/model_runtime/model_providers/openai/tts/tts-1-hd.yaml b/api/core/model_runtime/model_providers/openai/tts/tts-1-hd.yaml index aa7ed537a4..72f15134ea 100644 --- a/api/core/model_runtime/model_providers/openai/tts/tts-1-hd.yaml +++ b/api/core/model_runtime/model_providers/openai/tts/tts-1-hd.yaml @@ -2,6 +2,30 @@ model: tts-1-hd model_type: tts model_properties: default_voice: 'alloy' + voices: + - mode: 'alloy' + name: 'Alloy' + language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ] + - mode: 'echo' + name: 'Echo' + language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ] + - mode: 'fable' + name: 'Fable' + language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ] + - mode: 'onyx' + name: 'Onyx' + language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ] + - mode: 'nova' + name: 'Nova' + language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ] + - mode: 'shimmer' + name: 'Shimmer' + language: [ 'zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID' ] word_limit: 120 audio_type: 'mp3' max_workers: 5 +pricing: + input: '0.03' + output: '0' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/tts/tts-1.yaml b/api/core/model_runtime/model_providers/openai/tts/tts-1.yaml index 96f54a7340..8d222fed64 100644 --- a/api/core/model_runtime/model_providers/openai/tts/tts-1.yaml +++ b/api/core/model_runtime/model_providers/openai/tts/tts-1.yaml @@ -2,6 +2,30 @@ model: tts-1 model_type: tts model_properties: default_voice: 'alloy' + voices: + - mode: 'alloy' + name: 'Alloy' + language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + - mode: 'echo' + name: 'Echo' + language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + - mode: 'fable' + name: 'Fable' + language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + - mode: 'onyx' + name: 'Onyx' + language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + - mode: 'nova' + name: 'Nova' + language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] + - mode: 'shimmer' + name: 'Shimmer' + language: ['zh-Hans', 'en-US', 'de-DE', 'fr-FR', 'es-ES', 'it-IT', 'th-TH', 'id-ID'] word_limit: 120 audio_type: 'mp3' max_workers: 5 +pricing: + input: '0.015' + output: '0' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/tts/tts.py b/api/core/model_runtime/model_providers/openai/tts/tts.py index 95a88e9bec..b1718c063c 100644 --- a/api/core/model_runtime/model_providers/openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/openai/tts/tts.py @@ -3,40 +3,48 @@ from functools import reduce from io import BytesIO from typing import Optional +from flask import Response, stream_with_context +from openai import OpenAI +from pydub import AudioSegment + from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.model_runtime.model_providers.openai._common import _CommonOpenAI -from flask import Response, stream_with_context -from openai import OpenAI -from pydub import AudioSegment +from extensions.ext_storage import storage class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): """ Model class for OpenAI Speech to text model. """ - def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None) -> any: + + def _invoke(self, model: str, tenant_id: str, credentials: dict, + content_text: str, voice: str, streaming: bool, user: Optional[str] = None) -> any: """ _invoke text2speech model :param model: model name + :param tenant_id: user tenant id :param credentials: model credentials :param content_text: text content to be translated + :param voice: model timbre :param streaming: output is streaming :param user: unique user id :return: text translated to audio file """ - self._is_ffmpeg_installed() audio_type = self._get_model_audio_type(model, credentials) + if not voice: + voice = self._get_model_default_voice(model, credentials) if streaming: return Response(stream_with_context(self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, - user=user)), + tenant_id=tenant_id, + voice=voice)), status=200, mimetype=f'audio/{audio_type}') else: - return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, user=user) + return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice) def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: """ @@ -51,91 +59,96 @@ class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): self._tts_invoke( model=model, credentials=credentials, - content_text='Hello world!', - user=user + content_text='Hello Dify!', + voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke(self, model: str, credentials: dict, content_text: str, user: Optional[str] = None) -> Response: + def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> Response: """ _tts_invoke text2speech model :param model: model name :param credentials: model credentials :param content_text: text content to be translated - :param user: unique user id + :param voice: model timbre :return: text translated to audio file """ audio_type = self._get_model_audio_type(model, credentials) word_limit = self._get_model_word_limit(model, credentials) max_workers = self._get_model_workers_limit(model, credentials) - try: sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) audio_bytes_list = list() # Create a thread pool and map the function to the list of sentences with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [executor.submit(self._process_sentence, sentence, model, credentials) for sentence - in sentences] + futures = [executor.submit(self._process_sentence, sentence=sentence, model=model, voice=voice, + credentials=credentials) for sentence in sentences] for future in futures: try: - audio_bytes_list.append(future.result()) + if future.result(): + audio_bytes_list.append(future.result()) except Exception as ex: raise InvokeBadRequestError(str(ex)) - audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in - audio_bytes_list if audio_bytes] - combined_segment = reduce(lambda x, y: x + y, audio_segments) - buffer: BytesIO = BytesIO() - combined_segment.export(buffer, format=audio_type) - buffer.seek(0) - return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}") + if len(audio_bytes_list) > 0: + audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in + audio_bytes_list if audio_bytes] + combined_segment = reduce(lambda x, y: x + y, audio_segments) + buffer: BytesIO = BytesIO() + combined_segment.export(buffer, format=audio_type) + buffer.seek(0) + return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}") except Exception as ex: raise InvokeBadRequestError(str(ex)) # Todo: To improve the streaming function - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, user: Optional[str] = None) -> any: + def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str, + voice: str) -> any: """ _tts_invoke_streaming text2speech model :param model: model name + :param tenant_id: user tenant id :param credentials: model credentials :param content_text: text content to be translated - :param user: unique user id + :param voice: model timbre :return: text translated to audio file """ # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) - voice_name = self._get_model_voice(model, credentials) + if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials): + voice = self._get_model_default_voice(model, credentials) word_limit = self._get_model_word_limit(model, credentials) audio_type = self._get_model_audio_type(model, credentials) tts_file_id = self._get_file_name(content_text) - file_path = f'storage/generate_files/{audio_type}/{tts_file_id}.{audio_type}' + file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}' try: client = OpenAI(**credentials_kwargs) sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) for sentence in sentences: - response = client.audio.speech.create(model=model, voice=voice_name, input=sentence.strip()) - response.stream_to_file(file_path) + response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip()) + # response.stream_to_file(file_path) + storage.save(file_path, response.read()) except Exception as ex: raise InvokeBadRequestError(str(ex)) - def _process_sentence(self, sentence: str, model: str, credentials: dict): + def _process_sentence(self, sentence: str, model: str, + voice, credentials: dict): """ _tts_invoke openai text2speech model api :param model: model name :param credentials: model credentials + :param voice: model timbre :param sentence: text content to be translated :return: text translated to audio file """ # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) - voice_name = self._get_model_voice(model, credentials) - client = OpenAI(**credentials_kwargs) - response = client.audio.speech.create(model=model, voice=voice_name, input=sentence.strip()) + response = client.audio.speech.create(model=model, voice=voice, input=sentence.strip()) if isinstance(response.read(), bytes): return response.read() diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py index 9b7b052b99..51950ca377 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py @@ -1,13 +1,14 @@ -from decimal import Decimal import requests -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.model_entities import (AIModelEntity, DefaultParameterName, FetchFrom, - ModelPropertyKey, ModelType, ParameterRule, ParameterType, - PriceConfig) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) + +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) class _CommonOAI_API_Compat: diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 53ee5817d9..cf90633aa6 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -1,19 +1,36 @@ import json import logging +from collections.abc import Generator from decimal import Decimal -from typing import Generator, List, Optional, Union, cast +from typing import Optional, Union, cast from urllib.parse import urljoin import requests + from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - PromptMessage, PromptMessageContent, PromptMessageContentType, - PromptMessageFunction, PromptMessageTool, SystemPromptMessage, - ToolPromptMessage, UserPromptMessage) -from core.model_runtime.entities.model_entities import (AIModelEntity, DefaultParameterName, FetchFrom, - ModelPropertyKey, ModelType, ParameterRule, ParameterType, - PriceConfig) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageContent, + PromptMessageContentType, + PromptMessageFunction, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, + PriceConfig, +) from core.model_runtime.errors.invoke import InvokeError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -30,7 +47,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -130,16 +147,16 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') if (completion_type is LLMMode.CHAT and ('object' not in json_result or json_result['object'] != 'chat.completion')): raise CredentialsValidateFailedError( - f'Credentials validation failed: invalid response object, must be \'chat.completion\'') + 'Credentials validation failed: invalid response object, must be \'chat.completion\'') elif (completion_type is LLMMode.COMPLETION and ('object' not in json_result or json_result['object'] != 'text_completion')): raise CredentialsValidateFailedError( - f'Credentials validation failed: invalid response object, must be \'text_completion\'') + 'Credentials validation failed: invalid response object, must be \'text_completion\'') except CredentialsValidateFailedError: raise except Exception as ex: @@ -229,7 +246,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): # validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard. def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, \ user: Optional[str] = None) -> Union[LLMResult, Generator]: """ @@ -350,13 +367,16 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter): if chunk: + #ignore sse comments + if chunk.startswith(':'): + continue decoded_chunk = chunk.strip().lstrip('data: ').lstrip() chunk_json = None try: chunk_json = json.loads(decoded_chunk) # stream ended except json.JSONDecodeError as e: - logger.error(f"decoded_chunk error,delimiter={delimiter},decoded_chunk={decoded_chunk}") + logger.error(f"decoded_chunk error: {e}, delimiter={delimiter}, decoded_chunk={decoded_chunk}") yield create_final_llm_result_chunk( index=chunk_index + 1, message=AssistantPromptMessage(content=""), @@ -551,7 +571,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, model: str, messages: List[PromptMessage], + def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """ Approximate num tokens with GPT2 tokenizer. diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py index 407eefa701..3445ebbaf7 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.py @@ -1,7 +1,5 @@ import logging -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.model_provider import ModelProvider logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index b735fdb792..3467cd6dfd 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -6,9 +6,16 @@ from urllib.parse import urljoin import numpy as np import requests + from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, - PriceConfig, PriceType) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + PriceConfig, + PriceType, +) from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel @@ -172,11 +179,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel): try: json_result = response.json() except json.JSONDecodeError as e: - raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error') + raise CredentialsValidateFailedError('Credentials validation failed: JSON decode error') if 'model' not in json_result: raise CredentialsValidateFailedError( - f'Credentials validation failed: invalid response') + 'Credentials validation failed: invalid response') except CredentialsValidateFailedError: raise except Exception as ex: diff --git a/api/core/model_runtime/model_providers/openllm/llm/llm.py b/api/core/model_runtime/model_providers/openllm/llm/llm.py index af62ddf92f..8ea5819bde 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/llm.py +++ b/api/core/model_runtime/model_providers/openllm/llm/llm.py @@ -1,28 +1,46 @@ -from typing import Generator, List, Optional, Union +from collections.abc import Generator from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) -from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, - ParameterRule, ParameterType) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.openllm.llm.openllm_generate import OpenLLMGenerate, OpenLLMGenerateMessage -from core.model_runtime.model_providers.openllm.llm.openllm_generate_errors import (BadRequestError, - InsufficientAccountBalanceError, - InternalServerError, - InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError) +from core.model_runtime.model_providers.openllm.llm.openllm_generate_errors import ( + BadRequestError, + InsufficientAccountBalanceError, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) class OpenLLMLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -59,7 +77,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): tools: list[PromptMessageTool] | None = None) -> int: return self._num_tokens_from_messages(prompt_messages, tools) - def _num_tokens_from_messages(self, messages: List[PromptMessage], tools: list[PromptMessageTool]) -> int: + def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ Calculate num tokens for OpenLLM model it's a generate model, so we just join them by spe @@ -69,7 +87,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: client = OpenLLMGenerate() response = client.generate( diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py index 2d9a10fa2a..43258d1e5e 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py @@ -1,16 +1,17 @@ +from collections.abc import Generator from enum import Enum from json import dumps, loads -from typing import Any, Dict, Generator, List, Union +from typing import Any, Union -from core.model_runtime.model_providers.openllm.llm.openllm_generate_errors import (BadRequestError, - InsufficientAccountBalanceError, - InternalServerError, - InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError) from requests import Response, post from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema +from core.model_runtime.model_providers.openllm.llm.openllm_generate_errors import ( + BadRequestError, + InternalServerError, + InvalidAuthenticationError, +) + class OpenLLMGenerateMessage: class Role(Enum): @@ -19,10 +20,10 @@ class OpenLLMGenerateMessage: role: str = Role.USER.value content: str - usage: Dict[str, int] = None + usage: dict[str, int] = None stop_reason: str = '' - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { 'role': self.role, 'content': self.content, @@ -33,10 +34,10 @@ class OpenLLMGenerateMessage: self.role = role -class OpenLLMGenerate(object): +class OpenLLMGenerate: def generate( - self, server_url: str, model_name: str, stream: bool, model_parameters: Dict[str, Any], - stop: List[str], prompt_messages: List[OpenLLMGenerateMessage], user: str, + self, server_url: str, model_name: str, stream: bool, model_parameters: dict[str, Any], + stop: list[str], prompt_messages: list[OpenLLMGenerateMessage], user: str, ) -> Union[Generator[OpenLLMGenerateMessage, None, None], OpenLLMGenerateMessage]: if not server_url: raise InvalidAuthenticationError('Invalid server URL') diff --git a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py index 2f30427d36..33847c0cb3 100644 --- a/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openllm/text_embedding/text_embedding.py @@ -1,15 +1,22 @@ import time -from json import dumps, loads +from json import dumps from typing import Optional +from requests import post +from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema + from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from requests import post -from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema class OpenLLMTextEmbeddingModel(TextEmbeddingModel): diff --git a/api/core/model_runtime/model_providers/replicate/_common.py b/api/core/model_runtime/model_providers/replicate/_common.py index ad130cabbc..29d8427d8e 100644 --- a/api/core/model_runtime/model_providers/replicate/_common.py +++ b/api/core/model_runtime/model_providers/replicate/_common.py @@ -1,6 +1,7 @@ -from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError from replicate.exceptions import ModelError, ReplicateError +from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError + class _CommonReplicate: diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index 69c0a82636..ee2de85607 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -1,23 +1,36 @@ -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole, - PromptMessageTool, SystemPromptMessage, UserPromptMessage) -from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, - ParameterRule) -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.replicate._common import _CommonReplicate from replicate import Client as ReplicateClient from replicate.exceptions import ReplicateError from replicate.prediction import Prediction +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageRole, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.replicate._common import _CommonReplicate + class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: version = credentials['model_version'] diff --git a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py index 37a275614c..a481aebc99 100644 --- a/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/replicate/text_embedding/text_embedding.py @@ -2,13 +2,14 @@ import json import time from typing import Optional +from replicate import Client as ReplicateClient + from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.replicate._common import _CommonReplicate -from replicate import Client as ReplicateClient class ReplicateEmbeddingModel(_CommonReplicate, TextEmbeddingModel): diff --git a/api/core/model_runtime/model_providers/spark/llm/_client.py b/api/core/model_runtime/model_providers/spark/llm/_client.py index 9390f4351b..a4659454ee 100644 --- a/api/core/model_runtime/model_providers/spark/llm/_client.py +++ b/api/core/model_runtime/model_providers/spark/llm/_client.py @@ -1,5 +1,4 @@ import base64 -import datetime import hashlib import hmac import json diff --git a/api/core/model_runtime/model_providers/spark/llm/llm.py b/api/core/model_runtime/model_providers/spark/llm/llm.py index 33475f5769..65beae517c 100644 --- a/api/core/model_runtime/model_providers/spark/llm/llm.py +++ b/api/core/model_runtime/model_providers/spark/llm/llm.py @@ -1,11 +1,23 @@ import threading -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -16,7 +28,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -75,7 +87,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -233,7 +245,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): return message_text - def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str: + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Anthropic model diff --git a/api/core/model_runtime/model_providers/spark/spark.py b/api/core/model_runtime/model_providers/spark/spark.py index c8bea10390..b3695e0501 100644 --- a/api/core/model_runtime/model_providers/spark/spark.py +++ b/api/core/model_runtime/model_providers/spark/spark.py @@ -1,7 +1,5 @@ import logging -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.model_provider import ModelProvider logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/togetherai/llm/llm.py b/api/core/model_runtime/model_providers/togetherai/llm/llm.py index 89198fe4b0..b312d99b1c 100644 --- a/api/core/model_runtime/model_providers/togetherai/llm/llm.py +++ b/api/core/model_runtime/model_providers/togetherai/llm/llm.py @@ -1,4 +1,5 @@ -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool @@ -14,7 +15,7 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) @@ -27,7 +28,7 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): return super().validate_credentials(model, cred_with_endpoint) def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) diff --git a/api/core/model_runtime/model_providers/togetherai/togetherai.py b/api/core/model_runtime/model_providers/togetherai/togetherai.py index e2ede35d69..ffce4794e7 100644 --- a/api/core/model_runtime/model_providers/togetherai/togetherai.py +++ b/api/core/model_runtime/model_providers/togetherai/togetherai.py @@ -1,7 +1,5 @@ import logging -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.model_provider import ModelProvider logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/tongyi/llm/_client.py b/api/core/model_runtime/model_providers/tongyi/llm/_client.py index 2aab69af7a..cfe33558e1 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/_client.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/_client.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Optional from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms import Tongyi @@ -8,7 +8,7 @@ from langchain.schema import Generation, LLMResult class EnhanceTongyi(Tongyi): @property - def _default_params(self) -> Dict[str, Any]: + def _default_params(self) -> dict[str, Any]: """Get the default parameters for calling OpenAI API.""" normal_params = { "top_p": self.top_p, @@ -19,13 +19,13 @@ class EnhanceTongyi(Tongyi): def _generate( self, - prompts: List[str], - stop: Optional[List[str]] = None, + prompts: list[str], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: generations = [] - params: Dict[str, Any] = { + params: dict[str, Any] = { **{"model": self.model_name}, **self._default_params, **kwargs, diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index 63f300fc19..7ae8b87764 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -1,18 +1,37 @@ -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from dashscope import get_tokenizer from dashscope.api_entities.dashscope_response import DashScopeAPIResponse -from dashscope.common.error import (AuthenticationError, InvalidParameter, RequestFailure, ServiceUnavailableError, - UnsupportedHTTPMethod, UnsupportedModel) +from dashscope.common.error import ( + AuthenticationError, + InvalidParameter, + RequestFailure, + ServiceUnavailableError, + UnsupportedHTTPMethod, + UnsupportedModel, +) from langchain.llms.tongyi import generate_with_retry, stream_generate_with_retry +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + from ._client import EnhanceTongyi @@ -20,7 +39,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -82,7 +101,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -250,7 +269,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): return message_text - def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str: + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Anthropic model diff --git a/api/core/model_runtime/model_providers/tongyi/tts/tts-1.yaml b/api/core/model_runtime/model_providers/tongyi/tts/tts-1.yaml index 8746fb9f02..e533d5812d 100644 --- a/api/core/model_runtime/model_providers/tongyi/tts/tts-1.yaml +++ b/api/core/model_runtime/model_providers/tongyi/tts/tts-1.yaml @@ -1,7 +1,134 @@ model: tts-1 model_type: tts model_properties: - default_voice: 'sambert-zhiru-v1' # 音色参考 https://help.aliyun.com/zh/dashscope/model-list 配置 + default_voice: 'sambert-zhiru-v1' + voices: + - mode: "sambert-zhinan-v1" + name: "知楠(广告男声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhiqi-v1" + name: "知琪(温柔女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhichu-v1" + name: "知厨(新闻播报)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhide-v1" + name: "知德(新闻男声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhijia-v1" + name: "知佳(标准女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhiru-v1" + name: "知茹(新闻女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhiqian-v1" + name: "知倩(配音解说、新闻播报)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhixiang-v1" + name: "知祥(配音解说)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhiwei-v1" + name: "知薇(萝莉女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhihao-v1" + name: "知浩(咨询男声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhijing-v1" + name: "知婧(严厉女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhiming-v1" + name: "知茗(诙谐男声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhimo-v1" + name: "知墨(情感男声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhina-v1" + name: "知娜(浙普女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhishu-v1" + name: "知树(资讯男声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhistella-v1" + name: "知莎(知性女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhiting-v1" + name: "知婷(电台女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhixiao-v1" + name: "知笑(资讯女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhiya-v1" + name: "知雅(严厉女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhiye-v1" + name: "知晔(青年男声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhiying-v1" + name: "知颖(软萌童声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhiyuan-v1" + name: "知媛(知心姐姐)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhigui-v1" + name: "知柜(直播女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhishuo-v1" + name: "知硕(自然男声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhimiao-emo-v1" + name: "知妙(多种情感女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhimao-v1" + name: "知猫(直播女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhilun-v1" + name: "知伦(悬疑解说)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhifei-v1" + name: "知飞(激昂解说)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-zhida-v1" + name: "知达(标准男声)" + language: [ "zh-Hans", "en-US" ] + - mode: "sambert-camila-v1" + name: "Camila(西班牙语女声)" + language: [ "es-ES" ] + - mode: "sambert-perla-v1" + name: "Perla(意大利语女声)" + language: [ "it-IT" ] + - mode: "sambert-indah-v1" + name: "Indah(印尼语女声)" + language: [ "id-ID" ] + - mode: "sambert-clara-v1" + name: "Clara(法语女声)" + language: [ "fr-FR" ] + - mode: "sambert-hanna-v1" + name: "Hanna(德语女声)" + language: [ "de-DE" ] + - mode: "sambert-beth-v1" + name: "Beth(咨询女声)" + language: [ "en-US" ] + - mode: "sambert-betty-v1" + name: "Betty(客服女声)" + language: [ "en-US" ] + - mode: "sambert-cally-v1" + name: "Cally(自然女声)" + language: [ "en-US" ] + - mode: "sambert-cindy-v1" + name: "Cindy(对话女声)" + language: [ "en-US" ] + - mode: "sambert-eva-v1" + name: "Eva(陪伴女声)" + language: [ "en-US" ] + - mode: "sambert-donna-v1" + name: "Donna(教育女声)" + language: [ "en-US" ] + - mode: "sambert-brian-v1" + name: "Brian(客服男声)" + language: [ "en-US" ] + - mode: "sambert-waan-v1" + name: "Waan(泰语女声)" + language: [ "th-TH" ] word_limit: 120 audio_type: 'mp3' max_workers: 5 diff --git a/api/core/model_runtime/model_providers/tongyi/tts/tts.py b/api/core/model_runtime/model_providers/tongyi/tts/tts.py index a6fc201080..6bd17684fe 100644 --- a/api/core/model_runtime/model_providers/tongyi/tts/tts.py +++ b/api/core/model_runtime/model_providers/tongyi/tts/tts.py @@ -4,39 +4,47 @@ from io import BytesIO from typing import Optional import dashscope +from flask import Response, stream_with_context +from pydub import AudioSegment + from core.model_runtime.errors.invoke import InvokeBadRequestError from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.model_runtime.model_providers.tongyi._common import _CommonTongyi -from flask import Response, stream_with_context -from pydub import AudioSegment +from extensions.ext_storage import storage class TongyiText2SpeechModel(_CommonTongyi, TTSModel): """ Model class for Tongyi Speech to text model. """ - def _invoke(self, model: str, credentials: dict, content_text: str, streaming: bool, user: Optional[str] = None) -> any: + + def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, streaming: bool, + user: Optional[str] = None) -> any: """ _invoke text2speech model :param model: model name + :param tenant_id: user tenant id :param credentials: model credentials + :param voice: model timbre :param content_text: text content to be translated :param streaming: output is streaming :param user: unique user id :return: text translated to audio file """ - self._is_ffmpeg_installed() audio_type = self._get_model_audio_type(model, credentials) + if not voice or voice not in self.get_tts_model_voices(model=model, credentials=credentials): + voice = self._get_model_default_voice(model, credentials) if streaming: return Response(stream_with_context(self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, - user=user)), + voice=voice, + tenant_id=tenant_id)), status=200, mimetype=f'audio/{audio_type}') else: - return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, user=user) + return self._tts_invoke(model=model, credentials=credentials, content_text=content_text, voice=voice) def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: """ @@ -51,91 +59,96 @@ class TongyiText2SpeechModel(_CommonTongyi, TTSModel): self._tts_invoke( model=model, credentials=credentials, - content_text='Hello world!', - user=user + content_text='Hello Dify!', + voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _tts_invoke(self, model: str, credentials: dict, content_text: str, user: Optional[str] = None) -> Response: + def _tts_invoke(self, model: str, credentials: dict, content_text: str, voice: str) -> Response: """ _tts_invoke text2speech model :param model: model name :param credentials: model credentials + :param voice: model timbre :param content_text: text content to be translated - :param user: unique user id :return: text translated to audio file """ audio_type = self._get_model_audio_type(model, credentials) word_limit = self._get_model_word_limit(model, credentials) max_workers = self._get_model_workers_limit(model, credentials) - try: sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) audio_bytes_list = list() # Create a thread pool and map the function to the list of sentences with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [executor.submit(self._process_sentence, model=model, sentence=sentence, - credentials=credentials, audio_type=audio_type) for sentence in sentences] + futures = [executor.submit(self._process_sentence, sentence=sentence, + credentials=credentials, voice=voice, audio_type=audio_type) for sentence in + sentences] for future in futures: try: - audio_bytes_list.append(future.result()) + if future.result(): + audio_bytes_list.append(future.result()) except Exception as ex: raise InvokeBadRequestError(str(ex)) - audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in - audio_bytes_list if audio_bytes] - combined_segment = reduce(lambda x, y: x + y, audio_segments) - buffer: BytesIO = BytesIO() - combined_segment.export(buffer, format=audio_type) - buffer.seek(0) - return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}") + if len(audio_bytes_list) > 0: + audio_segments = [AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in + audio_bytes_list if audio_bytes] + combined_segment = reduce(lambda x, y: x + y, audio_segments) + buffer: BytesIO = BytesIO() + combined_segment.export(buffer, format=audio_type) + buffer.seek(0) + return Response(buffer.read(), status=200, mimetype=f"audio/{audio_type}") except Exception as ex: raise InvokeBadRequestError(str(ex)) # Todo: To improve the streaming function - def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, user: Optional[str] = None) -> any: + def _tts_invoke_streaming(self, model: str, tenant_id: str, credentials: dict, content_text: str, + voice: str) -> any: """ _tts_invoke_streaming text2speech model :param model: model name + :param tenant_id: user tenant id :param credentials: model credentials + :param voice: model timbre :param content_text: text content to be translated - :param user: unique user id :return: text translated to audio file """ - # transform credentials to kwargs for model instance dashscope.api_key = credentials.get('dashscope_api_key') - voice_name = self._get_model_voice(model, credentials) word_limit = self._get_model_word_limit(model, credentials) audio_type = self._get_model_audio_type(model, credentials) + tts_file_id = self._get_file_name(content_text) + file_path = f'generate_files/audio/{tenant_id}/{tts_file_id}.{audio_type}' try: sentences = list(self._split_text_into_sentences(text=content_text, limit=word_limit)) for sentence in sentences: - response = dashscope.audio.tts.SpeechSynthesizer.call(model=voice_name, sample_rate=48000, text=sentence.strip(), + response = dashscope.audio.tts.SpeechSynthesizer.call(model=voice, sample_rate=48000, + text=sentence.strip(), format=audio_type, word_timestamp_enabled=True, phoneme_timestamp_enabled=True) if isinstance(response.get_audio_data(), bytes): - return response.get_audio_data() + storage.save(file_path, response.get_audio_data()) except Exception as ex: raise InvokeBadRequestError(str(ex)) - def _process_sentence(self, sentence: str, model: str, credentials: dict, audio_type: str): + @staticmethod + def _process_sentence(sentence: str, credentials: dict, voice: str, audio_type: str): """ _tts_invoke Tongyi text2speech model api - :param model: model name :param credentials: model credentials :param sentence: text content to be translated + :param voice: model timbre :param audio_type: audio file type :return: text translated to audio file """ - # transform credentials to kwargs for model instance dashscope.api_key = credentials.get('dashscope_api_key') - voice_name = self._get_model_voice(model, credentials) - - response = dashscope.audio.tts.SpeechSynthesizer.call(model=voice_name, sample_rate=48000, text=sentence.strip(), format=audio_type) + response = dashscope.audio.tts.SpeechSynthesizer.call(model=voice, sample_rate=48000, + text=sentence.strip(), + format=audio_type) if isinstance(response.get_audio_data(), bytes): return response.get_audio_data() diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index 65081a9665..81868aeed1 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -1,18 +1,23 @@ +from collections.abc import Generator from datetime import datetime, timedelta from enum import Enum from json import dumps, loads from threading import Lock -from typing import Any, Dict, Generator, List, Union +from typing import Any, Union -from core.model_runtime.entities.message_entities import PromptMessageTool -from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (BadRequestError, InternalServerError, - InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError) from requests import Response, post +from core.model_runtime.entities.message_entities import PromptMessageTool +from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( + BadRequestError, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) + # map api_key to access_token -baidu_access_tokens: Dict[str, 'BaiduAccessToken'] = {} +baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {} baidu_access_tokens_lock = Lock() class BaiduAccessToken: @@ -101,10 +106,10 @@ class ErnieMessage: role: str = Role.USER.value content: str - usage: Dict[str, int] = None + usage: dict[str, int] = None stop_reason: str = '' - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { 'role': self.role, 'content': self.content, @@ -114,7 +119,7 @@ class ErnieMessage: self.content = content self.role = role -class ErnieBotModel(object): +class ErnieBotModel: api_bases = { 'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', 'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', @@ -134,9 +139,9 @@ class ErnieBotModel(object): self.api_key = api_key self.secret_key = secret_key - def generate(self, model: str, stream: bool, messages: List[ErnieMessage], - parameters: Dict[str, Any], timeout: int, tools: List[PromptMessageTool], \ - stop: List[str], user: str) \ + def generate(self, model: str, stream: bool, messages: list[ErnieMessage], + parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \ + stop: list[str], user: str) \ -> Union[Generator[ErnieMessage, None, None], ErnieMessage]: # check parameters @@ -212,11 +217,11 @@ class ErnieBotModel(object): token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key) return token.access_token - def _copy_messages(self, messages: List[ErnieMessage]) -> List[ErnieMessage]: + def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]: return [ErnieMessage(message.content, message.role) for message in messages] - def _check_parameters(self, model: str, parameters: Dict[str, Any], - tools: List[PromptMessageTool], stop: List[str]) -> None: + def _check_parameters(self, model: str, parameters: dict[str, Any], + tools: list[PromptMessageTool], stop: list[str]) -> None: if model not in self.api_bases: raise BadRequestError(f'Invalid model: {model}') @@ -227,40 +232,40 @@ class ErnieBotModel(object): # so, we just disable function calling for now. if tools is not None and len(tools) > 0: - raise BadRequestError(f'function calling is not supported yet.') + raise BadRequestError('function calling is not supported yet.') if stop is not None: if len(stop) > 4: - raise BadRequestError(f'stop list should not exceed 4 items.') + raise BadRequestError('stop list should not exceed 4 items.') for s in stop: if len(s) > 20: - raise BadRequestError(f'stop item should not exceed 20 characters.') + raise BadRequestError('stop item should not exceed 20 characters.') - def _build_request_body(self, model: str, messages: List[ErnieMessage], stream: bool, parameters: Dict[str, Any], - tools: List[PromptMessageTool], stop: List[str], user: str) -> Dict[str, Any]: + def _build_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, parameters: dict[str, Any], + tools: list[PromptMessageTool], stop: list[str], user: str) -> dict[str, Any]: # if model in self.function_calling_supports: # return self._build_function_calling_request_body(model, messages, parameters, tools, stop, user) return self._build_chat_request_body(model, messages, stream, parameters, stop, user) - def _build_function_calling_request_body(self, model: str, messages: List[ErnieMessage], stream: bool, - parameters: Dict[str, Any], tools: List[PromptMessageTool], - stop: List[str], user: str) \ - -> Dict[str, Any]: + def _build_function_calling_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, + parameters: dict[str, Any], tools: list[PromptMessageTool], + stop: list[str], user: str) \ + -> dict[str, Any]: if len(messages) % 2 == 0: - raise BadRequestError(f'The number of messages should be odd.') + raise BadRequestError('The number of messages should be odd.') if messages[0].role == 'function': - raise BadRequestError(f'The first message should be user message.') + raise BadRequestError('The first message should be user message.') """ TODO: implement function calling """ - def _build_chat_request_body(self, model: str, messages: List[ErnieMessage], stream: bool, - parameters: Dict[str, Any], stop: List[str], user: str) \ - -> Dict[str, Any]: + def _build_chat_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, + parameters: dict[str, Any], stop: list[str], user: str) \ + -> dict[str, Any]: if len(messages) == 0: - raise BadRequestError(f'The number of messages should not be zero.') + raise BadRequestError('The number of messages should not be zero.') # check if the first element is system, shift it system_message = '' @@ -269,9 +274,9 @@ class ErnieBotModel(object): system_message = message.content if len(messages) % 2 == 0: - raise BadRequestError(f'The number of messages should be odd.') + raise BadRequestError('The number of messages should be odd.') if messages[0].role != 'user': - raise BadRequestError(f'The first message should be user message.') + raise BadRequestError('The first message should be user message.') body = { 'messages': [message.to_dict() for message in messages], 'stream': stream, diff --git a/api/core/model_runtime/model_providers/wenxin/llm/llm.py b/api/core/model_runtime/model_providers/wenxin/llm/llm.py index 27b2bce9af..51b3c97497 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/llm.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/llm.py @@ -1,23 +1,39 @@ -from typing import Generator, List, Optional, Union, cast +from collections.abc import Generator +from typing import cast -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.wenxin.llm.ernie_bot import BaiduAccessToken, ErnieBotModel, ErnieMessage -from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import (BadRequestError, InsufficientAccountBalance, - InternalServerError, InvalidAPIKeyError, - InvalidAuthenticationError, - RateLimitReachedError) +from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( + BadRequestError, + InsufficientAccountBalance, + InternalServerError, + InvalidAPIKeyError, + InvalidAuthenticationError, + RateLimitReachedError, +) class ErnieBotLarguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: List[str] | None = None, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, @@ -28,7 +44,7 @@ class ErnieBotLarguageModel(LargeLanguageModel): # tools is not supported yet return self._num_tokens_from_messages(prompt_messages) - def _num_tokens_from_messages(self, messages: List[PromptMessage],) -> int: + def _num_tokens_from_messages(self, messages: list[PromptMessage],) -> int: """Calculate num tokens for baichuan model""" def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -63,7 +79,7 @@ class ErnieBotLarguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: instance = ErnieBotModel( api_key=credentials['api_key'], diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 841e197873..ffb4a0328c 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -1,32 +1,69 @@ -from typing import Generator, Iterator, List, Optional, Union, cast +from collections.abc import Generator, Iterator +from typing import cast -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageTool, - SystemPromptMessage, ToolPromptMessage, UserPromptMessage) -from core.model_runtime.entities.model_entities import (AIModelEntity, FetchFrom, ModelFeature, ModelPropertyKey, - ModelType, ParameterRule, ParameterType) -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.xinference.xinference_helper import (XinferenceHelper, - XinferenceModelExtraParameter) -from core.model_runtime.utils import helper -from openai import (APIConnectionError, APITimeoutError, AuthenticationError, ConflictError, InternalServerError, - NotFoundError, OpenAI, PermissionDeniedError, RateLimitError, Stream, UnprocessableEntityError) +from openai import ( + APIConnectionError, + APITimeoutError, + AuthenticationError, + ConflictError, + InternalServerError, + NotFoundError, + OpenAI, + PermissionDeniedError, + RateLimitError, + UnprocessableEntityError, +) from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessageToolCall from openai.types.chat.chat_completion_chunk import ChoiceDeltaFunctionCall, ChoiceDeltaToolCall from openai.types.chat.chat_completion_message import FunctionCall from openai.types.completion import Completion -from xinference_client.client.restful.restful_client import (Client, RESTfulChatglmCppChatModelHandle, - RESTfulChatModelHandle, RESTfulGenerateModelHandle) +from xinference_client.client.restful.restful_client import ( + Client, + RESTfulChatglmCppChatModelHandle, + RESTfulChatModelHandle, + RESTfulGenerateModelHandle, +) + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.xinference.xinference_helper import ( + XinferenceHelper, + XinferenceModelExtraParameter, +) +from core.model_runtime.utils import helper class XinferenceAILargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ invoke LLM @@ -60,6 +97,9 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): } """ try: + if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") + extra_param = XinferenceHelper.get_xinference_extra_parameter( server_url=credentials['server_url'], model_uid=credentials['model_uid'] @@ -95,7 +135,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): """ return self._num_tokens_from_messages(prompt_messages, tools) - def _num_tokens_from_messages(self, messages: List[PromptMessage], tools: list[PromptMessageTool], + def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], is_completion_model: bool = False) -> int: def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -323,7 +363,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ generate text from LLM @@ -368,7 +408,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): } for tool in tools ] - if isinstance(xinference_model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)): + if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle): resp = client.chat.completions.create( model=credentials['model_uid'], messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages], diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index 9ec9e09aa0..1399e9ccd2 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -1,13 +1,20 @@ from typing import Optional +from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle + from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.rerank_model import RerankModel -from xinference_client.client.restful.restful_client import Client, RESTfulRerankModelHandle class XinferenceRerankModel(RerankModel): @@ -85,6 +92,9 @@ class XinferenceRerankModel(RerankModel): :return: """ try: + if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") + self.invoke( model=model, credentials=credentials, diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index bfc77db494..32d2b1516d 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -1,15 +1,22 @@ import time from typing import Optional +from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle + from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.xinference.xinference_helper import XinferenceHelper -from xinference_client.client.restful.restful_client import Client, RESTfulEmbeddingModelHandle, RESTfulModelHandle class XinferenceTextEmbeddingModel(TextEmbeddingModel): @@ -106,6 +113,9 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel): :return: """ try: + if "/" in credentials['model_uid'] or "?" in credentials['model_uid'] or "#" in credentials['model_uid']: + raise CredentialsValidateFailedError("model_uid should not contain /, ?, or #") + server_url = credentials['server_url'] model_uid = credentials['model_uid'] extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid) diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 764ffe8b65..24a91af62c 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -1,23 +1,21 @@ from os import path from threading import Lock from time import time -from typing import List -from requests import get from requests.adapters import HTTPAdapter from requests.exceptions import ConnectionError, MissingSchema, Timeout from requests.sessions import Session -class XinferenceModelExtraParameter(object): +class XinferenceModelExtraParameter: model_format: str model_handle_type: str - model_ability: List[str] + model_ability: list[str] max_tokens: int = 512 context_length: int = 2048 support_function_call: bool = False - def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str], + def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], support_function_call: bool, max_tokens: int, context_length: int) -> None: self.model_format = model_format self.model_handle_type = model_handle_type diff --git a/api/core/model_runtime/model_providers/zhipuai/_common.py b/api/core/model_runtime/model_providers/zhipuai/_common.py index b961fe8b24..2574234abf 100644 --- a/api/core/model_runtime/model_providers/zhipuai/_common.py +++ b/api/core/model_runtime/model_providers/zhipuai/_common.py @@ -1,5 +1,11 @@ -from core.model_runtime.errors.invoke import (InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, - InvokeError, InvokeRateLimitError, InvokeServerUnavailableError) +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) class _CommonZhipuaiAI: diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index e3180ec177..c62422dfb0 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -1,12 +1,16 @@ -import json -from typing import Any, Dict, Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - PromptMessage, PromptMessageContentType, PromptMessageRole, - PromptMessageTool, SystemPromptMessage, - TextPromptMessageContent, ToolPromptMessage, - UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageContentType, + PromptMessageRole, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI @@ -20,7 +24,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -86,7 +90,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): def _generate(self, model: str, credentials_kwargs: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -116,7 +120,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): prompt_messages = prompt_messages[1:] # resolve zhipuai model not support system message and user message, assistant message must be in sequence - new_prompt_messages: List[PromptMessage] = [] + new_prompt_messages: list[PromptMessage] = [] for prompt_message in prompt_messages: copy_prompt_message = prompt_message.copy() if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]: @@ -272,7 +276,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): :return: llm response """ text = '' - assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = [] + assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = [] for choice in response.choices: if choice.message.tool_calls: for tool_call in choice.message.tool_calls: @@ -332,7 +336,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): continue - assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = [] + assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = [] for tool_call in delta.delta.tool_calls or []: if tool_call.type == 'function': assistant_tool_calls.append( @@ -406,7 +410,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): return message_text - def _convert_messages_to_prompt(self, messages: List[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str: + def _convert_messages_to_prompt(self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str: """ :param messages: List of PromptMessage to combine. :return: Combined string with necessary human_prompt and ai_prompt tags. diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index fd39f5a7a9..0f9fecfc72 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -1,5 +1,5 @@ import time -from typing import List, Optional, Tuple +from typing import Optional from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult @@ -7,7 +7,6 @@ from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI from core.model_runtime.model_providers.zhipuai.zhipuai_sdk._client import ZhipuAI -from langchain.schema.language_model import _get_token_ids_default_method class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): @@ -82,7 +81,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def embed_documents(self, model: str, client: ZhipuAI, texts: List[str]) -> Tuple[List[List[float]], int]: + def embed_documents(self, model: str, client: ZhipuAI, texts: list[str]) -> tuple[list[list[float]], int]: """Call out to ZhipuAI's embedding endpoint. Args: @@ -102,7 +101,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): return [list(map(float, e)) for e in embeddings], embedding_used_tokens - def embed_query(self, text: str) -> List[float]: + def embed_query(self, text: str) -> list[float]: """Call out to ZhipuAI's embedding endpoint. Args: diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py index 573f0715c4..29b1746351 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py @@ -1,7 +1,8 @@ from __future__ import annotations import os -from typing import Mapping, Union +from collections.abc import Mapping +from typing import Union import httpx from httpx import Timeout @@ -37,7 +38,7 @@ class ZhipuAI(HttpClient): if base_url is None: base_url = os.environ.get("ZHIPUAI_BASE_URL") if base_url is None: - base_url = f"https://open.bigmodel.cn/api/paas/v4" + base_url = "https://open.bigmodel.cn/api/paas/v4" from .__version__ import __version__ super().__init__( version=__version__, diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py index 16c4b54f1a..dab6dac5fe 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py @@ -1,9 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Union import httpx -from typing_extensions import Literal from ...core._base_api import BaseAPI from ...core._base_type import NOT_GIVEN, Headers, NotGiven @@ -15,7 +14,7 @@ if TYPE_CHECKING: class AsyncCompletions(BaseAPI): - def __init__(self, client: "ZhipuAI") -> None: + def __init__(self, client: ZhipuAI) -> None: super().__init__(client) @@ -29,8 +28,8 @@ class AsyncCompletions(BaseAPI): top_p: Optional[float] | NotGiven = NOT_GIVEN, max_tokens: int | NotGiven = NOT_GIVEN, seed: int | NotGiven = NOT_GIVEN, - messages: Union[str, List[str], List[int], List[List[int]], None], - stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN, + messages: Union[str, list[str], list[int], list[list[int]], None], + stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, tools: Optional[object] | NotGiven = NOT_GIVEN, tool_choice: str | NotGiven = NOT_GIVEN, diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py index e5bb8cdf68..5c4ed4d1ba 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py @@ -1,9 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Union import httpx -from typing_extensions import Literal from ...core._base_api import BaseAPI from ...core._base_type import NOT_GIVEN, Headers, NotGiven @@ -17,7 +16,7 @@ if TYPE_CHECKING: class Completions(BaseAPI): - def __init__(self, client: "ZhipuAI") -> None: + def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( @@ -31,8 +30,8 @@ class Completions(BaseAPI): top_p: Optional[float] | NotGiven = NOT_GIVEN, max_tokens: int | NotGiven = NOT_GIVEN, seed: int | NotGiven = NOT_GIVEN, - messages: Union[str, List[str], List[int], object, None], - stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN, + messages: Union[str, list[str], list[int], object, None], + stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, tools: Optional[object] | NotGiven = NOT_GIVEN, tool_choice: str | NotGiven = NOT_GIVEN, diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py index d5db469de4..35d54592fd 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Optional, Union import httpx @@ -14,13 +14,13 @@ if TYPE_CHECKING: class Embeddings(BaseAPI): - def __init__(self, client: "ZhipuAI") -> None: + def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( self, *, - input: Union[str, List[str], List[int], List[List[int]]], + input: Union[str, list[str], list[int], list[list[int]]], model: Union[str], encoding_format: str | NotGiven = NOT_GIVEN, user: str | NotGiven = NOT_GIVEN, diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py index dfe52fd54c..5deb8d08f3 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING import httpx from ..core._base_api import BaseAPI -from ..core._base_type import NOT_GIVEN, Body, FileTypes, Headers, NotGiven, Query +from ..core._base_type import NOT_GIVEN, FileTypes, Headers, NotGiven from ..core._files import is_file_content from ..core._http_client import make_user_request_input from ..types.file_object import FileObject, ListOfFileObject @@ -18,7 +18,7 @@ __all__ = ["Files"] class Files(BaseAPI): - def __init__(self, client: "ZhipuAI") -> None: + def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py index ead6cdae2f..b860de192a 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py @@ -17,7 +17,7 @@ __all__ = ["Jobs"] class Jobs(BaseAPI): - def __init__(self, client: "ZhipuAI") -> None: + def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py index 65ce5b246f..3201426dfa 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Optional import httpx @@ -14,7 +14,7 @@ if TYPE_CHECKING: class Images(BaseAPI): - def __init__(self, client: "ZhipuAI") -> None: + def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def generations( diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py index f3dde8461c..b7cf6bb7fd 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py @@ -1,21 +1,22 @@ from __future__ import annotations +from collections.abc import Mapping, Sequence from os import PathLike -from typing import IO, TYPE_CHECKING, Any, List, Mapping, Sequence, Tuple, Type, TypeVar, Union +from typing import IO, TYPE_CHECKING, Any, Literal, TypeVar, Union import pydantic -from typing_extensions import Literal, override +from typing_extensions import override Query = Mapping[str, object] Body = object AnyMapping = Mapping[str, object] PrimitiveData = Union[str, int, float, bool, None] -Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"] +Data = Union[PrimitiveData, list[Any], tuple[Any], "Mapping[str, Any]"] ModelT = TypeVar("ModelT", bound=pydantic.BaseModel) _T = TypeVar("_T") if TYPE_CHECKING: - NoneType: Type[None] + NoneType: type[None] else: NoneType = type(None) @@ -74,7 +75,7 @@ Headers = Mapping[str, Union[str, Omit]] ResponseT = TypeVar( "ResponseT", - bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]", + bound="Union[str, None, BaseModel, list[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]", ) # for user input files @@ -85,21 +86,21 @@ else: FileTypes = Union[ FileContent, # file content - Tuple[str, FileContent], # (filename, file) - Tuple[str, FileContent, str], # (filename, file , content_type) - Tuple[str, FileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) + tuple[str, FileContent], # (filename, file) + tuple[str, FileContent, str], # (filename, file , content_type) + tuple[str, FileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) ] -RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]] +RequestFiles = Union[Mapping[str, FileTypes], Sequence[tuple[str, FileTypes]]] # for httpx client supported files HttpxFileContent = Union[bytes, IO[bytes]] HttpxFileTypes = Union[ FileContent, # file content - Tuple[str, HttpxFileContent], # (filename, file) - Tuple[str, HttpxFileContent, str], # (filename, file , content_type) - Tuple[str, HttpxFileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) + tuple[str, HttpxFileContent], # (filename, file) + tuple[str, HttpxFileContent, str], # (filename, file , content_type) + tuple[str, HttpxFileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) ] -HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]] +HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[tuple[str, HttpxFileTypes]]] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py index e41ede128a..0796bfe11c 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py @@ -2,14 +2,14 @@ from __future__ import annotations import io import os +from collections.abc import Mapping, Sequence from pathlib import Path -from typing import Mapping, Sequence from ._base_type import FileTypes, HttpxFileTypes, HttpxRequestFiles, RequestFiles def is_file_content(obj: object) -> bool: - return isinstance(obj, (bytes, tuple, io.IOBase, os.PathLike)) + return isinstance(obj, bytes | tuple | io.IOBase | os.PathLike) def _transform_file(file: FileTypes) -> HttpxFileTypes: diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py index 5227d20615..e13d2b0233 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py @@ -1,8 +1,8 @@ -# -*- coding:utf-8 -*- from __future__ import annotations import inspect -from typing import Any, Mapping, Type, Union, cast +from collections.abc import Mapping +from typing import Any, Union, cast import httpx import pydantic @@ -140,7 +140,7 @@ class HttpClient: for k, v in value.items(): items.extend(self._object_to_formfata(f"{key}[{k}]", v)) return items - if isinstance(value, (list, tuple)): + if isinstance(value, list | tuple): for v in value: items.extend(self._object_to_formfata(key + "[]", v)) return items @@ -175,7 +175,7 @@ class HttpClient: def _parse_response( self, *, - cast_type: Type[ResponseT], + cast_type: type[ResponseT], response: httpx.Response, enable_stream: bool, request_param: ClientRequestParam, @@ -224,7 +224,7 @@ class HttpClient: def request( self, *, - cast_type: Type[ResponseT], + cast_type: type[ResponseT], params: ClientRequestParam, enable_stream: bool = False, stream_cls: type[StreamResponse[Any]] | None = None, @@ -259,7 +259,7 @@ class HttpClient: self, path: str, *, - cast_type: Type[ResponseT], + cast_type: type[ResponseT], options: UserRequestInput = {}, enable_stream: bool = False, ) -> ResponseT | StreamResponse: @@ -274,7 +274,7 @@ class HttpClient: path: str, *, body: Body | None = None, - cast_type: Type[ResponseT], + cast_type: type[ResponseT], options: UserRequestInput = {}, files: RequestFiles | None = None, enable_stream: bool = False, @@ -294,7 +294,7 @@ class HttpClient: path: str, *, body: Body | None = None, - cast_type: Type[ResponseT], + cast_type: type[ResponseT], options: UserRequestInput = {}, ) -> ResponseT: opts = ClientRequestParam.construct(method="patch", url=path, json_data=body, **options) @@ -308,7 +308,7 @@ class HttpClient: path: str, *, body: Body | None = None, - cast_type: Type[ResponseT], + cast_type: type[ResponseT], options: UserRequestInput = {}, files: RequestFiles | None = None, ) -> ResponseT | StreamResponse: @@ -324,7 +324,7 @@ class HttpClient: path: str, *, body: Body | None = None, - cast_type: Type[ResponseT], + cast_type: type[ResponseT], options: UserRequestInput = {}, ) -> ResponseT | StreamResponse: opts = ClientRequestParam.construct(method="delete", url=path, json_data=body, **options) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py index bbf2e72e68..b0a91d04a9 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import time import cachetools.func diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py index 3f22731de6..a3f49ba846 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py @@ -1,11 +1,10 @@ from __future__ import annotations -from typing import Any, Union, cast +from typing import Any, ClassVar, Union -import pydantic.generics from httpx import Timeout from pydantic import ConfigDict -from typing_extensions import ClassVar, TypedDict, Unpack +from typing_extensions import TypedDict, Unpack from ._base_type import Body, Headers, HttpxRequestFiles, NotGiven, Query from ._utils import remove_notgiven_indict @@ -18,7 +17,7 @@ class UserRequestInput(TypedDict, total=False): params: Query | None -class ClientRequestParam(): +class ClientRequestParam: method: str url: str max_retries: Union[int, NotGiven] = NotGiven() diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py index 116246e645..2f831b6fc9 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py @@ -1,11 +1,11 @@ from __future__ import annotations import datetime -from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, get_args, get_origin import httpx import pydantic -from typing_extensions import ParamSpec, get_args, get_origin +from typing_extensions import ParamSpec from ._base_type import NoneType from ._sse_client import StreamResponse @@ -19,7 +19,7 @@ R = TypeVar("R") class HttpResponse(Generic[R]): _cast_type: type[R] - _client: "HttpClient" + _client: HttpClient _parsed: R | None _enable_stream: bool _stream_cls: type[StreamResponse[Any]] @@ -30,7 +30,7 @@ class HttpResponse(Generic[R]): *, raw_response: httpx.Response, cast_type: type[R], - client: "HttpClient", + client: HttpClient, enable_stream: bool = False, stream_cls: type[StreamResponse[Any]] | None = None, ) -> None: diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py index 6efe20edcb..66afbfd107 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py @@ -1,8 +1,8 @@ -# -*- coding:utf-8 -*- from __future__ import annotations import json -from typing import TYPE_CHECKING, Generic, Iterator, Mapping +from collections.abc import Iterator, Mapping +from typing import TYPE_CHECKING, Generic import httpx @@ -36,8 +36,7 @@ class StreamResponse(Generic[ResponseT]): return self._stream_chunks.__next__() def __iter__(self) -> Iterator[ResponseT]: - for item in self._stream_chunks: - yield item + yield from self._stream_chunks def __stream__(self) -> Iterator[ResponseT]: @@ -62,7 +61,7 @@ class StreamResponse(Generic[ResponseT]): pass -class Event(object): +class Event: def __init__( self, event: str | None = None, diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py index 78c97af65b..6b610567da 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Iterable, Mapping, TypeVar +from collections.abc import Iterable, Mapping +from typing import TypeVar from ._base_type import NotGiven diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py index bae4197c50..f22f32d251 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from pydantic import BaseModel @@ -19,5 +19,5 @@ class AsyncCompletion(BaseModel): request_id: Optional[str] = None model: Optional[str] = None task_status: str - choices: List[CompletionChoice] + choices: list[CompletionChoice] usage: CompletionUsage \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py index 524e218d39..b2a847c50c 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from pydantic import BaseModel @@ -19,7 +19,7 @@ class CompletionMessageToolCall(BaseModel): class CompletionMessage(BaseModel): content: Optional[str] = None role: str - tool_calls: Optional[List[CompletionMessageToolCall]] = None + tool_calls: Optional[list[CompletionMessageToolCall]] = None class CompletionUsage(BaseModel): @@ -37,7 +37,7 @@ class CompletionChoice(BaseModel): class Completion(BaseModel): model: Optional[str] = None created: Optional[int] = None - choices: List[CompletionChoice] + choices: list[CompletionChoice] request_id: Optional[str] = None id: Optional[str] = None usage: CompletionUsage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py index c2e0c57666..c250699741 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from pydantic import BaseModel @@ -32,7 +32,7 @@ class ChoiceDeltaToolCall(BaseModel): class ChoiceDelta(BaseModel): content: Optional[str] = None role: Optional[str] = None - tool_calls: Optional[List[ChoiceDeltaToolCall]] = None + tool_calls: Optional[list[ChoiceDeltaToolCall]] = None class Choice(BaseModel): @@ -49,7 +49,7 @@ class CompletionUsage(BaseModel): class ChatCompletionChunk(BaseModel): id: Optional[str] = None - choices: List[Choice] + choices: list[Choice] created: Optional[int] = None model: Optional[str] = None usage: Optional[CompletionUsage] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py index a8737cf8dc..e01f2c815f 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Optional +from typing import Optional from pydantic import BaseModel @@ -12,11 +12,11 @@ __all__ = ["Embedding", "EmbeddingsResponded"] class Embedding(BaseModel): object: str index: Optional[int] = None - embedding: List[float] + embedding: list[float] class EmbeddingsResponded(BaseModel): object: str - data: List[Embedding] + data: list[Embedding] model: str usage: CompletionUsage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py index 94db046bd6..917bda7576 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from pydantic import BaseModel @@ -20,5 +20,5 @@ class FileObject(BaseModel): class ListOfFileObject(BaseModel): object: Optional[str] = None - data: List[FileObject] + data: list[FileObject] has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py index c841e1d756..71c00eaff0 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py @@ -1,7 +1,6 @@ -from typing import List, Optional, Union +from typing import Optional, Union from pydantic import BaseModel -from typing_extensions import Literal __all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob" ] @@ -35,7 +34,7 @@ class FineTuningJob(BaseModel): object: Optional[str] = None - result_files: List[str] + result_files: list[str] status: str @@ -48,5 +47,5 @@ class FineTuningJob(BaseModel): class ListOfFineTuningJob(BaseModel): object: Optional[str] = None - data: List[FineTuningJob] + data: list[FineTuningJob] has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py index 1a70483a7b..e26b448534 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py @@ -1,7 +1,6 @@ -from typing import List, Optional, Union +from typing import Optional, Union from pydantic import BaseModel -from typing_extensions import Literal __all__ = ["FineTuningJobEvent", "Metric", "JobEvent"] @@ -32,5 +31,5 @@ class JobEvent(BaseModel): class FineTuningJobEvent(BaseModel): object: Optional[str] = None - data: List[JobEvent] + data: list[JobEvent] has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py index c661f7cdd5..e1ebc352bc 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Union +from typing import Literal, Union -from typing_extensions import Literal, TypedDict +from typing_extensions import TypedDict __all__ = ["Hyperparameters"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py index 429a7e25bc..b352ce0954 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Optional +from typing import Optional from pydantic import BaseModel @@ -15,4 +15,4 @@ class GeneratedImage(BaseModel): class ImagesResponded(BaseModel): created: int - data: List[GeneratedImage] + data: list[GeneratedImage] diff --git a/api/core/model_runtime/utils/_compat.py b/api/core/model_runtime/utils/_compat.py index 305edcac8f..5c34152751 100644 --- a/api/core/model_runtime/utils/_compat.py +++ b/api/core/model_runtime/utils/_compat.py @@ -1,8 +1,7 @@ -from typing import Any +from typing import Any, Literal from pydantic import BaseModel from pydantic.version import VERSION as PYDANTIC_VERSION -from typing_extensions import Literal PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index d0d93c34b9..cf6c98e01a 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -1,13 +1,14 @@ import dataclasses import datetime from collections import defaultdict, deque +from collections.abc import Callable from decimal import Decimal from enum import Enum from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network from pathlib import Path, PurePath from re import Pattern from types import GeneratorType -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Optional, Union from uuid import UUID from pydantic import BaseModel @@ -46,7 +47,7 @@ def decimal_encoder(dec_value: Decimal) -> Union[int, float]: return float(dec_value) -ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { +ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = { bytes: lambda o: o.decode(), Color: str, datetime.date: isoformat, @@ -77,9 +78,9 @@ ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { def generate_encoders_by_class_tuples( - type_encoder_map: Dict[Any, Callable[[Any], Any]] -) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]: - encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict( + type_encoder_map: dict[Any, Callable[[Any], Any]] +) -> dict[Callable[[Any], Any], tuple[Any, ...]]: + encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict( tuple ) for type_, encoder in type_encoder_map.items(): @@ -96,7 +97,7 @@ def jsonable_encoder( exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, - custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None, + custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None, sqlalchemy_safe: bool = True, ) -> Any: custom_encoder = custom_encoder or {} @@ -109,7 +110,7 @@ def jsonable_encoder( return encoder_instance(obj) if isinstance(obj, BaseModel): # TODO: remove when deprecating Pydantic v1 - encoders: Dict[Any, Any] = {} + encoders: dict[Any, Any] = {} if not PYDANTIC_V2: encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined] if custom_encoder: @@ -149,7 +150,7 @@ def jsonable_encoder( return obj.value if isinstance(obj, PurePath): return str(obj) - if isinstance(obj, (str, int, float, type(None))): + if isinstance(obj, str | int | float | type(None)): return obj if isinstance(obj, Decimal): return format(obj, 'f') @@ -184,7 +185,7 @@ def jsonable_encoder( ) encoded_dict[encoded_key] = encoded_value return encoded_dict - if isinstance(obj, (list, set, frozenset, GeneratorType, tuple, deque)): + if isinstance(obj, list | set | frozenset | GeneratorType | tuple | deque): encoded_list = [] for item in obj: encoded_list.append( @@ -209,7 +210,7 @@ def jsonable_encoder( try: data = dict(obj) except Exception as e: - errors: List[Exception] = [] + errors: list[Exception] = [] errors.append(e) try: data = vars(obj) diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py index 82b2f27234..9cafbf17a3 100644 --- a/api/core/moderation/api/api.py +++ b/api/core/moderation/api/api.py @@ -1,9 +1,10 @@ +from pydantic import BaseModel + from core.extension.api_based_extension_requestor import APIBasedExtensionPoint, APIBasedExtensionRequestor from core.helper.encrypter import decrypt_token from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult from extensions.ext_database import db from models.api_based_extension import APIBasedExtension -from pydantic import BaseModel class ModerationInputParams(BaseModel): diff --git a/api/core/moderation/base.py b/api/core/moderation/base.py index 1cce8f18f2..9a369a9f87 100644 --- a/api/core/moderation/base.py +++ b/api/core/moderation/base.py @@ -2,9 +2,10 @@ from abc import ABC, abstractmethod from enum import Enum from typing import Optional -from core.extension.extensible import Extensible, ExtensionModule from pydantic import BaseModel +from core.extension.extensible import Extensible, ExtensionModule + class ModerationAction(Enum): DIRECT_OUTPUT = 'direct_output' diff --git a/api/core/prompt/output_parser/rule_config_generator.py b/api/core/prompt/output_parser/rule_config_generator.py index 61165d628e..619555ce2e 100644 --- a/api/core/prompt/output_parser/rule_config_generator.py +++ b/api/core/prompt/output_parser/rule_config_generator.py @@ -1,7 +1,8 @@ from typing import Any -from core.prompt.prompts import RULE_CONFIG_GENERATE_TEMPLATE from langchain.schema import BaseOutputParser, OutputParserException + +from core.prompt.prompts import RULE_CONFIG_GENERATE_TEMPLATE from libs.json_in_md_parser import parse_and_check_json_markdown @@ -18,11 +19,11 @@ class RuleConfigGeneratorOutputParser(BaseOutputParser): raise ValueError("Expected 'prompt' to be a string.") if not isinstance(parsed["variables"], list): raise ValueError( - f"Expected 'variables' to be a list." + "Expected 'variables' to be a list." ) if not isinstance(parsed["opening_statement"], str): raise ValueError( - f"Expected 'opening_statement' to be a str." + "Expected 'opening_statement' to be a str." ) return parsed except Exception as e: diff --git a/api/core/prompt/output_parser/suggested_questions_after_answer.py b/api/core/prompt/output_parser/suggested_questions_after_answer.py index 49501a2dd7..d8bb0809cf 100644 --- a/api/core/prompt/output_parser/suggested_questions_after_answer.py +++ b/api/core/prompt/output_parser/suggested_questions_after_answer.py @@ -2,9 +2,10 @@ import json import re from typing import Any +from langchain.schema import BaseOutputParser + from core.model_runtime.errors.invoke import InvokeError from core.prompt.prompts import SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT -from langchain.schema import BaseOutputParser class SuggestedQuestionsAfterAnswerOutputParser(BaseOutputParser): diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 01cad0c1d4..0a373b7c42 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -2,15 +2,23 @@ import enum import json import os import re -from typing import List, Optional, Tuple, cast +from typing import Optional, cast -from core.entities.application_entities import (AdvancedCompletionPromptTemplateEntity, ModelConfigEntity, - PromptTemplateEntity) +from core.entities.application_entities import ( + AdvancedCompletionPromptTemplateEntity, + ModelConfigEntity, + PromptTemplateEntity, +) from core.file.file_obj import FileObj from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessage, PromptMessageRole, - SystemPromptMessage, TextPromptMessageContent, - UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageRole, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.entities.model_entities import ModelPropertyKey from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.prompt_builder import PromptBuilder @@ -59,11 +67,11 @@ class PromptTransform: prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, - files: List[FileObj], + files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigEntity) -> \ - Tuple[List[PromptMessage], Optional[List[str]]]: + tuple[list[PromptMessage], Optional[list[str]]]: app_mode = AppMode.value_of(app_mode) model_mode = ModelMode.value_of(model_config.mode) @@ -107,10 +115,10 @@ class PromptTransform: prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, - files: List[FileObj], + files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> List[PromptMessage]: + model_config: ModelConfigEntity) -> list[PromptMessage]: app_mode = AppMode.value_of(app_mode) model_mode = ModelMode.value_of(model_config.mode) @@ -174,7 +182,7 @@ class PromptTransform: ) def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory, - max_token_limit: int) -> List[PromptMessage]: + max_token_limit: int) -> list[PromptMessage]: """Get memory messages.""" return memory.get_history_prompt_messages( max_token_limit=max_token_limit @@ -209,7 +217,7 @@ class PromptTransform: json_file_path = os.path.join(prompt_path, f'{prompt_name}.json') # Open the JSON file and read its content - with open(json_file_path, 'r', encoding='utf-8') as json_file: + with open(json_file_path, encoding='utf-8') as json_file: return json.load(json_file) def _get_simple_chat_app_chat_model_prompt_messages(self, prompt_rules: dict, @@ -217,9 +225,9 @@ class PromptTransform: inputs: dict, query: str, context: Optional[str], - files: List[FileObj], + files: list[FileObj], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> List[PromptMessage]: + model_config: ModelConfigEntity) -> list[PromptMessage]: prompt_messages = [] context_prompt_content = '' @@ -272,8 +280,8 @@ class PromptTransform: query: str, context: Optional[str], memory: Optional[TokenBufferMemory], - files: List[FileObj], - model_config: ModelConfigEntity) -> List[PromptMessage]: + files: list[FileObj], + model_config: ModelConfigEntity) -> list[PromptMessage]: context_prompt_content = '' if context and 'context_prompt' in prompt_rules: prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt']) @@ -443,10 +451,10 @@ class PromptTransform: prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, - files: List[FileObj], + files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> List[PromptMessage]: + model_config: ModelConfigEntity) -> list[PromptMessage]: raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt role_prefix = prompt_template_entity.advanced_completion_prompt_template.role_prefix @@ -486,10 +494,10 @@ class PromptTransform: prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, - files: List[FileObj], + files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> List[PromptMessage]: + model_config: ModelConfigEntity) -> list[PromptMessage]: raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages prompt_messages = [] @@ -527,7 +535,7 @@ class PromptTransform: def _get_completion_app_completion_model_prompt_messages(self, prompt_template_entity: PromptTemplateEntity, inputs: dict, - context: Optional[str]) -> List[PromptMessage]: + context: Optional[str]) -> list[PromptMessage]: raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt prompt_messages = [] @@ -546,8 +554,8 @@ class PromptTransform: def _get_completion_app_chat_model_prompt_messages(self, prompt_template_entity: PromptTemplateEntity, inputs: dict, - files: List[FileObj], - context: Optional[str]) -> List[PromptMessage]: + files: list[FileObj], + context: Optional[str]) -> list[PromptMessage]: raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages prompt_messages = [] diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 1c505823d1..6e28247d38 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -3,21 +3,36 @@ from collections import defaultdict from json import JSONDecodeError from typing import Optional +from sqlalchemy.exc import IntegrityError + from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations, ProviderModelBundle -from core.entities.provider_entities import (CustomConfiguration, CustomModelConfiguration, CustomProviderConfiguration, - QuotaConfiguration, SystemConfiguration) +from core.entities.provider_entities import ( + CustomConfiguration, + CustomModelConfiguration, + CustomProviderConfiguration, + QuotaConfiguration, + SystemConfiguration, +) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import (ConfigurateMethod, CredentialFormSchema, FormType, - ProviderEntity) +from core.model_runtime.entities.provider_entities import ( + CredentialFormSchema, + FormType, + ProviderEntity, +) from core.model_runtime.model_providers import model_provider_factory from extensions import ext_hosting_provider from extensions.ext_database import db -from models.provider import (Provider, ProviderModel, ProviderQuotaType, ProviderType, TenantDefaultModel, - TenantPreferredModelProvider) -from sqlalchemy.exc import IntegrityError +from models.provider import ( + Provider, + ProviderModel, + ProviderQuotaType, + ProviderType, + TenantDefaultModel, + TenantPreferredModelProvider, +) class ProviderManager: diff --git a/api/core/rerank/rerank.py b/api/core/rerank/rerank.py index 4d2f84b492..a675dfc568 100644 --- a/api/core/rerank/rerank.py +++ b/api/core/rerank/rerank.py @@ -1,15 +1,16 @@ -from typing import List, Optional +from typing import Optional + +from langchain.schema import Document from core.model_manager import ModelInstance -from langchain.schema import Document class RerankRunner: def __init__(self, rerank_model_instance: ModelInstance) -> None: self.rerank_model_instance = rerank_model_instance - def run(self, query: str, documents: List[Document], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, user: Optional[str] = None) -> List[Document]: + def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, + top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: """ Run rerank model :param query: search query diff --git a/api/core/spiltter/fixed_text_splitter.py b/api/core/splitter/fixed_text_splitter.py similarity index 87% rename from api/core/spiltter/fixed_text_splitter.py rename to api/core/splitter/fixed_text_splitter.py index a6895998cf..285a7ba14e 100644 --- a/api/core/spiltter/fixed_text_splitter.py +++ b/api/core/splitter/fixed_text_splitter.py @@ -1,13 +1,22 @@ """Functionality for splitting text.""" from __future__ import annotations -from typing import Any, List, Optional, cast +from typing import Any, Optional, cast + +from langchain.text_splitter import ( + TS, + AbstractSet, + Collection, + Literal, + RecursiveCharacterTextSplitter, + TokenTextSplitter, + Type, + Union, +) from core.model_manager import ModelInstance from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer -from langchain.text_splitter import (TS, AbstractSet, Collection, Literal, RecursiveCharacterTextSplitter, - TokenTextSplitter, Type, Union) class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): @@ -19,8 +28,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): def from_encoder( cls: Type[TS], embedding_model_instance: Optional[ModelInstance], - allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), - disallowed_special: Union[Literal["all"], Collection[str]] = "all", + allowed_special: Union[Literal[all], AbstractSet[str]] = set(), + disallowed_special: Union[Literal[all], Collection[str]] = "all", **kwargs: Any, ): def _token_encoder(text: str) -> int: @@ -50,13 +59,13 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter): - def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any): + def __init__(self, fixed_separator: str = "\n\n", separators: Optional[list[str]] = None, **kwargs: Any): """Create a new TextSplitter.""" super().__init__(**kwargs) self._fixed_separator = fixed_separator self._separators = separators or ["\n\n", "\n", " ", ""] - def split_text(self, text: str) -> List[str]: + def split_text(self, text: str) -> list[str]: """Split incoming text and return chunks.""" if self._fixed_separator: chunks = text.split(self._fixed_separator) @@ -72,7 +81,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) return final_chunks - def recursive_split_text(self, text: str) -> List[str]: + def recursive_split_text(self, text: str) -> list[str]: """Split incoming text and return chunks.""" final_chunks = [] # Get appropriate separator to use diff --git a/api/core/third_party/langchain/llms/fake.py b/api/core/third_party/langchain/llms/fake.py index 64117477e1..ab5152b38d 100644 --- a/api/core/third_party/langchain/llms/fake.py +++ b/api/core/third_party/langchain/llms/fake.py @@ -1,5 +1,6 @@ import time -from typing import Any, List, Mapping, Optional +from collections.abc import Mapping +from typing import Any, Optional from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.chat_models.base import SimpleChatModel @@ -19,8 +20,8 @@ class FakeLLM(SimpleChatModel): def _call( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: @@ -36,8 +37,8 @@ class FakeLLM(SimpleChatModel): def _generate( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: diff --git a/api/core/third_party/spark/spark_llm.py b/api/core/third_party/spark/spark_llm.py index ff7f04c396..5c97bba530 100644 --- a/api/core/third_party/spark/spark_llm.py +++ b/api/core/third_party/spark/spark_llm.py @@ -1,5 +1,4 @@ import base64 -import datetime import hashlib import hmac import json diff --git a/api/core/tool/current_datetime_tool.py b/api/core/tool/current_datetime_tool.py index 3bb2bb5eaa..208490a5bf 100644 --- a/api/core/tool/current_datetime_tool.py +++ b/api/core/tool/current_datetime_tool.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Type from langchain.tools import BaseTool from pydantic import BaseModel, Field @@ -12,7 +11,7 @@ class DatetimeToolInput(BaseModel): class DatetimeTool(BaseTool): """Tool for querying current datetime.""" name: str = "current_datetime" - args_schema: Type[BaseModel] = DatetimeToolInput + args_schema: type[BaseModel] = DatetimeToolInput description: str = "A tool when you want to get the current date, time, week, month or year, " \ "and the time zone is UTC. Result is \"