diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index ab585a5ae9..e7a8c98d26 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,5 +1,5 @@ -FROM mcr.microsoft.com/devcontainers/python:3.10 +FROM mcr.microsoft.com/devcontainers/python:3.12 # [Optional] Uncomment this section to install additional OS packages. # RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ -# && apt-get -y install --no-install-recommends \ No newline at end of file +# && apt-get -y install --no-install-recommends diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index ebc8bf74c1..339ad60ce0 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,7 +1,7 @@ // For format details, see https://aka.ms/devcontainer.json. For config options, see the // README at: https://github.com/devcontainers/templates/tree/main/src/anaconda { - "name": "Python 3.10", + "name": "Python 3.12", "build": { "context": "..", "dockerfile": "Dockerfile" diff --git a/.github/actions/setup-poetry/action.yml b/.github/actions/setup-poetry/action.yml index 5feab33d1d..2e76676f37 100644 --- a/.github/actions/setup-poetry/action.yml +++ b/.github/actions/setup-poetry/action.yml @@ -4,7 +4,7 @@ inputs: python-version: description: Python version to use and the Poetry installed with required: true - default: '3.10' + default: '3.11' poetry-version: description: Poetry version to set up required: true diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 76e844aaad..e1c0bf33a4 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -20,7 +20,6 @@ jobs: strategy: matrix: python-version: - - "3.10" - "3.11" - "3.12" diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index caddd23bab..73af370063 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -8,6 +8,8 @@ on: - api/core/rag/datasource/** - docker/** - .github/workflows/vdb-tests.yml + - api/poetry.lock + - api/pyproject.toml concurrency: group: vdb-tests-${{ github.head_ref || github.run_id }} @@ -20,7 +22,6 @@ jobs: strategy: matrix: python-version: - - "3.10" - "3.11" - "3.12" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index da2928d189..22261804fc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,6 +1,8 @@ +# CONTRIBUTING + So you're looking to contribute to Dify - that's awesome, we can't wait to see what you do. As a startup with limited headcount and funding, we have grand ambitions to design the most intuitive workflow for building and managing LLM applications. Any help from the community counts, truly. -We need to be nimble and ship fast given where we are, but we also want to make sure that contributors like you get as smooth an experience at contributing as possible. We've assembled this contribution guide for that purpose, aiming at getting you familiarized with the codebase & how we work with contributors, so you could quickly jump to the fun part. +We need to be nimble and ship fast given where we are, but we also want to make sure that contributors like you get as smooth an experience at contributing as possible. We've assembled this contribution guide for that purpose, aiming at getting you familiarized with the codebase & how we work with contributors, so you could quickly jump to the fun part. This guide, like Dify itself, is a constant work in progress. We highly appreciate your understanding if at times it lags behind the actual project, and welcome any feedback for us to improve. @@ -10,14 +12,12 @@ In terms of licensing, please take a minute to read our short [License and Contr [Find](https://github.com/langgenius/dify/issues?q=is:issue+is:open) an existing issue, or [open](https://github.com/langgenius/dify/issues/new/choose) a new one. We categorize issues into 2 types: -### Feature requests: +### Feature requests * If you're opening a new feature request, we'd like you to explain what the proposed feature achieves, and include as much context as possible. [@perzeusss](https://github.com/perzeuss) has made a solid [Feature Request Copilot](https://udify.app/chat/MK2kVSnw1gakVwMX) that helps you draft out your needs. Feel free to give it a try. * If you want to pick one up from the existing issues, simply drop a comment below it saying so. - - A team member working in the related direction will be looped in. If all looks good, they will give the go-ahead for you to start coding. We ask that you hold off working on the feature until then, so none of your work goes to waste should we propose changes. Depending on whichever area the proposed feature falls under, you might talk to different team members. Here's rundown of the areas each our team members are working on at the moment: @@ -40,7 +40,7 @@ In terms of licensing, please take a minute to read our short [License and Contr | Non-core features and minor enhancements | Low Priority | | Valuable but not immediate | Future-Feature | -### Anything else (e.g. bug report, performance optimization, typo correction): +### Anything else (e.g. bug report, performance optimization, typo correction) * Start coding right away. @@ -52,7 +52,6 @@ In terms of licensing, please take a minute to read our short [License and Contr | Non-critical bugs, performance boosts | Medium Priority | | Minor fixes (typos, confusing but working UI) | Low Priority | - ## Installing Here are the steps to set up Dify for development: @@ -63,7 +62,7 @@ Here are the steps to set up Dify for development: Clone the forked repository from your terminal: -``` +```shell git clone git@github.com:/dify.git ``` @@ -71,11 +70,11 @@ git clone git@github.com:/dify.git Dify requires the following dependencies to build, make sure they're installed on your system: -- [Docker](https://www.docker.com/) -- [Docker Compose](https://docs.docker.com/compose/install/) -- [Node.js v18.x (LTS)](http://nodejs.org) -- [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/) -- [Python](https://www.python.org/) version 3.10.x +* [Docker](https://www.docker.com/) +* [Docker Compose](https://docs.docker.com/compose/install/) +* [Node.js v18.x (LTS)](http://nodejs.org) +* [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/) +* [Python](https://www.python.org/) version 3.11.x or 3.12.x ### 4. Installations @@ -85,7 +84,7 @@ Check the [installation FAQ](https://docs.dify.ai/learn-more/faq/install-faq) fo ### 5. Visit dify in your browser -To validate your set up, head over to [http://localhost:3000](http://localhost:3000) (the default, or your self-configured URL and port) in your browser. You should now see Dify up and running. +To validate your set up, head over to [http://localhost:3000](http://localhost:3000) (the default, or your self-configured URL and port) in your browser. You should now see Dify up and running. ## Developing @@ -97,9 +96,9 @@ To help you quickly navigate where your contribution fits, a brief, annotated ou ### Backend -Dify’s backend is written in Python using [Flask](https://flask.palletsprojects.com/en/3.0.x/). It uses [SQLAlchemy](https://www.sqlalchemy.org/) for ORM and [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) for task queueing. Authorization logic goes via Flask-login. +Dify’s backend is written in Python using [Flask](https://flask.palletsprojects.com/en/3.0.x/). It uses [SQLAlchemy](https://www.sqlalchemy.org/) for ORM and [Celery](https://docs.celeryq.dev/en/stable/getting-started/introduction.html) for task queueing. Authorization logic goes via Flask-login. -``` +```text [api/] ├── constants // Constant settings used throughout code base. ├── controllers // API route definitions and request handling logic. @@ -121,7 +120,7 @@ Dify’s backend is written in Python using [Flask](https://flask.palletsproject The website is bootstrapped on [Next.js](https://nextjs.org/) boilerplate in Typescript and uses [Tailwind CSS](https://tailwindcss.com/) for styling. [React-i18next](https://react.i18next.com/) is used for internationalization. -``` +```text [web/] ├── app // layouts, pages, and components │ ├── (commonLayout) // common layout used throughout the app @@ -149,10 +148,10 @@ The website is bootstrapped on [Next.js](https://nextjs.org/) boilerplate in Typ ## Submitting your PR -At last, time to open a pull request (PR) to our repo. For major features, we first merge them into the `deploy/dev` branch for testing, before they go into the `main` branch. If you run into issues like merge conflicts or don't know how to open a pull request, check out [GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests). +At last, time to open a pull request (PR) to our repo. For major features, we first merge them into the `deploy/dev` branch for testing, before they go into the `main` branch. If you run into issues like merge conflicts or don't know how to open a pull request, check out [GitHub's pull request tutorial](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests). And that's it! Once your PR is merged, you will be featured as a contributor in our [README](https://github.com/langgenius/dify/blob/main/README.md). ## Getting Help -If you ever get stuck or got a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat. +If you ever get stuck or got a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat. diff --git a/CONTRIBUTING_CN.md b/CONTRIBUTING_CN.md index 310c55090a..4fa43b24ee 100644 --- a/CONTRIBUTING_CN.md +++ b/CONTRIBUTING_CN.md @@ -71,7 +71,7 @@ Dify 依赖以下工具和库: - [Docker Compose](https://docs.docker.com/compose/install/) - [Node.js v18.x (LTS)](http://nodejs.org) - [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/) -- [Python](https://www.python.org/) version 3.10.x +- [Python](https://www.python.org/) version 3.11.x or 3.12.x ### 4. 安装 diff --git a/CONTRIBUTING_JA.md b/CONTRIBUTING_JA.md index a68bdeddbc..22e30e9c03 100644 --- a/CONTRIBUTING_JA.md +++ b/CONTRIBUTING_JA.md @@ -74,7 +74,7 @@ Dify を構築するには次の依存関係が必要です。それらがシス - [Docker Compose](https://docs.docker.com/compose/install/) - [Node.js v18.x (LTS)](http://nodejs.org) - [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/) -- [Python](https://www.python.org/) version 3.10.x +- [Python](https://www.python.org/) version 3.11.x or 3.12.x ### 4. インストール diff --git a/CONTRIBUTING_VI.md b/CONTRIBUTING_VI.md index a77239ff38..ad41f51aeb 100644 --- a/CONTRIBUTING_VI.md +++ b/CONTRIBUTING_VI.md @@ -73,7 +73,7 @@ Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đ - [Docker Compose](https://docs.docker.com/compose/install/) - [Node.js v18.x (LTS)](http://nodejs.org) - [npm](https://www.npmjs.com/) phiên bản 8.x.x hoặc [Yarn](https://yarnpkg.com/) -- [Python](https://www.python.org/) phiên bản 3.10.x +- [Python](https://www.python.org/) phiên bản 3.11.x hoặc 3.12.x ### 4. Cài đặt @@ -153,4 +153,4 @@ Và thế là xong! Khi PR của bạn được merge, bạn sẽ được giớ ## Nhận trợ giúp -Nếu bạn gặp khó khăn hoặc có câu hỏi cấp bách trong quá trình đóng góp, hãy đặt câu hỏi của bạn trong vấn đề GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh chóng. \ No newline at end of file +Nếu bạn gặp khó khăn hoặc có câu hỏi cấp bách trong quá trình đóng góp, hãy đặt câu hỏi của bạn trong vấn đề GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh chóng. diff --git a/api/.env.example b/api/.env.example index 1a242a3daa..f8a2812563 100644 --- a/api/.env.example +++ b/api/.env.example @@ -42,6 +42,11 @@ REDIS_SENTINEL_USERNAME= REDIS_SENTINEL_PASSWORD= REDIS_SENTINEL_SOCKET_TIMEOUT=0.1 +# redis Cluster configuration. +REDIS_USE_CLUSTERS=false +REDIS_CLUSTERS= +REDIS_CLUSTERS_PASSWORD= + # PostgreSQL database configuration DB_USERNAME=postgres DB_PASSWORD=difyai123456 diff --git a/api/Dockerfile b/api/Dockerfile index 175535b188..e7b64f1107 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -1,5 +1,5 @@ # base image -FROM python:3.10-slim-bookworm AS base +FROM python:3.12-slim-bookworm AS base WORKDIR /app/api diff --git a/api/README.md b/api/README.md index de2baee4c5..461dac4759 100644 --- a/api/README.md +++ b/api/README.md @@ -18,12 +18,17 @@ ``` 2. Copy `.env.example` to `.env` + + ```cli + cp .env.example .env + ``` 3. Generate a `SECRET_KEY` in the `.env` file. + bash for Linux ```bash for Linux sed -i "/^SECRET_KEY=/c\SECRET_KEY=$(openssl rand -base64 42)" .env ``` - + bash for Mac ```bash for Mac secret_key=$(openssl rand -base64 42) sed -i '' "/^SECRET_KEY=/c\\ @@ -37,18 +42,10 @@ 5. Install dependencies ```bash - poetry env use 3.10 + poetry env use 3.12 poetry install ``` - In case of contributors missing to update dependencies for `pyproject.toml`, you can perform the following shell instead. - - ```bash - poetry shell # activate current environment - poetry add $(cat requirements.txt) # install dependencies of production and update pyproject.toml - poetry add $(cat requirements-dev.txt) --group dev # install dependencies of development and update pyproject.toml - ``` - 6. Run migrate Before the first launch, migrate the database to the latest version. @@ -84,5 +81,3 @@ ```bash poetry run -C api bash dev/pytest/pytest_all_tests.sh ``` - - diff --git a/api/app.py b/api/app.py index a667a84fd6..c1acb8bd0d 100644 --- a/api/app.py +++ b/api/app.py @@ -1,6 +1,11 @@ import os import sys +python_version = sys.version_info +if not ((3, 11) <= python_version < (3, 13)): + print(f"Python 3.11 or 3.12 is required, current version is {python_version.major}.{python_version.minor}") + raise SystemExit(1) + from configs import dify_config if not dify_config.DEBUG: @@ -30,9 +35,6 @@ from models import account, dataset, model, source, task, tool, tools, web # no # DO NOT REMOVE ABOVE -if sys.version_info[:2] == (3, 10): - print("Warning: Python 3.10 will not be supported in the next version.") - warnings.simplefilter("ignore", ResourceWarning) diff --git a/api/configs/app_config.py b/api/configs/app_config.py index 61de73c868..07ef6121cc 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -27,7 +27,6 @@ class DifyConfig( # read from dotenv format config file env_file=".env", env_file_encoding="utf-8", - frozen=True, # ignore extra attributes extra="ignore", ) diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 26b9b1347c..2e98c31ec3 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -68,3 +68,18 @@ class RedisConfig(BaseSettings): description="Socket timeout in seconds for Redis Sentinel connections", default=0.1, ) + + REDIS_USE_CLUSTERS: bool = Field( + description="Enable Redis Clusters mode for high availability", + default=False, + ) + + REDIS_CLUSTERS: Optional[str] = Field( + description="Comma-separated list of Redis Clusters nodes (host:port)", + default=None, + ) + + REDIS_CLUSTERS_PASSWORD: Optional[str] = Field( + description="Password for Redis Clusters authentication (if required)", + default=None, + ) diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index 1f2b8224e8..7e95e79bfb 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): CURRENT_VERSION: str = Field( description="Dify version", - default="0.11.2", + default="0.12.0", ) COMMIT_SHA: str = Field( diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 8a5c2e5b8f..f46d5b6b13 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -2,6 +2,7 @@ from flask import Blueprint from libs.external_api import ExternalApi +from .app.app_import import AppImportApi, AppImportConfirmApi from .files import FileApi, FilePreviewApi, FileSupportTypeApi from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi @@ -17,6 +18,10 @@ api.add_resource(FileSupportTypeApi, "/files/support-type") api.add_resource(RemoteFileInfoApi, "/remote-files/") api.add_resource(RemoteFileUploadApi, "/remote-files/upload") +# Import App +api.add_resource(AppImportApi, "/apps/imports") +api.add_resource(AppImportConfirmApi, "/apps/imports//confirm") + # Import other controllers from . import admin, apikey, extension, feature, ping, setup, version diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 5a4cd7684f..da72b704c7 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,7 +1,10 @@ import uuid +from typing import cast from flask_login import current_user from flask_restful import Resource, inputs, marshal, marshal_with, reqparse +from sqlalchemy import select +from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, abort from controllers.console import api @@ -13,13 +16,15 @@ from controllers.console.wraps import ( setup_required, ) from core.ops.ops_trace_manager import OpsTraceManager +from extensions.ext_database import db from fields.app_fields import ( app_detail_fields, app_detail_fields_with_site, app_pagination_fields, ) from libs.login import login_required -from services.app_dsl_service import AppDslService +from models import Account, App +from services.app_dsl_service import AppDslService, ImportMode from services.app_service import AppService ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] @@ -92,61 +97,6 @@ class AppListApi(Resource): return app, 201 -class AppImportApi(Resource): - @setup_required - @login_required - @account_initialization_required - @marshal_with(app_detail_fields_with_site) - @cloud_edition_billing_resource_check("apps") - def post(self): - """Import app""" - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - - parser = reqparse.RequestParser() - parser.add_argument("data", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=str, location="json") - parser.add_argument("description", type=str, location="json") - parser.add_argument("icon_type", type=str, location="json") - parser.add_argument("icon", type=str, location="json") - parser.add_argument("icon_background", type=str, location="json") - args = parser.parse_args() - - app = AppDslService.import_and_create_new_app( - tenant_id=current_user.current_tenant_id, data=args["data"], args=args, account=current_user - ) - - return app, 201 - - -class AppImportFromUrlApi(Resource): - @setup_required - @login_required - @account_initialization_required - @marshal_with(app_detail_fields_with_site) - @cloud_edition_billing_resource_check("apps") - def post(self): - """Import app from url""" - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - - parser = reqparse.RequestParser() - parser.add_argument("url", type=str, required=True, nullable=False, location="json") - parser.add_argument("name", type=str, location="json") - parser.add_argument("description", type=str, location="json") - parser.add_argument("icon", type=str, location="json") - parser.add_argument("icon_background", type=str, location="json") - args = parser.parse_args() - - app = AppDslService.import_and_create_new_app_from_url( - tenant_id=current_user.current_tenant_id, url=args["url"], args=args, account=current_user - ) - - return app, 201 - - class AppApi(Resource): @setup_required @login_required @@ -224,10 +174,24 @@ class AppCopyApi(Resource): parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() - data = AppDslService.export_dsl(app_model=app_model, include_secret=True) - app = AppDslService.import_and_create_new_app( - tenant_id=current_user.current_tenant_id, data=data, args=args, account=current_user - ) + with Session(db.engine) as session: + import_service = AppDslService(session) + yaml_content = import_service.export_dsl(app_model=app_model, include_secret=True) + account = cast(Account, current_user) + result = import_service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT.value, + yaml_content=yaml_content, + name=args.get("name"), + description=args.get("description"), + icon_type=args.get("icon_type"), + icon=args.get("icon"), + icon_background=args.get("icon_background"), + ) + session.commit() + + stmt = select(App).where(App.id == result.app_id) + app = session.scalar(stmt) return app, 201 @@ -368,8 +332,6 @@ class AppTraceApi(Resource): api.add_resource(AppListApi, "/apps") -api.add_resource(AppImportApi, "/apps/import") -api.add_resource(AppImportFromUrlApi, "/apps/import/url") api.add_resource(AppApi, "/apps/") api.add_resource(AppCopyApi, "/apps//copy") api.add_resource(AppExportApi, "/apps//export") diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py new file mode 100644 index 0000000000..244dcd75de --- /dev/null +++ b/api/controllers/console/app/app_import.py @@ -0,0 +1,90 @@ +from typing import cast + +from flask_login import current_user +from flask_restful import Resource, marshal_with, reqparse +from sqlalchemy.orm import Session +from werkzeug.exceptions import Forbidden + +from controllers.console.wraps import ( + account_initialization_required, + setup_required, +) +from extensions.ext_database import db +from fields.app_fields import app_import_fields +from libs.login import login_required +from models import Account +from services.app_dsl_service import AppDslService, ImportStatus + + +class AppImportApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(app_import_fields) + def post(self): + # Check user role first + if not current_user.is_editor: + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("mode", type=str, required=True, location="json") + parser.add_argument("yaml_content", type=str, location="json") + parser.add_argument("yaml_url", type=str, location="json") + parser.add_argument("name", type=str, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") + parser.add_argument("app_id", type=str, location="json") + args = parser.parse_args() + + # Create service with session + with Session(db.engine) as session: + import_service = AppDslService(session) + # Import app + account = cast(Account, current_user) + result = import_service.import_app( + account=account, + import_mode=args["mode"], + yaml_content=args.get("yaml_content"), + yaml_url=args.get("yaml_url"), + name=args.get("name"), + description=args.get("description"), + icon_type=args.get("icon_type"), + icon=args.get("icon"), + icon_background=args.get("icon_background"), + app_id=args.get("app_id"), + ) + session.commit() + + # Return appropriate status code based on result + status = result.status + if status == ImportStatus.FAILED.value: + return result.model_dump(mode="json"), 400 + elif status == ImportStatus.PENDING.value: + return result.model_dump(mode="json"), 202 + return result.model_dump(mode="json"), 200 + + +class AppImportConfirmApi(Resource): + @setup_required + @login_required + @account_initialization_required + @marshal_with(app_import_fields) + def post(self, import_id): + # Check user role first + if not current_user.is_editor: + raise Forbidden() + + # Create service with session + with Session(db.engine) as session: + import_service = AppDslService(session) + # Confirm import + account = cast(Account, current_user) + result = import_service.confirm_import(import_id=import_id, account=account) + session.commit() + + # Return appropriate status code based on result + if result.status == ImportStatus.FAILED.value: + return result.model_dump(mode="json"), 400 + return result.model_dump(mode="json"), 200 diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 7b78f622b9..a25004be4d 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime import pytz from flask_login import current_user @@ -314,7 +314,7 @@ def _get_conversation(app_model, conversation_id): raise NotFound("Conversation Not Exists.") if not conversation.read_at: - conversation.read_at = datetime.now(timezone.utc).replace(tzinfo=None) + conversation.read_at = datetime.now(UTC).replace(tzinfo=None) conversation.read_account_id = current_user.id db.session.commit() diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 2f5645852f..407f689819 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from flask_login import current_user from flask_restful import Resource, marshal_with, reqparse @@ -75,7 +75,7 @@ class AppSite(Resource): setattr(site, attr_name, value) site.updated_by = current_user.id - site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + site.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return site @@ -99,7 +99,7 @@ class AppSiteAccessTokenReset(Resource): site.code = Site.generate_code(16) site.updated_by = current_user.id - site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + site.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return site diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index f7027fb226..cc05a0d509 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -20,7 +20,6 @@ from libs.helper import TimestampField, uuid_value from libs.login import current_user, login_required from models import App from models.model import AppMode -from services.app_dsl_service import AppDslService from services.app_generate_service import AppGenerateService from services.errors.app import WorkflowHashNotEqualError from services.workflow_service import WorkflowService @@ -126,31 +125,6 @@ class DraftWorkflowApi(Resource): } -class DraftWorkflowImportApi(Resource): - @setup_required - @login_required - @account_initialization_required - @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) - @marshal_with(workflow_fields) - def post(self, app_model: App): - """ - Import draft workflow - """ - # The role of the current user in the ta table must be admin, owner, or editor - if not current_user.is_editor: - raise Forbidden() - - parser = reqparse.RequestParser() - parser.add_argument("data", type=str, required=True, nullable=False, location="json") - args = parser.parse_args() - - workflow = AppDslService.import_and_overwrite_workflow( - app_model=app_model, data=args["data"], account=current_user - ) - - return workflow - - class AdvancedChatDraftWorkflowRunApi(Resource): @setup_required @login_required @@ -453,7 +427,6 @@ class ConvertToWorkflowApi(Resource): api.add_resource(DraftWorkflowApi, "/apps//workflows/draft") -api.add_resource(DraftWorkflowImportApi, "/apps//workflows/draft/import") api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps//advanced-chat/workflows/draft/run") api.add_resource(DraftWorkflowRunApi, "/apps//workflows/draft/run") api.add_resource(WorkflowTaskStopApi, "/apps//workflow-runs/tasks//stop") diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index be353cefac..d2aa7c903b 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -65,7 +65,7 @@ class ActivateApi(Resource): account.timezone = args["timezone"] account.interface_theme = "light" account.status = AccountStatus.ACTIVE.value - account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index d27e3353c9..f53c28e2ec 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -1,5 +1,5 @@ import logging -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Optional import requests @@ -106,7 +106,7 @@ class OAuthCallback(Resource): if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value - account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) + account.initialized_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() try: diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index ef1e87905a..278295ca39 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -83,7 +83,7 @@ class DataSourceApi(Resource): if action == "enable": if data_source_binding.disabled: data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(data_source_binding) db.session.commit() else: @@ -92,7 +92,7 @@ class DataSourceApi(Resource): if action == "disable": if not data_source_binding.disabled: data_source_binding.disabled = True - data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(data_source_binding) db.session.commit() else: diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index f38408525a..f20261abc2 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,6 +1,6 @@ import logging from argparse import ArgumentTypeError -from datetime import datetime, timezone +from datetime import UTC, datetime from flask import request from flask_login import current_user @@ -665,7 +665,7 @@ class DocumentProcessingApi(DocumentResource): raise InvalidActionError("Document not in indexing state.") document.paused_by = current_user.id - document.paused_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.paused_at = datetime.now(UTC).replace(tzinfo=None) document.is_paused = True db.session.commit() @@ -745,7 +745,7 @@ class DocumentMetadataApi(DocumentResource): document.doc_metadata[key] = value document.doc_type = doc_type - document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return {"result": "success", "message": "Document metadata updated."}, 200 @@ -787,7 +787,7 @@ class DocumentStatusApi(DocumentResource): document.enabled = True document.disabled_at = None document.disabled_by = None - document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() # Set cache to prevent indexing the same document multiple times @@ -804,9 +804,9 @@ class DocumentStatusApi(DocumentResource): raise InvalidActionError("Document already disabled.") document.enabled = False - document.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.disabled_at = datetime.now(UTC).replace(tzinfo=None) document.disabled_by = current_user.id - document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() # Set cache to prevent indexing the same document multiple times @@ -821,9 +821,9 @@ class DocumentStatusApi(DocumentResource): raise InvalidActionError("Document already archived.") document.archived = True - document.archived_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.archived_at = datetime.now(UTC).replace(tzinfo=None) document.archived_by = current_user.id - document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() if document.enabled: @@ -840,7 +840,7 @@ class DocumentStatusApi(DocumentResource): document.archived = False document.archived_at = None document.archived_by = None - document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() # Set cache to prevent indexing the same document multiple times diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 5d8d664e41..6f7ef86d2c 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -1,5 +1,5 @@ import uuid -from datetime import datetime, timezone +from datetime import UTC, datetime import pandas as pd from flask import request @@ -188,7 +188,7 @@ class DatasetDocumentSegmentApi(Resource): raise InvalidActionError("Segment is already disabled.") segment.enabled = False - segment.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None) + segment.disabled_at = datetime.now(UTC).replace(tzinfo=None) segment.disabled_by = current_user.id db.session.commit() diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 125bc1af8c..85c43f8101 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,5 +1,5 @@ import logging -from datetime import datetime, timezone +from datetime import UTC, datetime from flask_login import current_user from flask_restful import reqparse @@ -46,7 +46,7 @@ class CompletionApi(InstalledAppResource): streaming = args["response_mode"] == "streaming" args["auto_generate_name"] = False - installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) + installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() try: @@ -106,7 +106,7 @@ class ChatApi(InstalledAppResource): args["auto_generate_name"] = False - installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) + installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() try: diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index d72715a38c..b60c4e176b 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from flask_login import current_user from flask_restful import Resource, inputs, marshal_with, reqparse @@ -81,7 +81,7 @@ class InstalledAppsListApi(Resource): tenant_id=current_tenant_id, app_owner_tenant_id=app.tenant_id, is_pinned=False, - last_used_at=datetime.now(timezone.utc).replace(tzinfo=None), + last_used_at=datetime.now(UTC).replace(tzinfo=None), ) db.session.add(new_installed_app) db.session.commit() diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 750f65168f..f704783cff 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -60,7 +60,7 @@ class AccountInitApi(Resource): raise InvalidInvitationCodeError() invitation_code.status = "used" - invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + invitation_code.used_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_account_id = account.id @@ -68,7 +68,7 @@ class AccountInitApi(Resource): account.timezone = args["timezone"] account.interface_theme = "light" account.status = "active" - account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() return {"result": "success"} diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index b935b23ed6..2128c4c53f 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from datetime import datetime, timezone +from datetime import UTC, datetime from enum import Enum from functools import wraps from typing import Optional @@ -198,7 +198,7 @@ def validate_and_get_api_token(scope=None): if not api_token: raise Unauthorized("Access token is invalid") - api_token.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) + api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return api_token diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 860ec5de0c..ead293200e 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -2,7 +2,7 @@ import json import logging import uuid from collections.abc import Mapping, Sequence -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Optional, Union, cast from core.agent.entities import AgentEntity, AgentToolEntity @@ -114,16 +114,9 @@ class BaseAgentRunner(AppRunner): # check if model supports stream tool call llm_model = cast(LargeLanguageModel, model_instance.model_type_instance) model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials) - if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []): - self.stream_tool_call = True - else: - self.stream_tool_call = False - - # check if model supports vision - if model_schema and ModelFeature.VISION in (model_schema.features or []): - self.files = application_generate_entity.files - else: - self.files = [] + features = model_schema.features if model_schema and model_schema.features else [] + self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features + self.files = application_generate_entity.files if ModelFeature.VISION in features else [] self.query = None self._current_thoughts: list[PromptMessage] = [] @@ -250,7 +243,7 @@ class BaseAgentRunner(AppRunner): update prompt message tool """ # try to get tool runtime parameters - tool_runtime_parameters = tool.get_runtime_parameters() or [] + tool_runtime_parameters = tool.get_runtime_parameters() for parameter in tool_runtime_parameters: if parameter.form != ToolParameter.ToolParameterForm.LLM: @@ -419,7 +412,7 @@ class BaseAgentRunner(AppRunner): .first() ) - db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None) db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) db.session.commit() db.session.close() diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index a22395b8e3..b9aae7904f 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -1,3 +1,4 @@ +import uuid from typing import Optional from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index a91b9f0f02..cdc82860c6 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -11,7 +11,7 @@ from core.provider_manager import ProviderManager class ModelConfigConverter: @classmethod - def convert(cls, app_config: EasyUIBasedAppConfig, skip_check: bool = False) -> ModelConfigWithCredentialsEntity: + def convert(cls, app_config: EasyUIBasedAppConfig) -> ModelConfigWithCredentialsEntity: """ Convert app model config dict to entity. :param app_config: app config @@ -38,27 +38,23 @@ class ModelConfigConverter: ) if model_credentials is None: - if not skip_check: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - else: - model_credentials = {} + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - if not skip_check: - # check model - provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_config.model, model_type=ModelType.LLM - ) + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=model_config.model, model_type=ModelType.LLM + ) - if provider_model is None: - model_name = model_config.model - raise ValueError(f"Model {model_name} not exist.") + if provider_model is None: + model_name = model_config.model + raise ValueError(f"Model {model_name} not exist.") - if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") - elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") # model config completion_params = model_config.parameters @@ -76,7 +72,7 @@ class ModelConfigConverter: model_schema = model_type_instance.get_model_schema(model_config.model, model_credentials) - if not skip_check and not model_schema: + if not model_schema: raise ValueError(f"Model {model_name} not exist.") return ModelConfigWithCredentialsEntity( diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 82a0e56ce8..fa30511f63 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -1,4 +1,5 @@ from core.app.app_config.entities import ( + AdvancedChatMessageEntity, AdvancedChatPromptTemplateEntity, AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity, @@ -25,7 +26,9 @@ class PromptTemplateConfigManager: chat_prompt_messages = [] for message in chat_prompt_config.get("prompt", []): chat_prompt_messages.append( - {"text": message["text"], "role": PromptMessageRole.value_of(message["role"])} + AdvancedChatMessageEntity( + **{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])} + ) ) advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages) diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 9b72452d7a..15bd353484 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from enum import Enum +from enum import Enum, StrEnum from typing import Any, Optional from pydantic import BaseModel, Field, field_validator @@ -88,7 +88,7 @@ class PromptTemplateEntity(BaseModel): advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None -class VariableEntityType(str, Enum): +class VariableEntityType(StrEnum): TEXT_INPUT = "text-input" SELECT = "select" PARAGRAPH = "paragraph" diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 00e5a74732..ffe56ce410 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -127,7 +127,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation_id=conversation.id if conversation else None, inputs=conversation.inputs if conversation - else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), + else self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index d1564a260e..48ee590e2f 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -134,7 +134,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): conversation_id=conversation.id if conversation else None, inputs=conversation.inputs if conversation - else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), + else self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 2c78d95778..85b7aced55 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional from core.app.app_config.entities import VariableEntityType @@ -6,7 +6,7 @@ from core.file import File, FileUploadConfig from factories import file_factory if TYPE_CHECKING: - from core.app.app_config.entities import AppConfig, VariableEntity + from core.app.app_config.entities import VariableEntity class BaseAppGenerator: @@ -14,23 +14,23 @@ class BaseAppGenerator: self, *, user_inputs: Optional[Mapping[str, Any]], - app_config: "AppConfig", + variables: Sequence["VariableEntity"], + tenant_id: str, ) -> Mapping[str, Any]: user_inputs = user_inputs or {} # Filter input variables from form configuration, handle required fields, default values, and option values - variables = app_config.variables user_inputs = { var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var) for var in variables } user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()} # Convert files in inputs to File - entity_dictionary = {item.variable: item for item in app_config.variables} + entity_dictionary = {item.variable: item for item in variables} # Convert single file to File files_inputs = { k: file_factory.build_from_mapping( mapping=v, - tenant_id=app_config.tenant_id, + tenant_id=tenant_id, config=FileUploadConfig( allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, @@ -44,7 +44,7 @@ class BaseAppGenerator: file_list_inputs = { k: file_factory.build_from_mappings( mappings=v, - tenant_id=app_config.tenant_id, + tenant_id=tenant_id, config=FileUploadConfig( allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index e683dfef3f..5b3efe12eb 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -132,7 +132,9 @@ class ChatAppGenerator(MessageBasedAppGenerator): conversation_id=conversation.id if conversation else None, inputs=conversation.inputs if conversation - else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), + else self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 22ee8b0967..e9e50015bd 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -113,7 +113,9 @@ class CompletionAppGenerator(MessageBasedAppGenerator): app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), file_upload_config=file_extra_config, - inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), query=query, files=file_objs, user_id=user.id, diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index da206f01e7..95ae798ec1 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -1,7 +1,7 @@ import json import logging from collections.abc import Generator -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Optional, Union from sqlalchemy import and_ @@ -200,7 +200,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): db.session.commit() db.session.refresh(conversation) else: - conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() message = Message( diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 65da39b220..31efe43412 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -96,7 +96,9 @@ class WorkflowAppGenerator(BaseAppGenerator): task_id=str(uuid.uuid4()), app_config=app_config, file_upload_config=file_extra_config, - inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), files=system_files, user_id=user.id, stream=stream, diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 2872390d46..1cf72ae79e 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -43,7 +43,6 @@ from core.workflow.graph_engine.entities.event import ( ) from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes import NodeType -from core.workflow.nodes.iteration import IterationNodeData from core.workflow.nodes.node_mapping import node_type_classes_mapping from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db @@ -160,8 +159,6 @@ class WorkflowBasedAppRunner(AppRunner): user_inputs=user_inputs, variable_pool=variable_pool, tenant_id=workflow.tenant_id, - node_type=node_type, - node_data=IterationNodeData(**iteration_node_config.get("data", {})), ) return graph, variable_pool diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 69bc0d7f9e..15543638fc 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,5 +1,5 @@ from datetime import datetime -from enum import Enum +from enum import Enum, StrEnum from typing import Any, Optional from pydantic import BaseModel, field_validator @@ -11,7 +11,7 @@ from core.workflow.nodes import NodeType from core.workflow.nodes.base import BaseNodeData -class QueueEvent(str, Enum): +class QueueEvent(StrEnum): """ QueueEvent enum """ diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 042339969f..d45726af46 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -1,8 +1,9 @@ import json import time from collections.abc import Mapping, Sequence -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any, Optional, Union, cast +from uuid import uuid4 from sqlalchemy.orm import Session @@ -80,38 +81,38 @@ class WorkflowCycleManage: inputs[f"sys.{key.value}"] = value - inputs = WorkflowEntry.handle_special_values(inputs) - triggered_from = ( WorkflowRunTriggeredFrom.DEBUGGING if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN ) - # init workflow run - workflow_run = WorkflowRun() - workflow_run_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID] - if workflow_run_id: - workflow_run.id = workflow_run_id - workflow_run.tenant_id = self._workflow.tenant_id - workflow_run.app_id = self._workflow.app_id - workflow_run.sequence_number = new_sequence_number - workflow_run.workflow_id = self._workflow.id - workflow_run.type = self._workflow.type - workflow_run.triggered_from = triggered_from.value - workflow_run.version = self._workflow.version - workflow_run.graph = self._workflow.graph - workflow_run.inputs = json.dumps(inputs) - workflow_run.status = WorkflowRunStatus.RUNNING.value - workflow_run.created_by_role = ( - CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value - ) - workflow_run.created_by = self._user.id + # handle special values + inputs = WorkflowEntry.handle_special_values(inputs) - db.session.add(workflow_run) - db.session.commit() - db.session.refresh(workflow_run) - db.session.close() + # init workflow run + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = WorkflowRun() + system_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID] + workflow_run.id = system_id or str(uuid4()) + workflow_run.tenant_id = self._workflow.tenant_id + workflow_run.app_id = self._workflow.app_id + workflow_run.sequence_number = new_sequence_number + workflow_run.workflow_id = self._workflow.id + workflow_run.type = self._workflow.type + workflow_run.triggered_from = triggered_from.value + workflow_run.version = self._workflow.version + workflow_run.graph = self._workflow.graph + workflow_run.inputs = json.dumps(inputs) + workflow_run.status = WorkflowRunStatus.RUNNING + workflow_run.created_by_role = ( + CreatedByRole.ACCOUNT if isinstance(self._user, Account) else CreatedByRole.END_USER + ) + workflow_run.created_by = self._user.id + workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) + + session.add(workflow_run) + session.commit() return workflow_run @@ -144,7 +145,7 @@ class WorkflowCycleManage: workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps - workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() db.session.refresh(workflow_run) @@ -191,7 +192,7 @@ class WorkflowCycleManage: workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps - workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() @@ -211,15 +212,18 @@ class WorkflowCycleManage: for workflow_node_execution in running_workflow_node_executions: workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.error = error - workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.elapsed_time = ( workflow_node_execution.finished_at - workflow_node_execution.created_at ).total_seconds() db.session.commit() - db.session.refresh(workflow_run) db.session.close() + with Session(db.engine, expire_on_commit=False) as session: + session.add(workflow_run) + session.refresh(workflow_run) + if trace_manager: trace_manager.add_trace_task( TraceTask( @@ -259,7 +263,7 @@ class WorkflowCycleManage: NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, } ) - workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) session.add(workflow_node_execution) session.commit() @@ -282,7 +286,7 @@ class WorkflowCycleManage: execution_metadata = ( json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None ) - finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( @@ -326,7 +330,7 @@ class WorkflowCycleManage: inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) - finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() execution_metadata = ( json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None @@ -381,7 +385,7 @@ class WorkflowCycleManage: id=workflow_run.id, workflow_id=workflow_run.workflow_id, sequence_number=workflow_run.sequence_number, - inputs=workflow_run.inputs_dict or {}, + inputs=workflow_run.inputs_dict, created_at=int(workflow_run.created_at.timestamp()), ), ) @@ -428,7 +432,7 @@ class WorkflowCycleManage: created_by=created_by, created_at=int(workflow_run.created_at.timestamp()), finished_at=int(workflow_run.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict or {}), + files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict), ), ) @@ -654,7 +658,7 @@ class WorkflowCycleManage: if event.error is None else WorkflowNodeExecutionStatus.FAILED, error=None, - elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(), + elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(), total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, execution_metadata=event.metadata, finished_at=int(time.time()), diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 807f09598c..d1b34db2fe 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -240,7 +240,7 @@ class ProviderConfiguration(BaseModel): if provider_record: provider_record.encrypted_config = json.dumps(credentials) provider_record.is_valid = True - provider_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: provider_record = Provider( @@ -394,7 +394,7 @@ class ProviderConfiguration(BaseModel): if provider_model_record: provider_model_record.encrypted_config = json.dumps(credentials) provider_model_record.is_valid = True - provider_model_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: provider_model_record = ProviderModel( @@ -468,7 +468,7 @@ class ProviderConfiguration(BaseModel): if model_setting: model_setting.enabled = True - model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: model_setting = ProviderModelSetting( @@ -503,7 +503,7 @@ class ProviderConfiguration(BaseModel): if model_setting: model_setting.enabled = False - model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: model_setting = ProviderModelSetting( @@ -570,7 +570,7 @@ class ProviderConfiguration(BaseModel): if model_setting: model_setting.load_balancing_enabled = True - model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: model_setting = ProviderModelSetting( @@ -605,7 +605,7 @@ class ProviderConfiguration(BaseModel): if model_setting: model_setting.load_balancing_enabled = False - model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: model_setting = ProviderModelSetting( diff --git a/api/core/file/enums.py b/api/core/file/enums.py index f4153f1676..06b99d3eb0 100644 --- a/api/core/file/enums.py +++ b/api/core/file/enums.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class FileType(str, Enum): +class FileType(StrEnum): IMAGE = "image" DOCUMENT = "document" AUDIO = "audio" @@ -16,7 +16,7 @@ class FileType(str, Enum): raise ValueError(f"No matching enum found for value '{value}'") -class FileTransferMethod(str, Enum): +class FileTransferMethod(StrEnum): REMOTE_URL = "remote_url" LOCAL_FILE = "local_file" TOOL_FILE = "tool_file" @@ -29,7 +29,7 @@ class FileTransferMethod(str, Enum): raise ValueError(f"No matching enum found for value '{value}'") -class FileBelongsTo(str, Enum): +class FileBelongsTo(StrEnum): USER = "user" ASSISTANT = "assistant" @@ -41,7 +41,7 @@ class FileBelongsTo(str, Enum): raise ValueError(f"No matching enum found for value '{value}'") -class FileAttribute(str, Enum): +class FileAttribute(StrEnum): TYPE = "type" SIZE = "size" NAME = "name" @@ -51,5 +51,5 @@ class FileAttribute(str, Enum): EXTENSION = "extension" -class ArrayFileAttribute(str, Enum): +class ArrayFileAttribute(StrEnum): LENGTH = "length" diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index eb260a8f84..6d8086435d 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -3,7 +3,12 @@ import base64 from configs import dify_config from core.file import file_repository from core.helper import ssrf_proxy -from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent +from core.model_runtime.entities import ( + AudioPromptMessageContent, + DocumentPromptMessageContent, + ImagePromptMessageContent, + VideoPromptMessageContent, +) from extensions.ext_database import db from extensions.ext_storage import storage @@ -29,35 +34,17 @@ def get_attr(*, file: File, attr: FileAttribute): return file.remote_url case FileAttribute.EXTENSION: return file.extension - case _: - raise ValueError(f"Invalid file attribute: {attr}") def to_prompt_message_content( f: File, /, *, - image_detail_config: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.LOW, + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ): - """ - Convert a File object to an ImagePromptMessageContent or AudioPromptMessageContent object. - - This function takes a File object and converts it to an appropriate PromptMessageContent - object, which can be used as a prompt for image or audio-based AI models. - - Args: - f (File): The File object to convert. - detail (Optional[ImagePromptMessageContent.DETAIL]): The detail level for image prompts. - If not provided, defaults to ImagePromptMessageContent.DETAIL.LOW. - - Returns: - Union[ImagePromptMessageContent, AudioPromptMessageContent]: An object containing the file data and detail level - - Raises: - ValueError: If the file type is not supported or if required data is missing. - """ match f.type: case FileType.IMAGE: + image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW if dify_config.MULTIMODAL_SEND_IMAGE_FORMAT == "url": data = _to_url(f) else: @@ -65,7 +52,7 @@ def to_prompt_message_content( return ImagePromptMessageContent(data=data, detail=image_detail_config) case FileType.AUDIO: - encoded_string = _file_to_encoded_string(f) + encoded_string = _get_encoded_string(f) if f.extension is None: raise ValueError("Missing file extension") return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip(".")) @@ -74,9 +61,20 @@ def to_prompt_message_content( data = _to_url(f) else: data = _to_base64_data_string(f) + if f.extension is None: + raise ValueError("Missing file extension") return VideoPromptMessageContent(data=data, format=f.extension.lstrip(".")) + case FileType.DOCUMENT: + data = _get_encoded_string(f) + if f.mime_type is None: + raise ValueError("Missing file mime_type") + return DocumentPromptMessageContent( + encode_format="base64", + mime_type=f.mime_type, + data=data, + ) case _: - raise ValueError("file type f.type is not supported") + raise ValueError(f"file type {f.type} is not supported") def download(f: File, /): @@ -118,21 +116,16 @@ def _get_encoded_string(f: File, /): case FileTransferMethod.REMOTE_URL: response = ssrf_proxy.get(f.remote_url, follow_redirects=True) response.raise_for_status() - content = response.content - encoded_string = base64.b64encode(content).decode("utf-8") - return encoded_string + data = response.content case FileTransferMethod.LOCAL_FILE: upload_file = file_repository.get_upload_file(session=db.session(), file=f) data = _download_file_content(upload_file.key) - encoded_string = base64.b64encode(data).decode("utf-8") - return encoded_string case FileTransferMethod.TOOL_FILE: tool_file = file_repository.get_tool_file(session=db.session(), file=f) data = _download_file_content(tool_file.file_key) - encoded_string = base64.b64encode(data).decode("utf-8") - return encoded_string - case _: - raise ValueError(f"Unsupported transfer method: {f.transfer_method}") + + encoded_string = base64.b64encode(data).decode("utf-8") + return encoded_string def _to_base64_data_string(f: File, /): @@ -140,18 +133,6 @@ def _to_base64_data_string(f: File, /): return f"data:{f.mime_type};base64,{encoded_string}" -def _file_to_encoded_string(f: File, /): - match f.type: - case FileType.IMAGE: - return _to_base64_data_string(f) - case FileType.VIDEO: - return _to_base64_data_string(f) - case FileType.AUDIO: - return _get_encoded_string(f) - case _: - raise ValueError(f"file type {f.type} is not supported") - - def _to_url(f: File, /): if f.transfer_method == FileTransferMethod.REMOTE_URL: if f.remote_url is None: diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 03c4b8d49d..011ff382ea 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -1,6 +1,6 @@ import logging from collections.abc import Mapping -from enum import Enum +from enum import StrEnum from threading import Lock from typing import Any, Optional @@ -31,7 +31,7 @@ class CodeExecutionResponse(BaseModel): data: Data -class CodeLanguage(str, Enum): +class CodeLanguage(StrEnum): PYTHON3 = "python3" JINJA2 = "jinja2" JAVASCRIPT = "javascript" diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 7db8f54f70..29e161cb74 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -30,6 +30,7 @@ from core.rag.splitter.fixed_text_splitter import ( ) from core.rag.splitter.text_splitter import TextSplitter from core.tools.utils.text_processing_utils import remove_leading_symbols +from core.tools.utils.web_reader_tool import get_image_upload_file_ids from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage @@ -85,7 +86,7 @@ class IndexingRunner: except ProviderTokenNotInitError as e: dataset_document.indexing_status = "error" dataset_document.error = str(e.description) - dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() except ObjectDeletedError: logging.warning("Document deleted, document id: {}".format(dataset_document.id)) @@ -93,7 +94,7 @@ class IndexingRunner: logging.exception("consume document failed") dataset_document.indexing_status = "error" dataset_document.error = str(e) - dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() def run_in_splitting_status(self, dataset_document: DatasetDocument): @@ -141,13 +142,13 @@ class IndexingRunner: except ProviderTokenNotInitError as e: dataset_document.indexing_status = "error" dataset_document.error = str(e.description) - dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() except Exception as e: logging.exception("consume document failed") dataset_document.indexing_status = "error" dataset_document.error = str(e) - dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() def run_in_indexing_status(self, dataset_document: DatasetDocument): @@ -199,13 +200,13 @@ class IndexingRunner: except ProviderTokenNotInitError as e: dataset_document.indexing_status = "error" dataset_document.error = str(e.description) - dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() except Exception as e: logging.exception("consume document failed") dataset_document.indexing_status = "error" dataset_document.error = str(e) - dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() def indexing_estimate( @@ -279,6 +280,19 @@ class IndexingRunner: if len(preview_texts) < 5: preview_texts.append(document.page_content) + # delete image files and related db records + image_upload_file_ids = get_image_upload_file_ids(document.page_content) + for upload_file_id in image_upload_file_ids: + image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first() + try: + storage.delete(image_file.key) + except Exception: + logging.exception( + "Delete image_files failed while indexing_estimate, \ + image_upload_file_is: {}".format(upload_file_id) + ) + db.session.delete(image_file) + if doc_form and doc_form == "qa_model": if len(preview_texts) > 0: # qa model document @@ -358,7 +372,7 @@ class IndexingRunner: after_indexing_status="splitting", extra_update_params={ DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs), - DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), }, ) @@ -450,7 +464,7 @@ class IndexingRunner: doc_store.add_documents(documents) # update document status to indexing - cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) self._update_document_index_status( document_id=dataset_document.id, after_indexing_status="indexing", @@ -465,7 +479,7 @@ class IndexingRunner: dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), }, ) @@ -666,7 +680,7 @@ class IndexingRunner: after_indexing_status="completed", extra_update_params={ DatasetDocument.tokens: tokens, - DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DatasetDocument.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, DatasetDocument.error: None, }, @@ -691,7 +705,7 @@ class IndexingRunner: { DocumentSegment.status: "completed", DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), } ) @@ -724,7 +738,7 @@ class IndexingRunner: { DocumentSegment.status: "completed", DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), } ) @@ -835,7 +849,7 @@ class IndexingRunner: doc_store.add_documents(documents) # update document status to indexing - cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) self._update_document_index_status( document_id=dataset_document.id, after_indexing_status="indexing", @@ -850,7 +864,7 @@ class IndexingRunner: dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), }, ) pass diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 688fb4776a..81d08dc885 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,8 +1,8 @@ +from collections.abc import Sequence from typing import Optional from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.file import file_manager -from core.file.models import FileType from core.model_manager import ModelInstance from core.model_runtime.entities import ( AssistantPromptMessage, @@ -27,7 +27,7 @@ class TokenBufferMemory: def get_history_prompt_messages( self, max_token_limit: int = 2000, message_limit: Optional[int] = None - ) -> list[PromptMessage]: + ) -> Sequence[PromptMessage]: """ Get history prompt messages. :param max_token_limit: max token limit @@ -102,12 +102,11 @@ class TokenBufferMemory: prompt_message_contents: list[PromptMessageContent] = [] prompt_message_contents.append(TextPromptMessageContent(data=message.query)) for file in file_objs: - if file.type in {FileType.IMAGE, FileType.AUDIO}: - prompt_message = file_manager.to_prompt_message_content( - file, - image_detail_config=detail, - ) - prompt_message_contents.append(prompt_message) + prompt_message = file_manager.to_prompt_message_content( + file, + image_detail_config=detail, + ) + prompt_message_contents.append(prompt_message) prompt_messages.append(UserPromptMessage(content=prompt_message_contents)) diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 059ba6c3d1..1986688551 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -100,10 +100,10 @@ class ModelInstance: def invoke_llm( self, - prompt_messages: list[PromptMessage], + prompt_messages: Sequence[PromptMessage], model_parameters: Optional[dict] = None, tools: Sequence[PromptMessageTool] | None = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index 6bd9325785..8870b34435 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from collections.abc import Sequence from typing import Optional from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk @@ -31,7 +32,7 @@ class Callback(ABC): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: @@ -60,7 +61,7 @@ class Callback(ABC): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ): @@ -90,7 +91,7 @@ class Callback(ABC): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: @@ -120,7 +121,7 @@ class Callback(ABC): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> None: diff --git a/api/core/model_runtime/entities/__init__.py b/api/core/model_runtime/entities/__init__.py index f5d4427e3e..5e52f10b4c 100644 --- a/api/core/model_runtime/entities/__init__.py +++ b/api/core/model_runtime/entities/__init__.py @@ -2,6 +2,7 @@ from .llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsa from .message_entities import ( AssistantPromptMessage, AudioPromptMessageContent, + DocumentPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContent, @@ -37,4 +38,5 @@ __all__ = [ "LLMResultChunk", "LLMResultChunkDelta", "AudioPromptMessageContent", + "DocumentPromptMessageContent", ] diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index 3c244d368e..f2870209bb 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,6 +1,7 @@ from abc import ABC -from enum import Enum -from typing import Optional +from collections.abc import Sequence +from enum import Enum, StrEnum +from typing import Literal, Optional from pydantic import BaseModel, Field, field_validator @@ -48,7 +49,7 @@ class PromptMessageFunction(BaseModel): function: PromptMessageTool -class PromptMessageContentType(Enum): +class PromptMessageContentType(StrEnum): """ Enum class for prompt message content type. """ @@ -57,6 +58,7 @@ class PromptMessageContentType(Enum): IMAGE = "image" AUDIO = "audio" VIDEO = "video" + DOCUMENT = "document" class PromptMessageContent(BaseModel): @@ -93,7 +95,7 @@ class ImagePromptMessageContent(PromptMessageContent): Model class for image prompt message content. """ - class DETAIL(str, Enum): + class DETAIL(StrEnum): LOW = "low" HIGH = "high" @@ -101,13 +103,20 @@ class ImagePromptMessageContent(PromptMessageContent): detail: DETAIL = DETAIL.LOW +class DocumentPromptMessageContent(PromptMessageContent): + type: PromptMessageContentType = PromptMessageContentType.DOCUMENT + encode_format: Literal["base64"] + mime_type: str + data: str + + class PromptMessage(ABC, BaseModel): """ Model class for prompt message. """ role: PromptMessageRole - content: Optional[str | list[PromptMessageContent]] = None + content: Optional[str | Sequence[PromptMessageContent]] = None name: Optional[str] = None def is_empty(self) -> bool: diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 52ea787c3a..edc6eac517 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -1,5 +1,5 @@ from decimal import Decimal -from enum import Enum +from enum import Enum, StrEnum from typing import Any, Optional from pydantic import BaseModel, ConfigDict @@ -87,9 +87,12 @@ class ModelFeature(Enum): AGENT_THOUGHT = "agent-thought" VISION = "vision" STREAM_TOOL_CALL = "stream-tool-call" + DOCUMENT = "document" + VIDEO = "video" + AUDIO = "audio" -class DefaultParameterName(str, Enum): +class DefaultParameterName(StrEnum): """ Enum class for parameter template variable. """ diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 5b6f96129b..8faeffa872 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -2,7 +2,7 @@ import logging import re import time from abc import abstractmethod -from collections.abc import Generator, Mapping +from collections.abc import Generator, Mapping, Sequence from typing import Optional, Union from pydantic import ConfigDict @@ -48,7 +48,7 @@ class LargeLanguageModel(AIModel): prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -169,7 +169,7 @@ class LargeLanguageModel(AIModel): prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -212,7 +212,7 @@ if you are not sure about the structure. ) model_parameters.pop("response_format") - stop = stop or [] + stop = list(stop) if stop is not None else [] stop.extend(["\n```", "```\n"]) block_prompts = block_prompts.replace("{{block}}", code_block) @@ -408,7 +408,7 @@ if you are not sure about the structure. prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -479,7 +479,7 @@ if you are not sure about the structure. prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> Union[LLMResult, Generator]: @@ -601,7 +601,7 @@ if you are not sure about the structure. prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -647,7 +647,7 @@ if you are not sure about the structure. prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -694,7 +694,7 @@ if you are not sure about the structure. prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, @@ -742,7 +742,7 @@ if you are not sure about the structure. prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: Optional[list[Callback]] = None, diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml index e02c5517fe..4eb56bbc0e 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20240620.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 200000 diff --git a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml index e20b8c4960..81822b162e 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml +++ b/api/core/model_runtime/model_providers/anthropic/llm/claude-3-5-sonnet-20241022.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 200000 diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 4e7faab891..b24324708b 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -1,7 +1,7 @@ import base64 import io import json -from collections.abc import Generator +from collections.abc import Generator, Sequence from typing import Optional, Union, cast import anthropic @@ -21,9 +21,9 @@ from httpx import Timeout from PIL import Image from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import ( +from core.model_runtime.entities import ( AssistantPromptMessage, + DocumentPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContentType, @@ -33,6 +33,7 @@ from core.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) +from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -86,10 +87,10 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], + prompt_messages: Sequence[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[list[str]] = None, + stop: Optional[Sequence[str]] = None, stream: bool = True, user: Optional[str] = None, ) -> Union[LLMResult, Generator]: @@ -130,9 +131,17 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): # Add the new header for claude-3-5-sonnet-20240620 model extra_headers = {} if model == "claude-3-5-sonnet-20240620": - if model_parameters.get("max_tokens") > 4096: + if model_parameters.get("max_tokens", 0) > 4096: extra_headers["anthropic-beta"] = "max-tokens-3-5-sonnet-2024-07-15" + if any( + isinstance(content, DocumentPromptMessageContent) + for prompt_message in prompt_messages + if isinstance(prompt_message.content, list) + for content in prompt_message.content + ): + extra_headers["anthropic-beta"] = "pdfs-2024-09-25" + if tools: extra_model_kwargs["tools"] = [self._transform_tool_prompt(tool) for tool in tools] response = client.beta.tools.messages.create( @@ -444,7 +453,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): return credentials_kwargs - def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]: + def _convert_prompt_messages(self, prompt_messages: Sequence[PromptMessage]) -> tuple[str, list[dict]]: """ Convert prompt messages to dict list and system """ @@ -452,7 +461,15 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): first_loop = True for message in prompt_messages: if isinstance(message, SystemPromptMessage): - message.content = message.content.strip() + if isinstance(message.content, str): + message.content = message.content.strip() + elif isinstance(message.content, list): + # System prompt only support text + message.content = "".join( + c.data.strip() for c in message.content if isinstance(c, TextPromptMessageContent) + ) + else: + raise ValueError(f"Unknown system prompt message content type {type(message.content)}") if first_loop: system = message.content first_loop = False @@ -504,6 +521,21 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): "source": {"type": "base64", "media_type": mime_type, "data": base64_data}, } sub_messages.append(sub_message_dict) + elif isinstance(message_content, DocumentPromptMessageContent): + if message_content.mime_type != "application/pdf": + raise ValueError( + f"Unsupported document type {message_content.mime_type}, " + "only support application/pdf" + ) + sub_message_dict = { + "type": "document", + "source": { + "type": message_content.encode_format, + "media_type": message_content.mime_type, + "data": message_content.data, + }, + } + sub_messages.append(sub_message_dict) prompt_message_dicts.append({"role": "user", "content": sub_messages}) elif isinstance(message, AssistantPromptMessage): message = cast(AssistantPromptMessage, message) diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index e61a9e0474..4cf58275d7 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -779,7 +779,7 @@ LLM_BASE_MODELS = [ name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), - _get_max_tokens(default=512, min_val=1, max_val=4096), + _get_max_tokens(default=512, min_val=1, max_val=16384), ParameterRule( name="seed", label=I18nObject(zh_Hans="种子", en_US="Seed"), diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index ff0403ee47..ef4dfaf6f1 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -2,13 +2,11 @@ import base64 import json import logging -import mimetypes from collections.abc import Generator from typing import Optional, Union, cast # 3rd import import boto3 -import requests from botocore.config import Config from botocore.exceptions import ( ClientError, @@ -439,22 +437,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel): sub_messages.append(sub_message_dict) elif message_content.type == PromptMessageContentType.IMAGE: message_content = cast(ImagePromptMessageContent, message_content) - if not message_content.data.startswith("data:"): - # fetch image data from url - try: - url = message_content.data - image_content = requests.get(url).content - if "?" in url: - url = url.split("?")[0] - mime_type, _ = mimetypes.guess_type(url) - base64_data = base64.b64encode(image_content).decode("utf-8") - except Exception as ex: - raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") - else: - data_split = message_content.data.split(";base64,") - mime_type = data_split[0].replace("data:", "") - base64_data = data_split[1] - image_content = base64.b64decode(base64_data) + data_split = message_content.data.split(";base64,") + mime_type = data_split[0].replace("data:", "") + base64_data = data_split[1] + image_content = base64.b64decode(base64_data) if mime_type not in {"image/jpeg", "image/png", "image/gif", "image/webp"}: raise ValueError( diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-5-haiku-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-5-haiku-v1.yaml index 9781965555..e5e0244a87 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-5-haiku-v1.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-5-haiku-v1.yaml @@ -15,9 +15,9 @@ parameter_rules: use_template: max_tokens required: true type: int - default: 4096 + default: 8192 min: 1 - max: 4096 + max: 8192 help: zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v2.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v2.yaml index 31a403289b..61f73276ee 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v2.yaml +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.anthropic.claude-3-sonnet-v2.yaml @@ -16,9 +16,9 @@ parameter_rules: use_template: max_tokens required: true type: int - default: 4096 + default: 8192 min: 1 - max: 4096 + max: 8192 help: zh_Hans: 停止前生成的最大令牌数。请注意,Anthropic Claude 模型可能会在达到 max_tokens 的值之前停止生成令牌。不同的 Anthropic Claude 模型对此参数具有不同的最大值。 en_US: The maximum number of tokens to generate before stopping. Note that Anthropic Claude models might stop generating tokens before reaching the value of max_tokens. Different Anthropic Claude models have different maximum values for this parameter. diff --git a/api/core/model_runtime/model_providers/cohere/llm/llm.py b/api/core/model_runtime/model_providers/cohere/llm/llm.py index 3863ad3308..f230157a34 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/llm.py +++ b/api/core/model_runtime/model_providers/cohere/llm/llm.py @@ -691,8 +691,8 @@ class CohereLargeLanguageModel(LargeLanguageModel): base_model_schema = cast(AIModelEntity, base_model_schema) base_model_schema_features = base_model_schema.features or [] - base_model_schema_model_properties = base_model_schema.model_properties or {} - base_model_schema_parameters_rules = base_model_schema.parameter_rules or [] + base_model_schema_model_properties = base_model_schema.model_properties + base_model_schema_parameters_rules = base_model_schema.parameter_rules entity = AIModelEntity( model=model, diff --git a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml index 4973ac8ad6..0bbd27ad74 100644 --- a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml +++ b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml @@ -5,6 +5,7 @@ label: model_type: llm features: - agent-thought + - tool-call - multi-tool-call - stream-tool-call model_properties: @@ -72,7 +73,7 @@ parameter_rules: - text - json_object pricing: - input: '1' - output: '2' - unit: '0.000001' + input: "1" + output: "2" + unit: "0.000001" currency: RMB diff --git a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-coder.yaml b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-coder.yaml index caafeadadd..97310e76b9 100644 --- a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-coder.yaml +++ b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-coder.yaml @@ -5,6 +5,7 @@ label: model_type: llm features: - agent-thought + - tool-call - multi-tool-call - stream-tool-call model_properties: diff --git a/api/core/model_runtime/model_providers/deepseek/llm/llm.py b/api/core/model_runtime/model_providers/deepseek/llm/llm.py index 6d0a3ee262..610dc7b458 100644 --- a/api/core/model_runtime/model_providers/deepseek/llm/llm.py +++ b/api/core/model_runtime/model_providers/deepseek/llm/llm.py @@ -1,18 +1,17 @@ from collections.abc import Generator from typing import Optional, Union -from urllib.parse import urlparse -import tiktoken +from yarl import URL -from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult from core.model_runtime.entities.message_entities import ( PromptMessage, PromptMessageTool, ) -from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel -class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): +class DeepseekLargeLanguageModel(OAIAPICompatLargeLanguageModel): def _invoke( self, model: str, @@ -25,92 +24,15 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): user: Optional[str] = None, ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) - - return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) def validate_credentials(self, model: str, credentials: dict) -> None: self._add_custom_parameters(credentials) super().validate_credentials(model, credentials) - # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: - """ - Calculate num tokens for text completion model with tiktoken package. - - :param model: model name - :param text: prompt text - :param tools: tools for tool calling - :return: number of tokens - """ - encoding = tiktoken.get_encoding("cl100k_base") - num_tokens = len(encoding.encode(text)) - - if tools: - num_tokens += self._num_tokens_for_tools(encoding, tools) - - return num_tokens - - # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_messages( - self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None - ) -> int: - """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. - - Official documentation: https://github.com/openai/openai-cookbook/blob/ - main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" - encoding = tiktoken.get_encoding("cl100k_base") - tokens_per_message = 3 - tokens_per_name = 1 - - num_tokens = 0 - messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] - for message in messages_dict: - num_tokens += tokens_per_message - for key, value in message.items(): - # Cast str(value) in case the message value is not a string - # This occurs with function messages - # TODO: The current token calculation method for the image type is not implemented, - # which need to download the image and then get the resolution for calculation, - # and will increase the request delay - if isinstance(value, list): - text = "" - for item in value: - if isinstance(item, dict) and item["type"] == "text": - text += item["text"] - - value = text - - if key == "tool_calls": - for tool_call in value: - for t_key, t_value in tool_call.items(): - num_tokens += len(encoding.encode(t_key)) - if t_key == "function": - for f_key, f_value in t_value.items(): - num_tokens += len(encoding.encode(f_key)) - num_tokens += len(encoding.encode(f_value)) - else: - num_tokens += len(encoding.encode(t_key)) - num_tokens += len(encoding.encode(t_value)) - else: - num_tokens += len(encoding.encode(str(value))) - - if key == "name": - num_tokens += tokens_per_name - - # every reply is primed with assistant - num_tokens += 3 - - if tools: - num_tokens += self._num_tokens_for_tools(encoding, tools) - - return num_tokens - @staticmethod - def _add_custom_parameters(credentials: dict) -> None: - credentials["mode"] = "chat" - credentials["openai_api_key"] = credentials["api_key"] - if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": - credentials["openai_api_base"] = "https://api.deepseek.com" - else: - parsed_url = urlparse(credentials["endpoint_url"]) - credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" + def _add_custom_parameters(credentials) -> None: + credentials["endpoint_url"] = str(URL(credentials.get("endpoint_url", "https://api.deepseek.com"))) + credentials["mode"] = LLMMode.CHAT.value + credentials["function_calling_type"] = "tool_call" + credentials["stream_function_calling"] = "support" diff --git a/api/core/model_runtime/model_providers/fishaudio/fishaudio.py b/api/core/model_runtime/model_providers/fishaudio/fishaudio.py index 3bc4b533e0..a99803eeea 100644 --- a/api/core/model_runtime/model_providers/fishaudio/fishaudio.py +++ b/api/core/model_runtime/model_providers/fishaudio/fishaudio.py @@ -18,7 +18,8 @@ class FishAudioProvider(ModelProvider): """ try: model_instance = self.get_model_instance(ModelType.TTS) - model_instance.validate_credentials(credentials=credentials) + # FIXME fish tts do not have model for now, so set it to empty string instead + model_instance.validate_credentials(model="", credentials=credentials) except CredentialsValidateFailedError as ex: raise ex except Exception as ex: diff --git a/api/core/model_runtime/model_providers/fishaudio/tts/tts.py b/api/core/model_runtime/model_providers/fishaudio/tts/tts.py index e518d7b95b..43a34cb090 100644 --- a/api/core/model_runtime/model_providers/fishaudio/tts/tts.py +++ b/api/core/model_runtime/model_providers/fishaudio/tts/tts.py @@ -66,7 +66,7 @@ class FishAudioText2SpeechModel(TTSModel): voice=voice, ) - def validate_credentials(self, credentials: dict, user: Optional[str] = None) -> None: + def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: """ Validate credentials for text2speech model @@ -76,7 +76,7 @@ class FishAudioText2SpeechModel(TTSModel): try: self.get_tts_model_voices( - None, + "", credentials={ "api_key": credentials["api_key"], "api_base": credentials["api_base"], diff --git a/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py b/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py index 231345c2f4..832ba92740 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py @@ -122,7 +122,7 @@ class GiteeAIRerankModel(RerankModel): label=I18nObject(en_US=model), model_type=ModelType.RERANK, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512))}, ) return entity diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-001.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-001.yaml index 2e68fa8e6f..43f4e4787d 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-001.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-001.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-002.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-002.yaml index 9f44504e89..7b9add6af1 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-002.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-002.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0827.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0827.yaml index a3da9095e1..d6de82012e 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0827.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0827.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0924.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0924.yaml index 19373e4993..23b8d318fc 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0924.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-8b-exp-0924.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-exp-0827.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-exp-0827.yaml index ca1f0b39b2..9762706cd7 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-exp-0827.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-exp-0827.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml index 24e8c3a74f..b9739d068e 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash-latest.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash.yaml index fa3e814fc3..d8ab4efc91 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-flash.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 1048576 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-001.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-001.yaml index da125e6fab..05184823e4 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-001.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-001.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 2097152 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-002.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-002.yaml index f683e54d3b..548fe6ddb2 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-002.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-002.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 2097152 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0801.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0801.yaml index c67c156bdb..defab26acf 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0801.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0801.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 2097152 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0827.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0827.yaml index 56059fd799..9cbc889f17 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0827.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-exp-0827.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 2097152 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml index ec376f3186..e5aefcdb99 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro-latest.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 2097152 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro.yaml index 8394cdfb56..00bd3e8d99 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-1.5-pro.yaml @@ -7,6 +7,7 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat context_size: 2097152 diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-exp-1114.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-exp-1114.yaml index f126627689..0515e706c2 100644 --- a/api/core/model_runtime/model_providers/google/llm/gemini-exp-1114.yaml +++ b/api/core/model_runtime/model_providers/google/llm/gemini-exp-1114.yaml @@ -7,9 +7,10 @@ features: - vision - tool-call - stream-tool-call + - document model_properties: mode: chat - context_size: 2097152 + context_size: 32767 parameter_rules: - name: temperature use_template: temperature diff --git a/api/core/model_runtime/model_providers/google/llm/gemini-exp-1121.yaml b/api/core/model_runtime/model_providers/google/llm/gemini-exp-1121.yaml new file mode 100644 index 0000000000..9ca4f6e675 --- /dev/null +++ b/api/core/model_runtime/model_providers/google/llm/gemini-exp-1121.yaml @@ -0,0 +1,38 @@ +model: gemini-exp-1121 +label: + en_US: Gemini exp 1121 +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 32767 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_output_tokens + use_template: max_tokens + default: 8192 + min: 1 + max: 8192 + - name: json_schema + use_template: json_schema +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/google/llm/learnlm-1.5-pro-experimental.yaml b/api/core/model_runtime/model_providers/google/llm/learnlm-1.5-pro-experimental.yaml new file mode 100644 index 0000000000..0b29814289 --- /dev/null +++ b/api/core/model_runtime/model_providers/google/llm/learnlm-1.5-pro-experimental.yaml @@ -0,0 +1,38 @@ +model: learnlm-1.5-pro-experimental +label: + en_US: LearnLM 1.5 Pro Experimental +model_type: llm +features: + - agent-thought + - vision + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 32767 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: top_k + label: + zh_Hans: 取样数量 + en_US: Top k + type: int + help: + zh_Hans: 仅从每个后续标记的前 K 个选项中采样。 + en_US: Only sample from the top K options for each subsequent token. + required: false + - name: max_output_tokens + use_template: max_tokens + default: 8192 + min: 1 + max: 8192 + - name: json_schema + use_template: json_schema +pricing: + input: '0.00' + output: '0.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 754f056ac1..77e0801b63 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -16,6 +16,7 @@ from PIL import Image from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, + DocumentPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContentType, @@ -35,6 +36,21 @@ from core.model_runtime.errors.invoke import ( from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +GOOGLE_AVAILABLE_MIMETYPE = [ + "application/pdf", + "application/x-javascript", + "text/javascript", + "application/x-python", + "text/x-python", + "text/plain", + "text/html", + "text/css", + "text/md", + "text/csv", + "text/xml", + "text/rtf", +] + class GoogleLargeLanguageModel(LargeLanguageModel): def _invoke( @@ -370,6 +386,12 @@ class GoogleLargeLanguageModel(LargeLanguageModel): raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}") blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}} glm_content["parts"].append(blob) + elif c.type == PromptMessageContentType.DOCUMENT: + message_content = cast(DocumentPromptMessageContent, c) + if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE: + raise ValueError(f"Unsupported mime type {message_content.mime_type}") + blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}} + glm_content["parts"].append(blob) return glm_content elif isinstance(message, AssistantPromptMessage): diff --git a/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py b/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py index 5ea7532564..feb5777028 100644 --- a/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py @@ -140,7 +140,7 @@ class GPUStackRerankModel(RerankModel): label=I18nObject(en_US=model), model_type=ModelType.RERANK, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512))}, ) return entity diff --git a/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.yaml b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.yaml index f3a912d84d..e81da51048 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.yaml +++ b/api/core/model_runtime/model_providers/huggingface_tei/huggingface_tei.yaml @@ -34,3 +34,11 @@ model_credential_schema: placeholder: zh_Hans: 在此输入Text Embedding Inference的服务器地址,如 http://192.168.1.100:8080 en_US: Enter the url of your Text Embedding Inference, e.g. http://192.168.1.100:8080 + - variable: api_key + label: + en_US: API Key + type: secret-input + required: false + placeholder: + zh_Hans: 在此输入您的 API Key + en_US: Enter your API Key diff --git a/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py index 0bb9a9c8b5..06f76c2d85 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/rerank/rerank.py @@ -51,8 +51,13 @@ class HuggingfaceTeiRerankModel(RerankModel): server_url = server_url.removesuffix("/") + headers = {"Content-Type": "application/json"} + api_key = credentials.get("api_key") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + try: - results = TeiHelper.invoke_rerank(server_url, query, docs) + results = TeiHelper.invoke_rerank(server_url, query, docs, headers) rerank_documents = [] for result in results: @@ -80,7 +85,11 @@ class HuggingfaceTeiRerankModel(RerankModel): """ try: server_url = credentials["server_url"] - extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) + headers = {"Content-Type": "application/json"} + api_key = credentials.get("api_key") + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers) if extra_args.model_type != "reranker": raise CredentialsValidateFailedError("Current model is not a rerank model") diff --git a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py index 81ab249214..3ffcf4175e 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/tei_helper.py @@ -26,13 +26,15 @@ cache_lock = Lock() class TeiHelper: @staticmethod - def get_tei_extra_parameter(server_url: str, model_name: str) -> TeiModelExtraParameter: + def get_tei_extra_parameter( + server_url: str, model_name: str, headers: Optional[dict] = None + ) -> TeiModelExtraParameter: TeiHelper._clean_cache() with cache_lock: if model_name not in cache: cache[model_name] = { "expires": time() + 300, - "value": TeiHelper._get_tei_extra_parameter(server_url), + "value": TeiHelper._get_tei_extra_parameter(server_url, headers), } return cache[model_name]["value"] @@ -47,7 +49,7 @@ class TeiHelper: pass @staticmethod - def _get_tei_extra_parameter(server_url: str) -> TeiModelExtraParameter: + def _get_tei_extra_parameter(server_url: str, headers: Optional[dict] = None) -> TeiModelExtraParameter: """ get tei model extra parameter like model_type, max_input_length, max_batch_requests """ @@ -61,7 +63,7 @@ class TeiHelper: session.mount("https://", HTTPAdapter(max_retries=3)) try: - response = session.get(url, timeout=10) + response = session.get(url, headers=headers, timeout=10) except (MissingSchema, ConnectionError, Timeout) as e: raise RuntimeError(f"get tei model extra parameter failed, url: {url}, error: {e}") if response.status_code != 200: @@ -86,7 +88,7 @@ class TeiHelper: ) @staticmethod - def invoke_tokenize(server_url: str, texts: list[str]) -> list[list[dict]]: + def invoke_tokenize(server_url: str, texts: list[str], headers: Optional[dict] = None) -> list[list[dict]]: """ Invoke tokenize endpoint @@ -114,15 +116,15 @@ class TeiHelper: :param server_url: server url :param texts: texts to tokenize """ - resp = httpx.post( - f"{server_url}/tokenize", - json={"inputs": texts}, - ) + url = f"{server_url}/tokenize" + json_data = {"inputs": texts} + resp = httpx.post(url, json=json_data, headers=headers) + resp.raise_for_status() return resp.json() @staticmethod - def invoke_embeddings(server_url: str, texts: list[str]) -> dict: + def invoke_embeddings(server_url: str, texts: list[str], headers: Optional[dict] = None) -> dict: """ Invoke embeddings endpoint @@ -147,15 +149,14 @@ class TeiHelper: :param texts: texts to embed """ # Use OpenAI compatible API here, which has usage tracking - resp = httpx.post( - f"{server_url}/v1/embeddings", - json={"input": texts}, - ) + url = f"{server_url}/v1/embeddings" + json_data = {"input": texts} + resp = httpx.post(url, json=json_data, headers=headers) resp.raise_for_status() return resp.json() @staticmethod - def invoke_rerank(server_url: str, query: str, docs: list[str]) -> list[dict]: + def invoke_rerank(server_url: str, query: str, docs: list[str], headers: Optional[dict] = None) -> list[dict]: """ Invoke rerank endpoint @@ -173,10 +174,7 @@ class TeiHelper: :param candidates: candidates to rerank """ params = {"query": query, "texts": docs, "return_text": True} - - response = httpx.post( - server_url + "/rerank", - json=params, - ) + url = f"{server_url}/rerank" + response = httpx.post(url, json=params, headers=headers) response.raise_for_status() return response.json() diff --git a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py index a0917630a9..284429b741 100644 --- a/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/huggingface_tei/text_embedding/text_embedding.py @@ -51,6 +51,10 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): server_url = server_url.removesuffix("/") + headers = {"Content-Type": "application/json"} + api_key = credentials["api_key"] + if api_key: + headers["Authorization"] = f"Bearer {api_key}" # get model properties context_size = self._get_context_size(model, credentials) max_chunks = self._get_max_chunks(model, credentials) @@ -60,7 +64,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): used_tokens = 0 # get tokenized results from TEI - batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts) + batched_tokenize_result = TeiHelper.invoke_tokenize(server_url, texts, headers) for i, (text, tokenize_result) in enumerate(zip(texts, batched_tokenize_result)): # Check if the number of tokens is larger than the context size @@ -97,7 +101,7 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): used_tokens = 0 for i in _iter: iter_texts = inputs[i : i + max_chunks] - results = TeiHelper.invoke_embeddings(server_url, iter_texts) + results = TeiHelper.invoke_embeddings(server_url, iter_texts, headers) embeddings = results["data"] embeddings = [embedding["embedding"] for embedding in embeddings] batched_embeddings.extend(embeddings) @@ -127,7 +131,11 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): server_url = server_url.removesuffix("/") - batch_tokens = TeiHelper.invoke_tokenize(server_url, texts) + headers = { + "Authorization": f"Bearer {credentials.get('api_key')}", + } + + batch_tokens = TeiHelper.invoke_tokenize(server_url, texts, headers) num_tokens = sum(len(tokens) for tokens in batch_tokens) return num_tokens @@ -141,7 +149,14 @@ class HuggingfaceTeiTextEmbeddingModel(TextEmbeddingModel): """ try: server_url = credentials["server_url"] - extra_args = TeiHelper.get_tei_extra_parameter(server_url, model) + headers = {"Content-Type": "application/json"} + + api_key = credentials.get("api_key") + + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + extra_args = TeiHelper.get_tei_extra_parameter(server_url, model, headers) print(extra_args) if extra_args.model_type != "embedding": raise CredentialsValidateFailedError("Current model is not a embedding model") diff --git a/api/core/model_runtime/model_providers/jina/rerank/rerank.py b/api/core/model_runtime/model_providers/jina/rerank/rerank.py index aacc8e75d3..22f882be6b 100644 --- a/api/core/model_runtime/model_providers/jina/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/jina/rerank/rerank.py @@ -128,7 +128,7 @@ class JinaRerankModel(RerankModel): label=I18nObject(en_US=model), model_type=ModelType.RERANK, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000))}, ) return entity diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py index 49c558f4a4..f5be7a9828 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py @@ -193,7 +193,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): label=I18nObject(en_US=model), model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000))}, ) return entity diff --git a/api/core/model_runtime/model_providers/ollama/llm/llm.py b/api/core/model_runtime/model_providers/ollama/llm/llm.py index a7ea53e0e9..094a674645 100644 --- a/api/core/model_runtime/model_providers/ollama/llm/llm.py +++ b/api/core/model_runtime/model_providers/ollama/llm/llm.py @@ -22,6 +22,7 @@ from core.model_runtime.entities.message_entities import ( PromptMessageTool, SystemPromptMessage, TextPromptMessageContent, + ToolPromptMessage, UserPromptMessage, ) from core.model_runtime.entities.model_entities import ( @@ -86,6 +87,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): credentials=credentials, prompt_messages=prompt_messages, model_parameters=model_parameters, + tools=tools, stop=stop, stream=stream, user=user, @@ -153,6 +155,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, @@ -196,6 +199,8 @@ class OllamaLargeLanguageModel(LargeLanguageModel): if completion_type is LLMMode.CHAT: endpoint_url = urljoin(endpoint_url, "api/chat") data["messages"] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages] + if tools: + data["tools"] = [self._convert_prompt_message_tool_to_dict(tool) for tool in tools] else: endpoint_url = urljoin(endpoint_url, "api/generate") first_prompt_message = prompt_messages[0] @@ -232,7 +237,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): if stream: return self._handle_generate_stream_response(model, credentials, completion_type, response, prompt_messages) - return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages) + return self._handle_generate_response(model, credentials, completion_type, response, prompt_messages, tools) def _handle_generate_response( self, @@ -241,6 +246,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): completion_type: LLMMode, response: requests.Response, prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]], ) -> LLMResult: """ Handle llm completion response @@ -253,14 +259,16 @@ class OllamaLargeLanguageModel(LargeLanguageModel): :return: llm result """ response_json = response.json() - + tool_calls = [] if completion_type is LLMMode.CHAT: message = response_json.get("message", {}) response_content = message.get("content", "") + response_tool_calls = message.get("tool_calls", []) + tool_calls = [self._extract_response_tool_call(tool_call) for tool_call in response_tool_calls] else: response_content = response_json["response"] - assistant_message = AssistantPromptMessage(content=response_content) + assistant_message = AssistantPromptMessage(content=response_content, tool_calls=tool_calls) if "prompt_eval_count" in response_json and "eval_count" in response_json: # transform usage @@ -405,9 +413,28 @@ class OllamaLargeLanguageModel(LargeLanguageModel): chunk_index += 1 + def _convert_prompt_message_tool_to_dict(self, tool: PromptMessageTool) -> dict: + """ + Convert PromptMessageTool to dict for Ollama API + + :param tool: tool + :return: tool dict + """ + return { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters, + }, + } + def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict: """ Convert PromptMessage to dict for Ollama API + + :param message: prompt message + :return: message dict """ if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) @@ -432,6 +459,9 @@ class OllamaLargeLanguageModel(LargeLanguageModel): elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) message_dict = {"role": "system", "content": message.content} + elif isinstance(message, ToolPromptMessage): + message = cast(ToolPromptMessage, message) + message_dict = {"role": "tool", "content": message.content} else: raise ValueError(f"Got unknown type {message}") @@ -452,6 +482,29 @@ class OllamaLargeLanguageModel(LargeLanguageModel): return num_tokens + def _extract_response_tool_call(self, response_tool_call: dict) -> AssistantPromptMessage.ToolCall: + """ + Extract response tool call + """ + tool_call = None + if response_tool_call and "function" in response_tool_call: + # Convert arguments to JSON string if it's a dict + arguments = response_tool_call.get("function").get("arguments") + if isinstance(arguments, dict): + arguments = json.dumps(arguments) + + function = AssistantPromptMessage.ToolCall.ToolCallFunction( + name=response_tool_call.get("function").get("name"), + arguments=arguments, + ) + tool_call = AssistantPromptMessage.ToolCall( + id=response_tool_call.get("function").get("name"), + type="function", + function=function, + ) + + return tool_call + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity: """ Get customizable model schema. @@ -461,10 +514,15 @@ class OllamaLargeLanguageModel(LargeLanguageModel): :return: model schema """ - extras = {} + extras = { + "features": [], + } if "vision_support" in credentials and credentials["vision_support"] == "true": - extras["features"] = [ModelFeature.VISION] + extras["features"].append(ModelFeature.VISION) + if "function_call_support" in credentials and credentials["function_call_support"] == "true": + extras["features"].append(ModelFeature.TOOL_CALL) + extras["features"].append(ModelFeature.MULTI_TOOL_CALL) entity = AIModelEntity( model=model, diff --git a/api/core/model_runtime/model_providers/ollama/ollama.yaml b/api/core/model_runtime/model_providers/ollama/ollama.yaml index 33747753bd..6560fcd180 100644 --- a/api/core/model_runtime/model_providers/ollama/ollama.yaml +++ b/api/core/model_runtime/model_providers/ollama/ollama.yaml @@ -96,3 +96,22 @@ model_credential_schema: label: en_US: 'No' zh_Hans: 否 + - variable: function_call_support + label: + zh_Hans: 是否支持函数调用 + en_US: Function call support + show_on: + - variable: __model_type + value: llm + default: 'false' + type: radio + required: false + options: + - value: 'true' + label: + en_US: 'Yes' + zh_Hans: 是 + - value: 'false' + label: + en_US: 'No' + zh_Hans: 否 diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index a16c91cd7e..83c4facc8d 100644 --- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -139,7 +139,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], diff --git a/api/core/model_runtime/model_providers/openai/llm/_position.yaml b/api/core/model_runtime/model_providers/openai/llm/_position.yaml index b7c25ecb16..099aae38a6 100644 --- a/api/core/model_runtime/model_providers/openai/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/_position.yaml @@ -3,6 +3,7 @@ - gpt-4o - gpt-4o-2024-05-13 - gpt-4o-2024-08-06 +- gpt-4o-2024-11-20 - chatgpt-4o-latest - gpt-4o-mini - gpt-4o-mini-2024-07-18 diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-11-20.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-11-20.yaml new file mode 100644 index 0000000000..ebd5ab38c3 --- /dev/null +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-2024-11-20.yaml @@ -0,0 +1,47 @@ +model: gpt-4o-2024-11-20 +label: + zh_Hans: gpt-4o-2024-11-20 + en_US: gpt-4o-2024-11-20 +model_type: llm +features: + - multi-tool-call + - agent-thought + - stream-tool-call + - vision +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: temperature + use_template: temperature + - name: top_p + use_template: top_p + - name: presence_penalty + use_template: presence_penalty + - name: frequency_penalty + use_template: frequency_penalty + - name: max_tokens + use_template: max_tokens + default: 512 + min: 1 + max: 16384 + - name: response_format + label: + zh_Hans: 回复格式 + en_US: Response Format + type: string + help: + zh_Hans: 指定模型必须输出的格式 + en_US: specifying the format that the model must output + required: false + options: + - text + - json_object + - json_schema + - name: json_schema + use_template: json_schema +pricing: + input: '2.50' + output: '10.00' + unit: '0.000001' + currency: USD diff --git a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml index 256e87edbe..6571cd094f 100644 --- a/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml +++ b/api/core/model_runtime/model_providers/openai/llm/gpt-4o-audio-preview.yaml @@ -7,7 +7,7 @@ features: - multi-tool-call - agent-thought - stream-tool-call - - vision + - audio model_properties: mode: chat context_size: 128000 diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 68317d7179..07cb1e2d10 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -615,19 +615,11 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) # o1 compatibility - block_as_stream = False if model.startswith("o1"): if "max_tokens" in model_parameters: model_parameters["max_completion_tokens"] = model_parameters["max_tokens"] del model_parameters["max_tokens"] - if stream: - block_as_stream = True - stream = False - - if "stream_options" in extra_model_kwargs: - del extra_model_kwargs["stream_options"] - if "stop" in extra_model_kwargs: del extra_model_kwargs["stop"] @@ -644,47 +636,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): if stream: return self._handle_chat_generate_stream_response(model, credentials, response, prompt_messages, tools) - block_result = self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) - - if block_as_stream: - return self._handle_chat_block_as_stream_response(block_result, prompt_messages, stop) - - return block_result - - def _handle_chat_block_as_stream_response( - self, - block_result: LLMResult, - prompt_messages: list[PromptMessage], - stop: Optional[list[str]] = None, - ) -> Generator[LLMResultChunk, None, None]: - """ - Handle llm chat response - - :param model: model name - :param credentials: credentials - :param response: response - :param prompt_messages: prompt messages - :param tools: tools for tool calling - :param stop: stop words - :return: llm response chunk generator - """ - text = block_result.message.content - text = cast(str, text) - - if stop: - text = self.enforce_stop_tokens(text, stop) - - yield LLMResultChunk( - model=block_result.model, - prompt_messages=prompt_messages, - system_fingerprint=block_result.system_fingerprint, - delta=LLMResultChunkDelta( - index=0, - message=AssistantPromptMessage(content=text), - finish_reason="stop", - usage=block_result.usage, - ), - ) + return self._handle_chat_generate_response(model, credentials, response, prompt_messages, tools) def _handle_chat_generate_response( self, @@ -991,6 +943,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) + if isinstance(message.content, list): + text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content) + message.content = "".join(c.data for c in text_contents) message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) @@ -1178,8 +1133,8 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): base_model_schema = model_map[base_model] base_model_schema_features = base_model_schema.features or [] - base_model_schema_model_properties = base_model_schema.model_properties or {} - base_model_schema_parameters_rules = base_model_schema.parameter_rules or [] + base_model_schema_model_properties = base_model_schema.model_properties + base_model_schema_parameters_rules = base_model_schema.parameter_rules entity = AIModelEntity( model=model, diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/rerank/rerank.py b/api/core/model_runtime/model_providers/openai_api_compatible/rerank/rerank.py index 508da4bf20..407dc7190e 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/rerank/rerank.py @@ -64,7 +64,7 @@ class OAICompatRerankModel(RerankModel): # TODO: Do we need truncate docs to avoid llama.cpp return error? - data = {"model": model_name, "query": query, "documents": docs, "top_n": top_n} + data = {"model": model_name, "query": query, "documents": docs, "top_n": top_n, "return_documents": True} try: response = post(str(URL(url) / "rerank"), headers=headers, data=dumps(data), timeout=60) @@ -83,7 +83,13 @@ class OAICompatRerankModel(RerankModel): index = result["index"] # Retrieve document text (fallback if llama.cpp rerank doesn't return it) - text = result.get("document", {}).get("text", docs[index]) + text = docs[index] + document = result.get("document", {}) + if document: + if isinstance(document, dict): + text = document.get("text", docs[index]) + elif isinstance(document, str): + text = document # Normalize the score normalized_score = (result["relevance_score"] - min_score) / score_range diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index c2b7297aac..793c384d5a 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -176,7 +176,7 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py index 351dcced15..2789a9250a 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py @@ -37,13 +37,14 @@ class OpenLLMGenerateMessage: class OpenLLMGenerate: def generate( self, + *, server_url: str, model_name: str, stream: bool, model_parameters: dict[str, Any], - stop: list[str], + stop: list[str] | None = None, prompt_messages: list[OpenLLMGenerateMessage], - user: str, + user: str | None = None, ) -> Union[Generator[OpenLLMGenerateMessage, None, None], OpenLLMGenerateMessage]: if not server_url: raise InvalidAuthenticationError("Invalid server URL") diff --git a/api/core/model_runtime/model_providers/openrouter/llm/llm.py b/api/core/model_runtime/model_providers/openrouter/llm/llm.py index 736ab8e7a8..2d6ece8113 100644 --- a/api/core/model_runtime/model_providers/openrouter/llm/llm.py +++ b/api/core/model_runtime/model_providers/openrouter/llm/llm.py @@ -45,19 +45,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel): user: Optional[str] = None, ) -> Union[LLMResult, Generator]: self._update_credential(model, credentials) - - block_as_stream = False - if model.startswith("openai/o1"): - block_as_stream = True - stop = None - - # invoke block as stream - if stream and block_as_stream: - return self._generate_block_as_stream( - model, credentials, prompt_messages, model_parameters, tools, stop, user - ) - else: - return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + return super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def _generate_block_as_stream( self, @@ -69,9 +57,7 @@ class OpenRouterLargeLanguageModel(OAIAPICompatLargeLanguageModel): stop: Optional[list[str]] = None, user: Optional[str] = None, ) -> Generator: - resp: LLMResult = super()._generate( - model, credentials, prompt_messages, model_parameters, tools, stop, False, user - ) + resp = super()._generate(model, credentials, prompt_messages, model_parameters, tools, stop, False, user) yield LLMResultChunk( model=model, diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py index d78bdaa75e..7bbd31e87c 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py @@ -182,7 +182,7 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml index f010e4c826..b52df3e4e3 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml @@ -24,4 +24,3 @@ - meta-llama/Meta-Llama-3.1-8B-Instruct - google/gemma-2-27b-it - google/gemma-2-9b-it -- deepseek-ai/DeepSeek-V2-Chat diff --git a/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml b/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml index 71f9a92381..73a9e80769 100644 --- a/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml @@ -18,6 +18,7 @@ supported_model_types: - text-embedding - rerank - speech2text + - tts configurate_methods: - predefined-model - customizable-model diff --git a/api/core/model_runtime/model_providers/siliconflow/tts/__init__.py b/api/core/model_runtime/model_providers/siliconflow/tts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/siliconflow/tts/fish-speech-1.4.yaml b/api/core/model_runtime/model_providers/siliconflow/tts/fish-speech-1.4.yaml new file mode 100644 index 0000000000..4adfd05c60 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/tts/fish-speech-1.4.yaml @@ -0,0 +1,37 @@ +model: fishaudio/fish-speech-1.4 +model_type: tts +model_properties: + default_voice: 'fishaudio/fish-speech-1.4:alex' + voices: + - mode: "fishaudio/fish-speech-1.4:alex" + name: "Alex(男声)" + language: [ "zh-Hans", "en-US" ] + - mode: "fishaudio/fish-speech-1.4:benjamin" + name: "Benjamin(男声)" + language: [ "zh-Hans", "en-US" ] + - mode: "fishaudio/fish-speech-1.4:charles" + name: "Charles(男声)" + language: [ "zh-Hans", "en-US" ] + - mode: "fishaudio/fish-speech-1.4:david" + name: "David(男声)" + language: [ "zh-Hans", "en-US" ] + - mode: "fishaudio/fish-speech-1.4:anna" + name: "Anna(女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "fishaudio/fish-speech-1.4:bella" + name: "Bella(女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "fishaudio/fish-speech-1.4:claire" + name: "Claire(女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "fishaudio/fish-speech-1.4:diana" + name: "Diana(女声)" + language: [ "zh-Hans", "en-US" ] + audio_type: 'mp3' + max_workers: 5 + # stream: false +pricing: + input: '0.015' + output: '0' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/tts/tts.py b/api/core/model_runtime/model_providers/siliconflow/tts/tts.py new file mode 100644 index 0000000000..a5554abb73 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/tts/tts.py @@ -0,0 +1,105 @@ +import concurrent.futures +from typing import Any, Optional + +from openai import OpenAI + +from core.model_runtime.errors.invoke import InvokeBadRequestError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.tts_model import TTSModel +from core.model_runtime.model_providers.openai._common import _CommonOpenAI + + +class SiliconFlowText2SpeechModel(_CommonOpenAI, TTSModel): + """ + Model class for SiliconFlow Speech to text model. + """ + + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ) -> Any: + """ + _invoke text2speech model + + :param model: model name + :param tenant_id: user tenant id + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :param user: unique user id + :return: text translated to audio file + """ + if not voice or voice not in [ + d["value"] for d in self.get_tts_model_voices(model=model, credentials=credentials) + ]: + voice = self._get_model_default_voice(model, credentials) + # if streaming: + return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice) + + def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: + """ + validate credentials text2speech model + + :param model: model name + :param credentials: model credentials + :param user: unique user id + :return: text translated to audio file + """ + try: + self._tts_invoke_streaming( + model=model, + credentials=credentials, + content_text="Hello SiliconFlow!", + voice=self._get_model_default_voice(model, credentials), + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> Any: + """ + _tts_invoke_streaming text2speech model + + :param model: model name + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :return: text translated to audio file + """ + try: + # doc: https://docs.siliconflow.cn/capabilities/text-to-speech + self._add_custom_parameters(credentials) + credentials_kwargs = self._to_credential_kwargs(credentials) + client = OpenAI(**credentials_kwargs) + model_support_voice = [ + x.get("value") for x in self.get_tts_model_voices(model=model, credentials=credentials) + ] + if not voice or voice not in model_support_voice: + voice = self._get_model_default_voice(model, credentials) + if len(content_text) > 4096: + sentences = self._split_text_into_sentences(content_text, max_length=4096) + executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) + futures = [ + executor.submit( + client.audio.speech.with_streaming_response.create, + model=model, + response_format="mp3", + input=sentences[i], + voice=voice, + ) + for i in range(len(sentences)) + ] + for future in futures: + yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801 + + else: + response = client.audio.speech.with_streaming_response.create( + model=model, voice=voice, response_format="mp3", input=content_text.strip() + ) + + yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801 + except Exception as ex: + raise InvokeBadRequestError(str(ex)) + + @classmethod + def _add_custom_parameters(cls, credentials: dict) -> None: + credentials["openai_api_base"] = "https://api.siliconflow.cn" + credentials["openai_api_key"] = credentials["api_key"] diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max-0809.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max-0809.yaml index 50e10226a5..94b6666d05 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max-0809.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max-0809.yaml @@ -6,6 +6,7 @@ model_type: llm features: - vision - agent-thought + - video model_properties: mode: chat context_size: 32000 diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max.yaml index 21b127f56c..b6172c1cbc 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-max.yaml @@ -6,6 +6,7 @@ model_type: llm features: - vision - agent-thought + - video model_properties: mode: chat context_size: 32000 diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus-0809.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus-0809.yaml index 67b2b2ebdd..0be4b68f4f 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus-0809.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus-0809.yaml @@ -6,6 +6,7 @@ model_type: llm features: - vision - agent-thought + - video model_properties: mode: chat context_size: 32768 diff --git a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus.yaml b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus.yaml index f55764c6c0..6c8a8121c6 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus.yaml +++ b/api/core/model_runtime/model_providers/tongyi/llm/qwen-vl-plus.yaml @@ -6,6 +6,7 @@ model_type: llm features: - vision - agent-thought + - video model_properties: mode: chat context_size: 8000 diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py index 43233e6126..9cd0c78d99 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py @@ -173,7 +173,7 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], diff --git a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py index ce4f0c3ab1..4a6f5b6f7b 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/text_embedding/models.py @@ -22,7 +22,7 @@ def get_model_config(credentials: dict) -> ModelConfig: return ModelConfig( properties=ModelProperties( context_size=int(credentials.get("context_size", 0)), - max_chunks=int(credentials.get("max_chunks", 0)), + max_chunks=int(credentials.get("max_chunks", 1)), ) ) return model_configs diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py index e69c9fccba..16f1bd43d8 100644 --- a/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py @@ -166,7 +166,7 @@ class VoyageTextEmbeddingModel(TextEmbeddingModel): label=I18nObject(en_US=model), model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512))}, ) return entity diff --git a/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml b/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml index 7c305735b9..bb71de2bad 100644 --- a/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml +++ b/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml @@ -1,9 +1,12 @@ model: grok-beta label: - en_US: Grok beta + en_US: Grok Beta model_type: llm features: + - agent-thought + - tool-call - multi-tool-call + - stream-tool-call model_properties: mode: chat context_size: 131072 diff --git a/api/core/model_runtime/model_providers/x/llm/grok-vision-beta.yaml b/api/core/model_runtime/model_providers/x/llm/grok-vision-beta.yaml new file mode 100644 index 0000000000..844f0520bc --- /dev/null +++ b/api/core/model_runtime/model_providers/x/llm/grok-vision-beta.yaml @@ -0,0 +1,64 @@ +model: grok-vision-beta +label: + en_US: Grok Vision Beta +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 2.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: 0 + max: 2.0 + precision: 1 + required: false + help: + en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim." + zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/x/llm/llm.py b/api/core/model_runtime/model_providers/x/llm/llm.py index 3f5325a857..eacd086fee 100644 --- a/api/core/model_runtime/model_providers/x/llm/llm.py +++ b/api/core/model_runtime/model_providers/x/llm/llm.py @@ -35,3 +35,5 @@ class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel): credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1" credentials["mode"] = LLMMode.CHAT.value credentials["function_calling_type"] = "tool_call" + credentials["stream_function_calling"] = "support" + credentials["vision_support"] = "support" diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index b82f0430c5..8d86d6937d 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -63,6 +63,9 @@ from core.model_runtime.model_providers.xinference.xinference_helper import ( ) from core.model_runtime.utils import helper +DEFAULT_MAX_RETRIES = 3 +DEFAULT_INVOKE_TIMEOUT = 60 + class XinferenceAILargeLanguageModel(LargeLanguageModel): def _invoke( @@ -315,7 +318,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content} + message_dict = { + "tool_call_id": message.tool_call_id, + "role": "tool", + "content": message.content, + "name": message.name, + } else: raise ValueError(f"Unknown message type {type(message)}") @@ -466,8 +474,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): client = OpenAI( base_url=f'{credentials["server_url"]}/v1', api_key=api_key, - max_retries=3, - timeout=60, + max_retries=int(credentials.get("max_retries") or DEFAULT_MAX_RETRIES), + timeout=int(credentials.get("invoke_timeout") or DEFAULT_INVOKE_TIMEOUT), ) xinference_client = Client( diff --git a/api/core/model_runtime/model_providers/xinference/xinference.yaml b/api/core/model_runtime/model_providers/xinference/xinference.yaml index be9073c1ca..3500136693 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference.yaml +++ b/api/core/model_runtime/model_providers/xinference/xinference.yaml @@ -56,3 +56,23 @@ model_credential_schema: placeholder: zh_Hans: 在此输入您的API密钥 en_US: Enter the api key + - variable: invoke_timeout + label: + zh_Hans: 调用超时时间 (单位:秒) + en_US: invoke timeout (unit:second) + type: text-input + required: true + default: '60' + placeholder: + zh_Hans: 在此输入调用超时时间 + en_US: Enter invoke timeout value + - variable: max_retries + label: + zh_Hans: 调用重试次数 + en_US: max retries + type: text-input + required: true + default: '3' + placeholder: + zh_Hans: 在此输入调用重试次数 + en_US: Enter max retries diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml index 91550ceee8..dbda18b888 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml @@ -6,6 +6,7 @@ model_properties: mode: chat features: - vision + - video parameter_rules: - name: temperature use_template: temperature diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index f629b62fd5..2428284ba9 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -105,17 +105,6 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): return [list(map(float, e)) for e in embeddings], embedding_used_tokens - def embed_query(self, text: str) -> list[float]: - """Call out to ZhipuAI's embedding endpoint. - - Args: - text: The text to embed. - - Returns: - Embeddings for the text. - """ - return self.embed_documents([text])[0] - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage diff --git a/api/core/moderation/keywords/keywords.py b/api/core/moderation/keywords/keywords.py index 4846da8f93..00b3c56c03 100644 --- a/api/core/moderation/keywords/keywords.py +++ b/api/core/moderation/keywords/keywords.py @@ -1,3 +1,6 @@ +from collections.abc import Sequence +from typing import Any + from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult @@ -62,5 +65,5 @@ class KeywordsModeration(Moderation): def _is_violated(self, inputs: dict, keywords_list: list) -> bool: return any(self._check_keywords_in_value(keywords_list, value) for value in inputs.values()) - def _check_keywords_in_value(self, keywords_list, value) -> bool: - return any(keyword.lower() in value.lower() for keyword in keywords_list) + def _check_keywords_in_value(self, keywords_list: Sequence[str], value: Any) -> bool: + return any(keyword.lower() in str(value).lower() for keyword in keywords_list) diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 256595286f..71ff03b6ef 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -1,5 +1,5 @@ from datetime import datetime -from enum import Enum +from enum import StrEnum from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, field_validator @@ -122,7 +122,7 @@ trace_info_info_map = { } -class TraceTaskName(str, Enum): +class TraceTaskName(StrEnum): CONVERSATION_TRACE = "conversation" WORKFLOW_TRACE = "workflow" MESSAGE_TRACE = "message" diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py index 447b799f1f..f486da3a6d 100644 --- a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py +++ b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py @@ -1,5 +1,5 @@ from datetime import datetime -from enum import Enum +from enum import StrEnum from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -39,7 +39,7 @@ def validate_input_output(v, field_name): return v -class LevelEnum(str, Enum): +class LevelEnum(StrEnum): DEBUG = "DEBUG" WARNING = "WARNING" ERROR = "ERROR" @@ -178,7 +178,7 @@ class LangfuseSpan(BaseModel): return validate_input_output(v, field_name) -class UnitEnum(str, Enum): +class UnitEnum(StrEnum): CHARACTERS = "CHARACTERS" TOKENS = "TOKENS" SECONDS = "SECONDS" diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index 16c76f363c..99221d669b 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -1,5 +1,5 @@ from datetime import datetime -from enum import Enum +from enum import StrEnum from typing import Any, Optional, Union from pydantic import BaseModel, Field, field_validator @@ -8,7 +8,7 @@ from pydantic_core.core_schema import ValidationInfo from core.ops.utils import replace_text_with_content -class LangSmithRunType(str, Enum): +class LangSmithRunType(StrEnum): tool = "tool" chain = "chain" llm = "llm" diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 1069889abd..b7799ce1fb 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -445,7 +445,7 @@ class TraceTask: "ls_provider": message_data.model_provider, "ls_model_name": message_data.model_id, "status": message_data.status, - "from_end_user_id": message_data.from_account_id, + "from_end_user_id": message_data.from_end_user_id, "from_account_id": message_data.from_account_id, "agent_based": message_data.agent_based, "workflow_run_id": message_data.workflow_run_id, @@ -521,7 +521,7 @@ class TraceTask: "ls_provider": message_data.model_provider, "ls_model_name": message_data.model_id, "status": message_data.status, - "from_end_user_id": message_data.from_account_id, + "from_end_user_id": message_data.from_end_user_id, "from_account_id": message_data.from_account_id, "agent_based": message_data.agent_based, "workflow_run_id": message_data.workflow_run_id, @@ -570,7 +570,7 @@ class TraceTask: "ls_provider": message_data.model_provider, "ls_model_name": message_data.model_id, "status": message_data.status, - "from_end_user_id": message_data.from_account_id, + "from_end_user_id": message_data.from_end_user_id, "from_account_id": message_data.from_account_id, "agent_based": message_data.agent_based, "workflow_run_id": message_data.workflow_run_id, diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 5a3481b963..93dd92f188 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from core.file.models import File -class ModelMode(str, enum.Enum): +class ModelMode(enum.StrEnum): COMPLETION = "completion" CHAT = "chat" diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index 5eec5e3c99..aa175153bc 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,3 +1,4 @@ +from collections.abc import Sequence from typing import cast from core.model_runtime.entities import ( @@ -14,7 +15,7 @@ from core.prompt.simple_prompt_transform import ModelMode class PromptMessageUtil: @staticmethod - def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: list[PromptMessage]) -> list[dict]: + def prompt_messages_to_prompt_for_saving(model_mode: str, prompt_messages: Sequence[PromptMessage]) -> list[dict]: """ Prompt messages to prompt for saving. :param model_mode: model mode diff --git a/api/core/rag/cleaner/clean_processor.py b/api/core/rag/cleaner/clean_processor.py index 3c6ab2e4cf..754b0d18b7 100644 --- a/api/core/rag/cleaner/clean_processor.py +++ b/api/core/rag/cleaner/clean_processor.py @@ -12,7 +12,7 @@ class CleanProcessor: # Unicode U+FFFE text = re.sub("\ufffe", "", text) - rules = process_rule["rules"] if process_rule else None + rules = process_rule["rules"] if process_rule else {} if "pre_processing_rules" in rules: pre_processing_rules = rules["pre_processing_rules"] for pre_processing_rule in pre_processing_rules: diff --git a/api/core/rag/datasource/keyword/keyword_type.py b/api/core/rag/datasource/keyword/keyword_type.py index d6deba3fb0..d845c7111d 100644 --- a/api/core/rag/datasource/keyword/keyword_type.py +++ b/api/core/rag/datasource/keyword/keyword_type.py @@ -1,5 +1,5 @@ -from enum import Enum +from enum import StrEnum -class KeyWordType(str, Enum): +class KeyWordType(StrEnum): JIEBA = "jieba" diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index 8e53e3ae84..05183c0371 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class VectorType(str, Enum): +class VectorType(StrEnum): ANALYTICDB = "analyticdb" CHROMA = "chroma" MILVUS = "milvus" diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 313bdce48b..b23da1113e 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -114,10 +114,10 @@ class WordExtractor(BaseExtractor): mime_type=mime_type or "", created_by=self.user_id, created_by_role=CreatedByRole.ACCOUNT, - created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), used=True, used_by=self.user_id, - used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + used_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), ) db.session.add(upload_file) diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index fc82b2080b..6ae432a526 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -27,11 +27,11 @@ class RerankModelRunner(BaseRerankRunner): :return: """ docs = [] - doc_id = set() + doc_ids = set() unique_documents = [] for document in documents: - if document.provider == "dify" and document.metadata["doc_id"] not in doc_id: - doc_id.add(document.metadata["doc_id"]) + if document.provider == "dify" and document.metadata["doc_id"] not in doc_ids: + doc_ids.add(document.metadata["doc_id"]) docs.append(document.page_content) unique_documents.append(document) elif document.provider == "external": diff --git a/api/core/rag/rerank/rerank_type.py b/api/core/rag/rerank/rerank_type.py index d71eb2daa8..b2d1314654 100644 --- a/api/core/rag/rerank/rerank_type.py +++ b/api/core/rag/rerank/rerank_type.py @@ -1,6 +1,6 @@ -from enum import Enum +from enum import StrEnum -class RerankMode(str, Enum): +class RerankMode(StrEnum): RERANKING_MODEL = "reranking_model" WEIGHTED_SCORE = "weighted_score" diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index b706f29bb1..4719be012f 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -37,11 +37,10 @@ class WeightRerankRunner(BaseRerankRunner): :return: """ unique_documents = [] - doc_id = set() + doc_ids = set() for document in documents: - doc_id = document.metadata.get("doc_id") - if doc_id not in doc_id: - doc_id.add(doc_id) + if document.metadata["doc_id"] not in doc_ids: + doc_ids.add(document.metadata["doc_id"]) unique_documents.append(document) documents = unique_documents diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index d8637fd2cb..4fc383f91b 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import Enum, StrEnum from typing import Any, Optional, Union, cast from pydantic import BaseModel, Field, field_validator @@ -137,7 +137,7 @@ class ToolParameterOption(BaseModel): class ToolParameter(BaseModel): - class ToolParameterType(str, Enum): + class ToolParameterType(StrEnum): STRING = "string" NUMBER = "number" BOOLEAN = "boolean" diff --git a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py index 8bed2c556c..1afe2f8385 100644 --- a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py +++ b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py @@ -66,6 +66,41 @@ class BingSearchTool(BuiltinTool): results.append(self.create_text_message(text=f'{related.get("displayText", "")}{url}')) return results + elif result_type == "json": + result = {} + if search_results: + result["organic"] = [ + { + "title": item.get("name", ""), + "snippet": item.get("snippet", ""), + "url": item.get("url", ""), + "siteName": item.get("siteName", ""), + } + for item in search_results + ] + + if computation and "expression" in computation and "value" in computation: + result["computation"] = {"expression": computation["expression"], "value": computation["value"]} + + if entities: + result["entities"] = [ + { + "name": item.get("name", ""), + "url": item.get("url", ""), + "description": item.get("description", ""), + } + for item in entities + ] + + if news: + result["news"] = [{"name": item.get("name", ""), "url": item.get("url", "")} for item in news] + + if related_searches: + result["related searches"] = [ + {"displayText": item.get("displayText", ""), "url": item.get("webSearchUrl", "")} for item in news + ] + + return self.create_json_message(result) else: # construct text text = "" diff --git a/api/core/tools/provider/builtin/bing/tools/bing_web_search.yaml b/api/core/tools/provider/builtin/bing/tools/bing_web_search.yaml index a3f60bb09b..f5c932c37b 100644 --- a/api/core/tools/provider/builtin/bing/tools/bing_web_search.yaml +++ b/api/core/tools/provider/builtin/bing/tools/bing_web_search.yaml @@ -113,9 +113,9 @@ parameters: zh_Hans: 结果类型 pt_BR: result type human_description: - en_US: return a list of links or texts - zh_Hans: 返回一个连接列表还是纯文本内容 - pt_BR: return a list of links or texts + en_US: return a list of links, json or texts + zh_Hans: 返回一个列表,内容是链接、json还是纯文本 + pt_BR: return a list of links, json or texts default: text options: - value: link @@ -123,6 +123,11 @@ parameters: en_US: Link zh_Hans: 链接 pt_BR: Link + - value: json + label: + en_US: JSON + zh_Hans: JSON + pt_BR: JSON - value: text label: en_US: Text diff --git a/api/core/tools/provider/builtin/chart/chart.py b/api/core/tools/provider/builtin/chart/chart.py index dfa3fbea6a..8fa647d9ed 100644 --- a/api/core/tools/provider/builtin/chart/chart.py +++ b/api/core/tools/provider/builtin/chart/chart.py @@ -1,3 +1,4 @@ +import matplotlib import matplotlib.pyplot as plt from matplotlib.font_manager import FontProperties, fontManager @@ -5,7 +6,7 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl def set_chinese_font(): - font_list = [ + to_find_fonts = [ "PingFang SC", "SimHei", "Microsoft YaHei", @@ -15,16 +16,16 @@ def set_chinese_font(): "Noto Sans CJK SC", "Noto Sans CJK JP", ] - - for font in font_list: - if font in fontManager.ttflist: - chinese_font = FontProperties(font) - if chinese_font.get_name() == font: - return chinese_font + installed_fonts = frozenset(fontInfo.name for fontInfo in fontManager.ttflist) + for font in to_find_fonts: + if font in installed_fonts: + return FontProperties(font) return FontProperties() +# use non-interactive backend to prevent `RuntimeError: main thread is not in main loop` +matplotlib.use("Agg") # use a business theme plt.style.use("seaborn-v0_8-darkgrid") plt.rcParams["axes.unicode_minus"] = False diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py index 54bb38755a..b3c630878f 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py @@ -18,6 +18,12 @@ class DuckDuckGoImageSearchTool(BuiltinTool): "size": tool_parameters.get("size"), "max_results": tool_parameters.get("max_results"), } + + # Add query_prefix handling + query_prefix = tool_parameters.get("query_prefix", "").strip() + final_query = f"{query_prefix} {query_dict['keywords']}".strip() + query_dict["keywords"] = final_query + response = DDGS().images(**query_dict) markdown_result = "\n\n" json_result = [] diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.yaml b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.yaml index 168cface22..a543d1e218 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.yaml +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.yaml @@ -86,3 +86,14 @@ parameters: en_US: The size of the image to be searched. zh_Hans: 要搜索的图片的大小 form: form + - name: query_prefix + label: + en_US: Query Prefix + zh_Hans: 查询前缀 + type: string + required: false + default: "" + form: form + human_description: + en_US: Specific Search e.g. "site:unsplash.com" + zh_Hans: 定向搜索 e.g. "site:unsplash.com" diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_news.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_news.py index 3a6fd394a8..11da6f5cf7 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_news.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_news.py @@ -7,7 +7,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool SUMMARY_PROMPT = """ -User's query: +User's query: {query} Here are the news results: @@ -30,6 +30,12 @@ class DuckDuckGoNewsSearchTool(BuiltinTool): "safesearch": "moderate", "region": "wt-wt", } + + # Add query_prefix handling + query_prefix = tool_parameters.get("query_prefix", "").strip() + final_query = f"{query_prefix} {query_dict['keywords']}".strip() + query_dict["keywords"] = final_query + try: response = list(DDGS().news(**query_dict)) if not response: diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_news.yaml b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_news.yaml index eb2b67b7c9..6e181e0f41 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_news.yaml +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_news.yaml @@ -69,3 +69,14 @@ parameters: en_US: Whether to pass the news results to llm for summarization. zh_Hans: 是否需要将新闻结果传给大模型总结 form: form + - name: query_prefix + label: + en_US: Query Prefix + zh_Hans: 查询前缀 + type: string + required: false + default: "" + form: form + human_description: + en_US: Specific Search e.g. "site:msn.com" + zh_Hans: 定向搜索 e.g. "site:msn.com" diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py index cbd65d2e77..3cd35d16a6 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py @@ -7,7 +7,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool SUMMARY_PROMPT = """ -User's query: +User's query: {query} Here is the search engine result: @@ -26,7 +26,12 @@ class DuckDuckGoSearchTool(BuiltinTool): query = tool_parameters.get("query") max_results = tool_parameters.get("max_results", 5) require_summary = tool_parameters.get("require_summary", False) - response = DDGS().text(query, max_results=max_results) + + # Add query_prefix handling + query_prefix = tool_parameters.get("query_prefix", "").strip() + final_query = f"{query_prefix} {query}".strip() + + response = DDGS().text(final_query, max_results=max_results) if require_summary: results = "\n".join([res.get("body") for res in response]) results = self.summary_results(user_id=user_id, content=results, query=query) diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.yaml b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.yaml index 333c0cb093..54e27d9905 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.yaml +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.yaml @@ -39,3 +39,14 @@ parameters: en_US: Whether to pass the search results to llm for summarization. zh_Hans: 是否需要将搜索结果传给大模型总结 form: form + - name: query_prefix + label: + en_US: Query Prefix + zh_Hans: 查询前缀 + type: string + required: false + default: "" + form: form + human_description: + en_US: Specific Search e.g. "site:wikipedia.org" + zh_Hans: 定向搜索 e.g. "site:wikipedia.org" diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_video.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_video.py index 4b74b223c1..1eef0b1ba2 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_video.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_video.py @@ -24,7 +24,7 @@ max-width: 100%; border-radius: 8px;"> def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: query_dict = { - "keywords": tool_parameters.get("query"), + "keywords": tool_parameters.get("query"), # LLM's query "region": tool_parameters.get("region", "wt-wt"), "safesearch": tool_parameters.get("safesearch", "moderate"), "timelimit": tool_parameters.get("timelimit"), @@ -40,6 +40,12 @@ max-width: 100%; border-radius: 8px;"> # Get proxy URL from parameters proxy_url = tool_parameters.get("proxy_url", "").strip() + query_prefix = tool_parameters.get("query_prefix", "").strip() + final_query = f"{query_prefix} {query_dict['keywords']}".strip() + + # Update the keywords in query_dict with the final_query + query_dict["keywords"] = final_query + response = DDGS().videos(**query_dict) # Create HTML result with embedded iframes @@ -51,9 +57,13 @@ max-width: 100%; border-radius: 8px;"> embed_html = res.get("embed_html", "") description = res.get("description", "") content_url = res.get("content", "") + transcript_url = None # Handle TED.com videos - if not embed_html and "ted.com/talks" in content_url: + if "ted.com/talks" in content_url: + # Create transcript URL + transcript_url = f"{content_url}/transcript" + # Create embed URL embed_url = content_url.replace("www.ted.com", "embed.ted.com") if proxy_url: embed_url = f"{proxy_url}{embed_url}" @@ -68,8 +78,14 @@ max-width: 100%; border-radius: 8px;"> markdown_result += f"{title}\n\n" markdown_result += f"{embed_html}\n\n" + if description: + markdown_result += f"{description}\n\n" markdown_result += "---\n\n" - json_result.append(self.create_json_message(res)) + # Add transcript_url to the JSON result if available + result_dict = res.copy() + if transcript_url: + result_dict["transcript_url"] = transcript_url + json_result.append(self.create_json_message(result_dict)) return [self.create_text_message(markdown_result)] + json_result diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_video.yaml b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_video.yaml index a516d3cb98..d846244e3d 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_video.yaml +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_video.yaml @@ -95,3 +95,14 @@ parameters: en_US: Proxy URL zh_Hans: 视频代理地址 form: form + - name: query_prefix + label: + en_US: Query Prefix + zh_Hans: 查询前缀 + type: string + required: false + default: "" + form: form + human_description: + en_US: Specific Search e.g. "site:www.ted.com" + zh_Hans: 定向搜索 e.g. "site:www.ted.com" diff --git a/api/core/tools/provider/builtin/email/tools/send_mail.py b/api/core/tools/provider/builtin/email/tools/send_mail.py index d51d5439b7..33c040400c 100644 --- a/api/core/tools/provider/builtin/email/tools/send_mail.py +++ b/api/core/tools/provider/builtin/email/tools/send_mail.py @@ -17,7 +17,7 @@ class SendMailTool(BuiltinTool): invoke tools """ sender = self.runtime.credentials.get("email_account", "") - email_rgx = re.compile(r"^[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$") + email_rgx = re.compile(r"^[a-zA-Z0-9._-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$") password = self.runtime.credentials.get("email_password", "") smtp_server = self.runtime.credentials.get("smtp_server", "") if not smtp_server: diff --git a/api/core/tools/provider/builtin/email/tools/send_mail_batch.py b/api/core/tools/provider/builtin/email/tools/send_mail_batch.py index ff7e176990..537dedb27d 100644 --- a/api/core/tools/provider/builtin/email/tools/send_mail_batch.py +++ b/api/core/tools/provider/builtin/email/tools/send_mail_batch.py @@ -18,7 +18,7 @@ class SendMailTool(BuiltinTool): invoke tools """ sender = self.runtime.credentials.get("email_account", "") - email_rgx = re.compile(r"^[a-zA-Z0-9_-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$") + email_rgx = re.compile(r"^[a-zA-Z0-9._-]+@[a-zA-Z0-9_-]+(\.[a-zA-Z0-9_-]+)+$") password = self.runtime.credentials.get("email_password", "") smtp_server = self.runtime.credentials.get("smtp_server", "") if not smtp_server: diff --git a/api/core/tools/provider/builtin/gitee_ai/tools/embedding.py b/api/core/tools/provider/builtin/gitee_ai/tools/embedding.py new file mode 100644 index 0000000000..ab03759c19 --- /dev/null +++ b/api/core/tools/provider/builtin/gitee_ai/tools/embedding.py @@ -0,0 +1,25 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class GiteeAIToolEmbedding(BuiltinTool): + def _invoke( + self, user_id: str, tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + headers = { + "content-type": "application/json", + "authorization": f"Bearer {self.runtime.credentials['api_key']}", + } + + payload = {"inputs": tool_parameters.get("inputs")} + model = tool_parameters.get("model", "bge-m3") + url = f"https://ai.gitee.com/api/serverless/{model}/embeddings" + response = requests.post(url, json=payload, headers=headers) + if response.status_code != 200: + return self.create_text_message(f"Got Error Response:{response.text}") + + return [self.create_text_message(response.content.decode("utf-8"))] diff --git a/api/core/tools/provider/builtin/gitee_ai/tools/embedding.yaml b/api/core/tools/provider/builtin/gitee_ai/tools/embedding.yaml new file mode 100644 index 0000000000..53e569d731 --- /dev/null +++ b/api/core/tools/provider/builtin/gitee_ai/tools/embedding.yaml @@ -0,0 +1,37 @@ +identity: + name: embedding + author: gitee_ai + label: + en_US: embedding + icon: icon.svg +description: + human: + en_US: Generate word embeddings using Serverless-supported models (compatible with OpenAI) + llm: This tool is used to generate word embeddings from text input. +parameters: + - name: model + type: string + required: true + in: path + description: + en_US: Supported Embedding (compatible with OpenAI) interface models + enum: + - bge-m3 + - bge-large-zh-v1.5 + - bge-small-zh-v1.5 + label: + en_US: Service Model + zh_Hans: 服务模型 + default: bge-m3 + form: form + - name: inputs + type: string + required: true + label: + en_US: Input Text + zh_Hans: 输入文本 + human_description: + en_US: The text input used to generate embeddings. + zh_Hans: 用于生成词向量的输入文本。 + llm_description: This text input will be used to generate embeddings. + form: llm diff --git a/api/core/tools/provider/builtin/gitee_ai/tools/text-to-image.py b/api/core/tools/provider/builtin/gitee_ai/tools/text-to-image.py index 14291d1729..bb0b2c915b 100644 --- a/api/core/tools/provider/builtin/gitee_ai/tools/text-to-image.py +++ b/api/core/tools/provider/builtin/gitee_ai/tools/text-to-image.py @@ -6,7 +6,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool -class GiteeAITool(BuiltinTool): +class GiteeAIToolText2Image(BuiltinTool): def _invoke( self, user_id: str, tool_parameters: dict[str, Any] ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py index 1e77f3c6df..ebcf13dc99 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_files.py @@ -69,14 +69,16 @@ class GitlabFilesTool(BuiltinTool): self.fetch_files(site_url, access_token, identifier, branch, item_path, is_repository) ) else: # It's a file + encoded_item_path = urllib.parse.quote(item_path, safe="") if is_repository: file_url = ( f"{domain}/api/v4/projects/{encoded_identifier}/repository/files" - f"/{item_path}/raw?ref={branch}" + f"/{encoded_item_path}/raw?ref={branch}" ) else: file_url = ( - f"{domain}/api/v4/projects/{project_id}/repository/files/{item_path}/raw?ref={branch}" + f"{domain}/api/v4/projects/{project_id}/repository/files" + f"{encoded_item_path}/raw?ref={branch}" ) file_response = requests.get(file_url, headers=headers) diff --git a/api/core/tools/provider/builtin/json_process/tools/parse.py b/api/core/tools/provider/builtin/json_process/tools/parse.py index 37cae40153..f91432ee77 100644 --- a/api/core/tools/provider/builtin/json_process/tools/parse.py +++ b/api/core/tools/provider/builtin/json_process/tools/parse.py @@ -40,6 +40,9 @@ class JSONParseTool(BuiltinTool): expr = parse(json_filter) result = [match.value for match in expr.find(input_data)] + if not result: + return "" + if len(result) == 1: result = result[0] diff --git a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py index 762e158459..db4adfd4ad 100644 --- a/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py +++ b/api/core/tools/provider/builtin/novitaai/_novita_tool_base.py @@ -12,7 +12,7 @@ class NovitaAiToolBase: if not loras_str: return [] - loras_ori_list = lora_str.strip().split(";") + loras_ori_list = loras_str.strip().split(";") result_list = [] for lora_str in loras_ori_list: lora_info = lora_str.strip().split(",") diff --git a/api/core/tools/provider/builtin/rapidapi/_assets/rapidapi.png b/api/core/tools/provider/builtin/rapidapi/_assets/rapidapi.png new file mode 100644 index 0000000000..9c7468bb17 Binary files /dev/null and b/api/core/tools/provider/builtin/rapidapi/_assets/rapidapi.png differ diff --git a/api/core/tools/provider/builtin/rapidapi/rapidapi.py b/api/core/tools/provider/builtin/rapidapi/rapidapi.py new file mode 100644 index 0000000000..31077b0894 --- /dev/null +++ b/api/core/tools/provider/builtin/rapidapi/rapidapi.py @@ -0,0 +1,22 @@ +from typing import Any + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.rapidapi.tools.google_news import GooglenewsTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class RapidapiProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + try: + GooglenewsTool().fork_tool_runtime( + meta={ + "credentials": credentials, + } + ).invoke( + user_id="", + tool_parameters={ + "language_region": "en-US", + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/rapidapi/rapidapi.yaml b/api/core/tools/provider/builtin/rapidapi/rapidapi.yaml new file mode 100644 index 0000000000..3f1d1c5824 --- /dev/null +++ b/api/core/tools/provider/builtin/rapidapi/rapidapi.yaml @@ -0,0 +1,39 @@ +identity: + name: rapidapi + author: Steven Sun + label: + en_US: RapidAPI + zh_Hans: RapidAPI + description: + en_US: RapidAPI is the world's largest API marketplace with over 1,000,000 developers and 10,000 APIs. + zh_Hans: RapidAPI是全球最大的API市场,拥有超过100万开发人员和10000个API。 + icon: rapidapi.png + tags: + - news +credentials_for_provider: + x-rapidapi-host: + type: text-input + required: true + label: + en_US: x-rapidapi-host + zh_Hans: x-rapidapi-host + placeholder: + en_US: Please input your x-rapidapi-host + zh_Hans: 请输入你的 x-rapidapi-host + help: + en_US: Get your x-rapidapi-host from RapidAPI. + zh_Hans: 从 RapidAPI 获取您的 x-rapidapi-host。 + url: https://rapidapi.com/ + x-rapidapi-key: + type: secret-input + required: true + label: + en_US: x-rapidapi-key + zh_Hans: x-rapidapi-key + placeholder: + en_US: Please input your x-rapidapi-key + zh_Hans: 请输入你的 x-rapidapi-key + help: + en_US: Get your x-rapidapi-key from RapidAPI. + zh_Hans: 从 RapidAPI 获取您的 x-rapidapi-key。 + url: https://rapidapi.com/ diff --git a/api/core/tools/provider/builtin/rapidapi/tools/google_news.py b/api/core/tools/provider/builtin/rapidapi/tools/google_news.py new file mode 100644 index 0000000000..d4b6dc4a46 --- /dev/null +++ b/api/core/tools/provider/builtin/rapidapi/tools/google_news.py @@ -0,0 +1,33 @@ +from typing import Any, Union + +import requests + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.errors import ToolInvokeError, ToolProviderCredentialValidationError +from core.tools.tool.builtin_tool import BuiltinTool + + +class GooglenewsTool(BuiltinTool): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + key = self.runtime.credentials.get("x-rapidapi-key", "") + host = self.runtime.credentials.get("x-rapidapi-host", "") + if not all([key, host]): + raise ToolProviderCredentialValidationError("Please input correct x-rapidapi-key and x-rapidapi-host") + headers = {"x-rapidapi-key": key, "x-rapidapi-host": host} + lr = tool_parameters.get("language_region", "") + url = f"https://{host}/latest?lr={lr}" + response = requests.get(url, headers=headers) + if response.status_code != 200: + raise ToolInvokeError(f"Error {response.status_code}: {response.text}") + return self.create_text_message(response.text) + + def validate_credentials(self, parameters: dict[str, Any]) -> None: + parameters["validate"] = True + self._invoke(parameters) diff --git a/api/core/tools/provider/builtin/rapidapi/tools/google_news.yaml b/api/core/tools/provider/builtin/rapidapi/tools/google_news.yaml new file mode 100644 index 0000000000..547681b166 --- /dev/null +++ b/api/core/tools/provider/builtin/rapidapi/tools/google_news.yaml @@ -0,0 +1,24 @@ +identity: + name: google_news + author: Steven Sun + label: + en_US: GoogleNews + zh_Hans: 谷歌新闻 +description: + human: + en_US: google news is a news aggregator service developed by Google. It presents a continuous, customizable flow of articles organized from thousands of publishers and magazines. + zh_Hans: 谷歌新闻是由谷歌开发的新闻聚合服务。它提供了一个持续的、可定制的文章流,这些文章是从成千上万的出版商和杂志中整理出来的。 + llm: A tool to get the latest news from Google News. +parameters: + - name: language_region + type: string + required: true + label: + en_US: Language and Region + zh_Hans: 语言和地区 + human_description: + en_US: The language and region determine the language and region of the search results, and its value is assigned according to the "National Language Code Comparison Table", such as en-US, which stands for English (United States); zh-CN, stands for Chinese (Simplified). + zh_Hans: 语言和地区决定了搜索结果的语言和地区,其赋值按照《国家语言代码对照表》,形如en-US,代表英语(美国);zh-CN,代表中文(简体)。 + llm_description: The language and region determine the language and region of the search results, and its value is assigned according to the "National Language Code Comparison Table", such as en-US, which stands for English (United States); zh-CN, stands for Chinese (Simplified). + default: en-US + form: llm diff --git a/api/core/tools/provider/builtin/searchapi/tools/google.py b/api/core/tools/provider/builtin/searchapi/tools/google.py index 17e2978194..29d36f5f23 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google.py @@ -45,7 +45,7 @@ class SearchAPI: def _process_response(res: dict, type: str) -> str: """Process response from SearchAPI.""" if "error" in res: - raise ValueError(f"Got error from SearchApi: {res['error']}") + return res["error"] toret = "" if type == "text": diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py index c478bc108b..de42360898 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py @@ -45,7 +45,7 @@ class SearchAPI: def _process_response(res: dict, type: str) -> str: """Process response from SearchAPI.""" if "error" in res: - raise ValueError(f"Got error from SearchApi: {res['error']}") + return res["error"] toret = "" if type == "text": diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_news.py b/api/core/tools/provider/builtin/searchapi/tools/google_news.py index 562bc01964..c8b3ccda05 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_news.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_news.py @@ -45,7 +45,7 @@ class SearchAPI: def _process_response(res: dict, type: str) -> str: """Process response from SearchAPI.""" if "error" in res: - raise ValueError(f"Got error from SearchApi: {res['error']}") + return res["error"] toret = "" if type == "text": diff --git a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py index 1867cf7be7..b14821f831 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py +++ b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py @@ -45,7 +45,7 @@ class SearchAPI: def _process_response(res: dict) -> str: """Process response from SearchAPI.""" if "error" in res: - raise ValueError(f"Got error from SearchApi: {res['error']}") + return res["error"] toret = "" if "transcripts" in res and "text" in res["transcripts"][0]: diff --git a/api/core/tools/provider/builtin/slidespeak/_assets/icon.png b/api/core/tools/provider/builtin/slidespeak/_assets/icon.png new file mode 100644 index 0000000000..4cac578330 Binary files /dev/null and b/api/core/tools/provider/builtin/slidespeak/_assets/icon.png differ diff --git a/api/core/tools/provider/builtin/slidespeak/slidespeak.py b/api/core/tools/provider/builtin/slidespeak/slidespeak.py new file mode 100644 index 0000000000..14c7c4880e --- /dev/null +++ b/api/core/tools/provider/builtin/slidespeak/slidespeak.py @@ -0,0 +1,28 @@ +from typing import Any + +import requests +from yarl import URL + +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class SlideSpeakProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict[str, Any]) -> None: + api_key = credentials.get("slidespeak_api_key") + base_url = credentials.get("base_url") + + if not api_key: + raise ToolProviderCredentialValidationError("API key is missing") + + if base_url: + base_url = str(URL(base_url) / "v1") + + headers = {"Content-Type": "application/json", "X-API-Key": api_key} + + test_task_id = "xxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx" + url = f"{base_url or 'https://api.slidespeak.co/api/v1'}/task_status/{test_task_id}" + + response = requests.get(url, headers=headers) + if response.status_code != 200: + raise ToolProviderCredentialValidationError("Invalid SlidePeak API key") diff --git a/api/core/tools/provider/builtin/slidespeak/slidespeak.yaml b/api/core/tools/provider/builtin/slidespeak/slidespeak.yaml new file mode 100644 index 0000000000..9f6927f1bd --- /dev/null +++ b/api/core/tools/provider/builtin/slidespeak/slidespeak.yaml @@ -0,0 +1,22 @@ +identity: + author: Kalo Chin + name: slidespeak + label: + en_US: SlideSpeak + zh_Hans: SlideSpeak + description: + en_US: Generate presentation slides using SlideSpeak API + zh_Hans: 使用 SlideSpeak API 生成演示幻灯片 + icon: icon.png + +credentials_for_provider: + slidespeak_api_key: + type: secret-input + required: true + label: + en_US: API Key + zh_Hans: API 密钥 + placeholder: + en_US: Enter your SlideSpeak API key + zh_Hans: 输入您的 SlideSpeak API 密钥 + url: https://app.slidespeak.co/settings/developer diff --git a/api/core/tools/provider/builtin/slidespeak/tools/slides_generator.py b/api/core/tools/provider/builtin/slidespeak/tools/slides_generator.py new file mode 100644 index 0000000000..aa4ee63e97 --- /dev/null +++ b/api/core/tools/provider/builtin/slidespeak/tools/slides_generator.py @@ -0,0 +1,163 @@ +import asyncio +from dataclasses import asdict, dataclass +from enum import Enum +from typing import Any, Optional, Union + +import aiohttp +from pydantic import ConfigDict + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.tool.builtin_tool import BuiltinTool + + +class SlidesGeneratorTool(BuiltinTool): + """ + Tool for generating presentations using the SlideSpeak API. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + headers: Optional[dict[str, str]] = None + base_url: Optional[str] = None + timeout: Optional[aiohttp.ClientTimeout] = None + poll_interval: Optional[int] = None + + class TaskState(Enum): + FAILURE = "FAILURE" + REVOKED = "REVOKED" + SUCCESS = "SUCCESS" + PENDING = "PENDING" + RECEIVED = "RECEIVED" + STARTED = "STARTED" + + @dataclass + class PresentationRequest: + plain_text: str + length: Optional[int] = None + theme: Optional[str] = None + + async def _generate_presentation( + self, + session: aiohttp.ClientSession, + request: PresentationRequest, + ) -> dict[str, Any]: + """Generate a new presentation asynchronously""" + async with session.post( + f"{self.base_url}/presentation/generate", + headers=self.headers, + json=asdict(request), + timeout=self.timeout, + ) as response: + response.raise_for_status() + return await response.json() + + async def _get_task_status( + self, + session: aiohttp.ClientSession, + task_id: str, + ) -> dict[str, Any]: + """Get the status of a task asynchronously""" + async with session.get( + f"{self.base_url}/task_status/{task_id}", + headers=self.headers, + timeout=self.timeout, + ) as response: + response.raise_for_status() + return await response.json() + + async def _wait_for_completion( + self, + session: aiohttp.ClientSession, + task_id: str, + ) -> str: + """Wait for task completion and return download URL""" + while True: + status = await self._get_task_status(session, task_id) + task_status = self.TaskState(status["task_status"]) + if task_status == self.TaskState.SUCCESS: + return status["task_result"]["url"] + if task_status in [self.TaskState.FAILURE, self.TaskState.REVOKED]: + raise Exception(f"Task failed with status: {task_status.value}") + await asyncio.sleep(self.poll_interval) + + async def _generate_slides( + self, + plain_text: str, + length: Optional[int], + theme: Optional[str], + ) -> str: + """Generate slides and return the download URL""" + async with aiohttp.ClientSession() as session: + request = self.PresentationRequest( + plain_text=plain_text, + length=length, + theme=theme, + ) + result = await self._generate_presentation(session, request) + task_id = result["task_id"] + download_url = await self._wait_for_completion(session, task_id) + return download_url + + async def _fetch_presentation( + self, + session: aiohttp.ClientSession, + download_url: str, + ) -> bytes: + """Fetch the presentation file from the download URL""" + async with session.get(download_url, timeout=self.timeout) as response: + response.raise_for_status() + return await response.read() + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """Synchronous invoke method that runs asynchronous code""" + + async def async_invoke(): + # Extract parameters + plain_text = tool_parameters.get("plain_text", "") + length = tool_parameters.get("length") + theme = tool_parameters.get("theme") + + # Ensure runtime and credentials + if not self.runtime or not self.runtime.credentials: + raise ToolProviderCredentialValidationError("Tool runtime or credentials are missing") + + # Get API key from credentials + api_key = self.runtime.credentials.get("slidespeak_api_key") + if not api_key: + raise ToolProviderCredentialValidationError("SlideSpeak API key is missing") + + # Set configuration + self.headers = { + "Content-Type": "application/json", + "X-API-Key": api_key, + } + self.base_url = "https://api.slidespeak.co/api/v1" + self.timeout = aiohttp.ClientTimeout(total=30) + self.poll_interval = 2 + + # Run the asynchronous slide generation + try: + download_url = await self._generate_slides(plain_text, length, theme) + + # Fetch the presentation file + async with aiohttp.ClientSession() as session: + presentation_bytes = await self._fetch_presentation(session, download_url) + + return [ + self.create_text_message(download_url), + self.create_blob_message( + blob=presentation_bytes, + meta={"mime_type": "application/vnd.openxmlformats-officedocument.presentationml.presentation"}, + ), + ] + except Exception as e: + return [self.create_text_message(f"An error occurred: {str(e)}")] + + # Run the asynchronous code synchronously + result = asyncio.run(async_invoke()) + return result diff --git a/api/core/tools/provider/builtin/slidespeak/tools/slides_generator.yaml b/api/core/tools/provider/builtin/slidespeak/tools/slides_generator.yaml new file mode 100644 index 0000000000..f881dadb20 --- /dev/null +++ b/api/core/tools/provider/builtin/slidespeak/tools/slides_generator.yaml @@ -0,0 +1,102 @@ +identity: + name: slide_generator + author: Kalo Chin + label: + en_US: Slides Generator + zh_Hans: 幻灯片生成器 +description: + human: + en_US: Generate presentation slides from text using SlideSpeak API. + zh_Hans: 使用 SlideSpeak API 从文本生成演示幻灯片。 + llm: This tool converts text input into a presentation using the SlideSpeak API service, with options for slide length and theme. +parameters: + - name: plain_text + type: string + required: true + label: + en_US: Topic or Content + zh_Hans: 主题或内容 + human_description: + en_US: The topic or content to be converted into presentation slides. + zh_Hans: 需要转换为幻灯片的内容或主题。 + llm_description: A string containing the topic or content to be transformed into presentation slides. + form: llm + - name: length + type: number + required: false + label: + en_US: Number of Slides + zh_Hans: 幻灯片数量 + human_description: + en_US: The desired number of slides in the presentation (optional). + zh_Hans: 演示文稿中所需的幻灯片数量(可选)。 + llm_description: Optional parameter specifying the number of slides to generate. + form: form + - name: theme + type: select + required: false + label: + en_US: Presentation Theme + zh_Hans: 演示主题 + human_description: + en_US: The visual theme for the presentation (optional). + zh_Hans: 演示文稿的视觉主题(可选)。 + llm_description: Optional parameter specifying the presentation theme. + options: + - label: + en_US: Adam + zh_Hans: Adam + value: adam + - label: + en_US: Aurora + zh_Hans: Aurora + value: aurora + - label: + en_US: Bruno + zh_Hans: Bruno + value: bruno + - label: + en_US: Clyde + zh_Hans: Clyde + value: clyde + - label: + en_US: Daniel + zh_Hans: Daniel + value: daniel + - label: + en_US: Default + zh_Hans: Default + value: default + - label: + en_US: Eddy + zh_Hans: Eddy + value: eddy + - label: + en_US: Felix + zh_Hans: Felix + value: felix + - label: + en_US: Gradient + zh_Hans: Gradient + value: gradient + - label: + en_US: Iris + zh_Hans: Iris + value: iris + - label: + en_US: Lavender + zh_Hans: Lavender + value: lavender + - label: + en_US: Monolith + zh_Hans: Monolith + value: monolith + - label: + en_US: Nebula + zh_Hans: Nebula + value: nebula + - label: + en_US: Nexus + zh_Hans: Nexus + value: nexus + form: form diff --git a/api/core/tools/provider/builtin/time/tools/current_time.py b/api/core/tools/provider/builtin/time/tools/current_time.py index cc38739c16..6464bb6602 100644 --- a/api/core/tools/provider/builtin/time/tools/current_time.py +++ b/api/core/tools/provider/builtin/time/tools/current_time.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any, Union from pytz import timezone as pytz_timezone @@ -20,7 +20,7 @@ class CurrentTimeTool(BuiltinTool): tz = tool_parameters.get("timezone", "UTC") fm = tool_parameters.get("format") or "%Y-%m-%d %H:%M:%S %Z" if tz == "UTC": - return self.create_text_message(f"{datetime.now(timezone.utc).strftime(fm)}") + return self.create_text_message(f"{datetime.now(UTC).strftime(fm)}") try: tz = pytz_timezone(tz) diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index c779d704c3..0b4c5bd2c6 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -5,6 +5,7 @@ from urllib.parse import urlencode import httpx +from core.file.file_manager import download from core.helper import ssrf_proxy from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType @@ -138,6 +139,7 @@ class ApiTool(Tool): path_params = {} body = {} cookies = {} + files = [] # check parameters for parameter in self.api_bundle.openapi.get("parameters", []): @@ -166,8 +168,12 @@ class ApiTool(Tool): properties = body_schema.get("properties", {}) for name, property in properties.items(): if name in parameters: - # convert type - body[name] = self._convert_body_property_type(property, parameters[name]) + if property.get("format") == "binary": + f = parameters[name] + files.append((name, (f.filename, download(f), f.mime_type))) + else: + # convert type + body[name] = self._convert_body_property_type(property, parameters[name]) elif name in required: raise ToolParameterValidationError( f"Missing required parameter {name} in operation {self.api_bundle.operation_id}" @@ -182,7 +188,7 @@ class ApiTool(Tool): for name, value in path_params.items(): url = url.replace(f"{{{name}}}", f"{value}") - # parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored + # parse http body data if needed if "Content-Type" in headers: if headers["Content-Type"] == "application/json": body = json.dumps(body) @@ -198,6 +204,7 @@ class ApiTool(Tool): headers=headers, cookies=cookies, data=body, + files=files, timeout=API_TOOL_DEFAULT_TIMEOUT, follow_redirects=True, ) diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 6cb6e18b6d..38a05ccf91 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from copy import deepcopy -from enum import Enum +from enum import Enum, StrEnum from typing import TYPE_CHECKING, Any, Optional, Union from pydantic import BaseModel, ConfigDict, field_validator @@ -62,7 +62,7 @@ class Tool(BaseModel, ABC): def __init__(self, **data: Any): super().__init__(**data) - class VariableKey(str, Enum): + class VariableKey(StrEnum): IMAGE = "image" DOCUMENT = "document" VIDEO = "video" @@ -261,7 +261,7 @@ class Tool(BaseModel, ABC): """ parameters = self.parameters or [] parameters = parameters.copy() - user_parameters = self.get_runtime_parameters() or [] + user_parameters = self.get_runtime_parameters() user_parameters = user_parameters.copy() # override parameters diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 9e290c3651..f92b43608e 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -1,7 +1,7 @@ import json from collections.abc import Mapping from copy import deepcopy -from datetime import datetime, timezone +from datetime import UTC, datetime from mimetypes import guess_type from typing import Any, Optional, Union @@ -55,13 +55,18 @@ class ToolEngine: # check if this tool has only one parameter parameters = [ parameter - for parameter in tool.get_runtime_parameters() or [] + for parameter in tool.get_runtime_parameters() if parameter.form == ToolParameter.ToolParameterForm.LLM ] if parameters and len(parameters) == 1: tool_parameters = {parameters[0].name: tool_parameters} else: - raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") + try: + tool_parameters = json.loads(tool_parameters) + except Exception as e: + pass + if not isinstance(tool_parameters, dict): + raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") # invoke the tool try: @@ -158,7 +163,7 @@ class ToolEngine: """ Invoke the tool with the given arguments. """ - started_at = datetime.now(timezone.utc) + started_at = datetime.now(UTC) meta = ToolInvokeMeta( time_cost=0.0, error=None, @@ -176,7 +181,7 @@ class ToolEngine: meta.error = str(e) raise ToolEngineInvokeError(meta) finally: - ended_at = datetime.now(timezone.utc) + ended_at = datetime.now(UTC) meta.time_cost = (ended_at - started_at).total_seconds() return meta, response diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 83600d21c1..8b5e27f538 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -127,7 +127,7 @@ class ToolParameterConfigurationManager(BaseModel): # get tool parameters tool_parameters = self.tool_runtime.parameters or [] # get tool runtime parameters - runtime_parameters = self.tool_runtime.get_runtime_parameters() or [] + runtime_parameters = self.tool_runtime.get_runtime_parameters() # override parameters current_parameters = tool_parameters.copy() for runtime_parameter in runtime_parameters: diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 5867a11bb3..ae44b1b99d 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -161,6 +161,9 @@ class ApiBasedToolSchemaParser: def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType: parameter = parameter or {} typ = None + if parameter.get("format") == "binary": + return ToolParameter.ToolParameterType.FILE + if "type" in parameter: typ = parameter["type"] elif "schema" in parameter and "type" in parameter["schema"]: diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index b71882b043..69bd5567a4 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -118,11 +118,11 @@ class FileSegment(Segment): @property def log(self) -> str: - return str(self.value) + return "" @property def text(self) -> str: - return str(self.value) + return "" class ArrayAnySegment(ArraySegment): @@ -155,3 +155,11 @@ class ArrayFileSegment(ArraySegment): for item in self.value: items.append(item.markdown) return "\n".join(items) + + @property + def log(self) -> str: + return "" + + @property + def text(self) -> str: + return "" diff --git a/api/core/variables/types.py b/api/core/variables/types.py index 53c2e8a3aa..af6a2a2937 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class SegmentType(str, Enum): +class SegmentType(StrEnum): NONE = "none" NUMBER = "number" STRING = "string" diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index a747266661..e174d3baa0 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from enum import Enum +from enum import StrEnum from typing import Any, Optional from pydantic import BaseModel @@ -8,7 +8,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage from models.workflow import WorkflowNodeExecutionStatus -class NodeRunMetadataKey(str, Enum): +class NodeRunMetadataKey(StrEnum): """ Node Run Metadata Key. """ @@ -36,7 +36,7 @@ class NodeRunResult(BaseModel): inputs: Optional[Mapping[str, Any]] = None # node inputs process_data: Optional[dict[str, Any]] = None # process data - outputs: Optional[dict[str, Any]] = None # node outputs + outputs: Optional[Mapping[str, Any]] = None # node outputs metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata llm_usage: Optional[LLMUsage] = None # llm usage diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index 213ed57f57..9642efa1a5 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class SystemVariableKey(str, Enum): +class SystemVariableKey(StrEnum): """ System Variables. """ diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index bb24b51112..baeec9bf01 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -1,5 +1,5 @@ import uuid -from datetime import datetime, timezone +from datetime import UTC, datetime from enum import Enum from typing import Optional @@ -63,7 +63,7 @@ class RouteNodeState(BaseModel): raise Exception(f"Invalid route status {run_result.status}") self.node_run_result = run_result - self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + self.finished_at = datetime.now(UTC).replace(tzinfo=None) class RuntimeRouteState(BaseModel): @@ -81,7 +81,7 @@ class RuntimeRouteState(BaseModel): :param node_id: node id """ - state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None)) + state = RouteNodeState(node_id=node_id, start_at=datetime.now(UTC).replace(tzinfo=None)) self.node_state_mapping[state.id] = state return state diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index 208144655b..9e9e52910e 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class NodeType(str, Enum): +class NodeType(StrEnum): START = "start" END = "end" ANSWER = "answer" diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/core/workflow/nodes/http_request/entities.py index 36ded104c1..5e39ef79d1 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/core/workflow/nodes/http_request/entities.py @@ -1,4 +1,6 @@ +import mimetypes from collections.abc import Sequence +from email.message import Message from typing import Any, Literal, Optional import httpx @@ -7,14 +9,6 @@ from pydantic import BaseModel, Field, ValidationInfo, field_validator from configs import dify_config from core.workflow.nodes.base import BaseNodeData -NON_FILE_CONTENT_TYPES = ( - "application/json", - "application/xml", - "text/html", - "text/plain", - "application/x-www-form-urlencoded", -) - class HttpRequestNodeAuthorizationConfig(BaseModel): type: Literal["basic", "bearer", "custom"] @@ -93,13 +87,53 @@ class Response: @property def is_file(self): - content_type = self.content_type + """ + Determine if the response contains a file by checking: + 1. Content-Disposition header (RFC 6266) + 2. Content characteristics + 3. MIME type analysis + """ + content_type = self.content_type.split(";")[0].strip().lower() content_disposition = self.response.headers.get("content-disposition", "") - return "attachment" in content_disposition or ( - not any(non_file in content_type for non_file in NON_FILE_CONTENT_TYPES) - and any(file_type in content_type for file_type in ("application/", "image/", "audio/", "video/")) - ) + # Check if it's explicitly marked as an attachment + if content_disposition: + msg = Message() + msg["content-disposition"] = content_disposition + disp_type = msg.get_content_disposition() # Returns 'attachment', 'inline', or None + filename = msg.get_filename() # Returns filename if present, None otherwise + if disp_type == "attachment" or filename is not None: + return True + + # For application types, try to detect if it's a text-based format + if content_type.startswith("application/"): + # Common text-based application types + if any( + text_type in content_type + for text_type in ("json", "xml", "javascript", "x-www-form-urlencoded", "yaml", "graphql") + ): + return False + + # Try to detect if content is text-based by sampling first few bytes + try: + # Sample first 1024 bytes for text detection + content_sample = self.response.content[:1024] + content_sample.decode("utf-8") + # If we can decode as UTF-8 and find common text patterns, likely not a file + text_markers = (b"{", b"[", b"<", b"function", b"var ", b"const ", b"let ") + if any(marker in content_sample for marker in text_markers): + return False + except UnicodeDecodeError: + # If we can't decode as UTF-8, likely a binary file + return True + + # For other types, use MIME type analysis + main_type, _ = mimetypes.guess_type("dummy" + (mimetypes.guess_extension(content_type) or "")) + if main_type: + return main_type.split("/")[0] in ("application", "image", "audio", "video") + + # For unknown types, check if it's a media type + return any(media_type in content_type for media_type in ("image/", "audio/", "video/")) @property def content_type(self) -> str: diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 80b322b068..22ad2a39f6 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -108,7 +108,7 @@ class Executor: self.content = self.variable_pool.convert_template(data[0].value).text case "json": json_string = self.variable_pool.convert_template(data[0].value).text - json_object = json.loads(json_string) + json_object = json.loads(json_string, strict=False) self.json = json_object # self.json = self._parse_object_contains_variables(json_object) case "binary": diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py index ebcb6f82fb..7a489dd725 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/core/workflow/nodes/iteration/entities.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import StrEnum from typing import Any, Optional from pydantic import Field @@ -6,7 +6,7 @@ from pydantic import Field from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData -class ErrorHandleMode(str, Enum): +class ErrorHandleMode(StrEnum): TERMINATED = "terminated" CONTINUE_ON_ERROR = "continue-on-error" REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output" diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index d5428f0286..22f242a42f 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -2,7 +2,7 @@ import logging import uuid from collections.abc import Generator, Mapping, Sequence from concurrent.futures import Future, wait -from datetime import datetime, timezone +from datetime import UTC, datetime from queue import Empty, Queue from typing import TYPE_CHECKING, Any, Optional, cast @@ -135,7 +135,7 @@ class IterationNode(BaseNode[IterationNodeData]): thread_pool_id=self.thread_pool_id, ) - start_at = datetime.now(timezone.utc).replace(tzinfo=None) + start_at = datetime.now(UTC).replace(tzinfo=None) yield IterationRunStartedEvent( iteration_id=self.id, @@ -367,7 +367,7 @@ class IterationNode(BaseNode[IterationNodeData]): """ run single iteration """ - iter_start_at = datetime.now(timezone.utc).replace(tzinfo=None) + iter_start_at = datetime.now(UTC).replace(tzinfo=None) try: rst = graph_engine.run() @@ -440,7 +440,7 @@ class IterationNode(BaseNode[IterationNodeData]): variable_pool.add([self.node_id, "index"], next_index) if next_index < len(iterator_list_value): variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds() + duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() iter_run_map[iteration_run_id] = duration yield IterationRunNextEvent( iteration_id=self.id, @@ -461,7 +461,7 @@ class IterationNode(BaseNode[IterationNodeData]): if next_index < len(iterator_list_value): variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds() + duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() iter_run_map[iteration_run_id] = duration yield IterationRunNextEvent( iteration_id=self.id, @@ -503,7 +503,7 @@ class IterationNode(BaseNode[IterationNodeData]): if next_index < len(iterator_list_value): variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds() + duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() iter_run_map[iteration_run_id] = duration yield IterationRunNextEvent( iteration_id=self.id, diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index a25d563fe0..19a66087f7 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -39,7 +39,14 @@ class VisionConfig(BaseModel): class PromptConfig(BaseModel): - jinja2_variables: Optional[list[VariableSelector]] = None + jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list) + + @field_validator("jinja2_variables", mode="before") + @classmethod + def convert_none_jinja2_variables(cls, v: Any): + if v is None: + return [] + return v class LLMNodeChatModelMessage(ChatModelMessage): @@ -53,7 +60,14 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): class LLMNodeData(BaseNodeData): model: ModelConfig prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - prompt_config: Optional[PromptConfig] = None + prompt_config: PromptConfig = Field(default_factory=PromptConfig) memory: Optional[MemoryConfig] = None context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) + + @field_validator("prompt_config", mode="before") + @classmethod + def convert_none_prompt_config(cls, v: Any): + if v is None: + return PromptConfig() + return v diff --git a/api/core/workflow/nodes/llm/exc.py b/api/core/workflow/nodes/llm/exc.py index f858be2515..6599221691 100644 --- a/api/core/workflow/nodes/llm/exc.py +++ b/api/core/workflow/nodes/llm/exc.py @@ -24,3 +24,17 @@ class LLMModeRequiredError(LLMNodeError): class NoPromptFoundError(LLMNodeError): """Raised when no prompt is found in the LLM configuration.""" + + +class TemplateTypeNotSupportError(LLMNodeError): + def __init__(self, *, type_name: str): + super().__init__(f"Prompt type {type_name} is not supported.") + + +class MemoryRolePrefixRequiredError(LLMNodeError): + """Raised when memory role prefix is required for completion model.""" + + +class FileTypeNotSupportError(LLMNodeError): + def __init__(self, *, type_name: str): + super().__init__(f"{type_name} type is not supported by this model") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index eb4d1c9d87..39480e34b3 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -1,4 +1,5 @@ import json +import logging from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional, cast @@ -6,21 +7,27 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.entities.model_entities import ModelStatus from core.entities.provider_entities import QuotaUnit from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from core.file import FileType, file_manager +from core.helper.code_executor import CodeExecutor, CodeLanguage from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities import ( - AudioPromptMessageContent, ImagePromptMessageContent, PromptMessage, PromptMessageContentType, TextPromptMessageContent, - VideoPromptMessageContent, ) from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageContent, + PromptMessageRole, + SystemPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.variables import ( @@ -34,6 +41,8 @@ from core.variables import ( ) from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes.base import BaseNode @@ -62,14 +71,18 @@ from .exc import ( InvalidVariableTypeError, LLMModeRequiredError, LLMNodeError, + MemoryRolePrefixRequiredError, ModelNotExistError, NoPromptFoundError, + TemplateTypeNotSupportError, VariableNotFoundError, ) if TYPE_CHECKING: from core.file.models import File +logger = logging.getLogger(__name__) + class LLMNode(BaseNode[LLMNodeData]): _node_data_cls = LLMNodeData @@ -121,19 +134,19 @@ class LLMNode(BaseNode[LLMNodeData]): # fetch memory memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance) - # fetch prompt messages + query = None if self.node_data.memory: - query = self.graph_runtime_state.variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) - if not query: - raise VariableNotFoundError("Query not found") - query = query.text - else: - query = None + query = self.node_data.memory.query_prompt_template + if not query and ( + query_variable := self.graph_runtime_state.variable_pool.get( + (SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY) + ) + ): + query = query_variable.text prompt_messages, stop = self._fetch_prompt_messages( - system_query=query, - inputs=inputs, - files=files, + user_query=query, + user_files=files, context=context, memory=memory, model_config=model_config, @@ -141,6 +154,8 @@ class LLMNode(BaseNode[LLMNodeData]): memory_config=self.node_data.memory, vision_enabled=self.node_data.vision.enabled, vision_detail=self.node_data.vision.configs.detail, + variable_pool=self.graph_runtime_state.variable_pool, + jinja2_variables=self.node_data.prompt_config.jinja2_variables, ) process_data = { @@ -181,6 +196,17 @@ class LLMNode(BaseNode[LLMNodeData]): ) ) return + except Exception as e: + logger.exception(f"Node {self.node_id} failed to run") + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=str(e), + inputs=node_inputs, + process_data=process_data, + ) + ) + return outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} @@ -203,8 +229,8 @@ class LLMNode(BaseNode[LLMNodeData]): self, node_data_model: ModelConfig, model_instance: ModelInstance, - prompt_messages: list[PromptMessage], - stop: Optional[list[str]] = None, + prompt_messages: Sequence[PromptMessage], + stop: Optional[Sequence[str]] = None, ) -> Generator[NodeEvent, None, None]: db.session.close() @@ -519,9 +545,8 @@ class LLMNode(BaseNode[LLMNodeData]): def _fetch_prompt_messages( self, *, - system_query: str | None = None, - inputs: dict[str, str] | None = None, - files: Sequence["File"], + user_query: str | None = None, + user_files: Sequence["File"], context: str | None = None, memory: TokenBufferMemory | None = None, model_config: ModelConfigWithCredentialsEntity, @@ -529,58 +554,144 @@ class LLMNode(BaseNode[LLMNodeData]): memory_config: MemoryConfig | None = None, vision_enabled: bool = False, vision_detail: ImagePromptMessageContent.DETAIL, - ) -> tuple[list[PromptMessage], Optional[list[str]]]: - inputs = inputs or {} + variable_pool: VariablePool, + jinja2_variables: Sequence[VariableSelector], + ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: + prompt_messages = [] - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs=inputs, - query=system_query or "", - files=files, - context=context, - memory_config=memory_config, - memory=memory, - model_config=model_config, - ) - stop = model_config.stop + if isinstance(prompt_template, list): + # For chat model + prompt_messages.extend( + _handle_list_messages( + messages=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + vision_detail_config=vision_detail, + ) + ) + + # Get memory messages for chat mode + memory_messages = _handle_memory_chat_mode( + memory=memory, + memory_config=memory_config, + model_config=model_config, + ) + # Extend prompt_messages with memory messages + prompt_messages.extend(memory_messages) + + # Add current query to the prompt messages + if user_query: + message = LLMNodeChatModelMessage( + text=user_query, + role=PromptMessageRole.USER, + edition_type="basic", + ) + prompt_messages.extend( + _handle_list_messages( + messages=[message], + context="", + jinja2_variables=[], + variable_pool=variable_pool, + vision_detail_config=vision_detail, + ) + ) + + elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): + # For completion model + prompt_messages.extend( + _handle_completion_template( + template=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + ) + ) + + # Get memory text for completion model + memory_text = _handle_memory_completion_mode( + memory=memory, + memory_config=memory_config, + model_config=model_config, + ) + # Insert histories into the prompt + prompt_content = prompt_messages[0].content + if "#histories#" in prompt_content: + prompt_content = prompt_content.replace("#histories#", memory_text) + else: + prompt_content = memory_text + "\n" + prompt_content + prompt_messages[0].content = prompt_content + + # Add current query to the prompt message + if user_query: + prompt_content = prompt_messages[0].content.replace("#sys.query#", user_query) + prompt_messages[0].content = prompt_content + else: + raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) + + if vision_enabled and user_files: + file_prompts = [] + for file in user_files: + file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) + file_prompts.append(file_prompt) + if ( + len(prompt_messages) > 0 + and isinstance(prompt_messages[-1], UserPromptMessage) + and isinstance(prompt_messages[-1].content, list) + ): + prompt_messages[-1] = UserPromptMessage(content=prompt_messages[-1].content + file_prompts) + else: + prompt_messages.append(UserPromptMessage(content=file_prompts)) + + # Filter prompt messages filtered_prompt_messages = [] for prompt_message in prompt_messages: - if prompt_message.is_empty(): - continue - - if not isinstance(prompt_message.content, str): + if isinstance(prompt_message.content, list): prompt_message_content = [] - for content_item in prompt_message.content or []: - # Skip image if vision is disabled - if not vision_enabled and content_item.type == PromptMessageContentType.IMAGE: + for content_item in prompt_message.content: + # Skip content if features are not defined + if not model_config.model_schema.features: + if content_item.type != PromptMessageContentType.TEXT: + continue + prompt_message_content.append(content_item) continue - if isinstance(content_item, ImagePromptMessageContent): - # Override vision config if LLM node has vision config, - # cuz vision detail is related to the configuration from FileUpload feature. - content_item.detail = vision_detail - prompt_message_content.append(content_item) - elif isinstance( - content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent + # Skip content if corresponding feature is not supported + if ( + ( + content_item.type == PromptMessageContentType.IMAGE + and ModelFeature.VISION not in model_config.model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.DOCUMENT + and ModelFeature.DOCUMENT not in model_config.model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.VIDEO + and ModelFeature.VIDEO not in model_config.model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.AUDIO + and ModelFeature.AUDIO not in model_config.model_schema.features + ) ): - prompt_message_content.append(content_item) - - if len(prompt_message_content) > 1: - prompt_message.content = prompt_message_content - elif ( - len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT - ): + continue + prompt_message_content.append(content_item) + if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: prompt_message.content = prompt_message_content[0].data - + else: + prompt_message.content = prompt_message_content + if prompt_message.is_empty(): + continue filtered_prompt_messages.append(prompt_message) - if not filtered_prompt_messages: + if len(filtered_prompt_messages) == 0: raise NoPromptFoundError( "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) + stop = model_config.stop return filtered_prompt_messages, stop @classmethod @@ -715,3 +826,204 @@ class LLMNode(BaseNode[LLMNodeData]): } }, } + + +def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): + match role: + case PromptMessageRole.USER: + return UserPromptMessage(content=contents) + case PromptMessageRole.ASSISTANT: + return AssistantPromptMessage(content=contents) + case PromptMessageRole.SYSTEM: + return SystemPromptMessage(content=contents) + raise NotImplementedError(f"Role {role} is not supported") + + +def _render_jinja2_message( + *, + template: str, + jinjia2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, +): + if not template: + return "" + + jinjia2_inputs = {} + for jinja2_variable in jinjia2_variables: + variable = variable_pool.get(jinja2_variable.value_selector) + jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" + code_execute_resp = CodeExecutor.execute_workflow_code_template( + language=CodeLanguage.JINJA2, + code=template, + inputs=jinjia2_inputs, + ) + result_text = code_execute_resp["result"] + return result_text + + +def _handle_list_messages( + *, + messages: Sequence[LLMNodeChatModelMessage], + context: Optional[str], + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + vision_detail_config: ImagePromptMessageContent.DETAIL, +) -> Sequence[PromptMessage]: + prompt_messages = [] + for message in messages: + if message.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=message.jinja2_text or "", + jinjia2_variables=jinja2_variables, + variable_pool=variable_pool, + ) + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=result_text)], role=message.role + ) + prompt_messages.append(prompt_message) + else: + # Get segment group from basic message + if context: + template = message.text.replace("{#context#}", context) + else: + template = message.text + segment_group = variable_pool.convert_template(template) + + # Process segments for images + file_contents = [] + for segment in segment_group.value: + if isinstance(segment, ArrayFileSegment): + for file in segment.value: + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + if isinstance(segment, FileSegment): + file = segment.value + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + + # Create message with text from all segments + plain_text = segment_group.text + if plain_text: + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=plain_text)], role=message.role + ) + prompt_messages.append(prompt_message) + + if file_contents: + # Create message with image contents + prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) + prompt_messages.append(prompt_message) + + return prompt_messages + + +def _calculate_rest_token( + *, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity +) -> int: + rest_tokens = 2000 + + model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + if model_context_tokens: + model_instance = ModelInstance( + provider_model_bundle=model_config.provider_model_bundle, model=model_config.model + ) + + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + + max_tokens = 0 + for parameter_rule in model_config.model_schema.parameter_rules: + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + model_config.parameters.get(parameter_rule.name) + or model_config.parameters.get(str(parameter_rule.use_template)) + or 0 + ) + + rest_tokens = model_context_tokens - max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + + return rest_tokens + + +def _handle_memory_chat_mode( + *, + memory: TokenBufferMemory | None, + memory_config: MemoryConfig | None, + model_config: ModelConfigWithCredentialsEntity, +) -> Sequence[PromptMessage]: + memory_messages = [] + # Get messages from memory for chat model + if memory and memory_config: + rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) + memory_messages = memory.get_history_prompt_messages( + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + ) + return memory_messages + + +def _handle_memory_completion_mode( + *, + memory: TokenBufferMemory | None, + memory_config: MemoryConfig | None, + model_config: ModelConfigWithCredentialsEntity, +) -> str: + memory_text = "" + # Get history text from memory for completion model + if memory and memory_config: + rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config) + if not memory_config.role_prefix: + raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") + memory_text = memory.get_history_prompt_text( + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + human_prefix=memory_config.role_prefix.user, + ai_prefix=memory_config.role_prefix.assistant, + ) + return memory_text + + +def _handle_completion_template( + *, + template: LLMNodeCompletionModelPromptTemplate, + context: Optional[str], + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, +) -> Sequence[PromptMessage]: + """Handle completion template processing outside of LLMNode class. + + Args: + template: The completion model prompt template + context: Optional context string + jinja2_variables: Variables for jinja2 template rendering + variable_pool: Variable pool for template conversion + + Returns: + Sequence of prompt messages + """ + prompt_messages = [] + if template.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=template.jinja2_text or "", + jinjia2_variables=jinja2_variables, + variable_pool=variable_pool, + ) + else: + if context: + template_text = template.text.replace("{#context#}", context) + else: + template_text = template.text + result_text = variable_pool.convert_template(template_text).text + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER + ) + prompt_messages.append(prompt_message) + return prompt_messages diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 744dfd3d8d..e855ab2d2b 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -86,12 +86,14 @@ class QuestionClassifierNode(LLMNode): ) prompt_messages, stop = self._fetch_prompt_messages( prompt_template=prompt_template, - system_query=query, + user_query=query, memory=memory, model_config=model_config, - files=files, + user_files=files, vision_enabled=node_data.vision.enabled, vision_detail=node_data.vision.configs.detail, + variable_pool=variable_pool, + jinja2_variables=[], ) # handle invoke result diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 5560f26456..951e5330a3 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -250,9 +250,8 @@ class ToolNode(BaseNode[ToolNodeData]): f"{message.message}" if message.type == ToolInvokeMessage.MessageType.TEXT else f"Link: {message.message}" - if message.type == ToolInvokeMessage.MessageType.LINK - else "" for message in tool_response + if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK} ] ) diff --git a/api/core/workflow/nodes/variable_assigner/node_data.py b/api/core/workflow/nodes/variable_assigner/node_data.py index 70ae29d45f..474ecefe76 100644 --- a/api/core/workflow/nodes/variable_assigner/node_data.py +++ b/api/core/workflow/nodes/variable_assigner/node_data.py @@ -1,11 +1,11 @@ from collections.abc import Sequence -from enum import Enum +from enum import StrEnum from typing import Optional from core.workflow.nodes.base import BaseNodeData -class WriteMode(str, Enum): +class WriteMode(StrEnum): OVER_WRITE = "over-write" APPEND = "append" CLEAR = "clear" diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 84b251223f..6f7b143ad6 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -5,10 +5,9 @@ from collections.abc import Generator, Mapping, Sequence from typing import Any, Optional, cast from configs import dify_config -from core.app.app_config.entities import FileUploadConfig from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.models import File, FileTransferMethod, ImageConfig +from core.file.models import File from core.workflow.callbacks import WorkflowCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.errors import WorkflowNodeRunFailedError @@ -18,9 +17,8 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.nodes import NodeType -from core.workflow.nodes.base import BaseNode, BaseNodeData +from core.workflow.nodes.base import BaseNode from core.workflow.nodes.event import NodeEvent -from core.workflow.nodes.llm import LLMNodeData from core.workflow.nodes.node_mapping import node_type_classes_mapping from factories import file_factory from models.enums import UserFrom @@ -115,7 +113,12 @@ class WorkflowEntry: @classmethod def single_step_run( - cls, workflow: Workflow, node_id: str, user_id: str, user_inputs: dict + cls, + *, + workflow: Workflow, + node_id: str, + user_id: str, + user_inputs: dict, ) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]: """ Single step run workflow node @@ -135,13 +138,9 @@ class WorkflowEntry: raise ValueError("nodes not found in workflow graph") # fetch node config from node id - node_config = None - for node in nodes: - if node.get("id") == node_id: - node_config = node - break - - if not node_config: + try: + node_config = next(filter(lambda node: node["id"] == node_id, nodes)) + except StopIteration: raise ValueError("node id not found in workflow graph") # Get node class @@ -153,11 +152,7 @@ class WorkflowEntry: raise ValueError(f"Node class not found for node type {node_type}") # init variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - environment_variables=workflow.environment_variables, - ) + variable_pool = VariablePool(environment_variables=workflow.environment_variables) # init graph graph = Graph.init(graph_config=workflow.graph_dict) @@ -183,28 +178,24 @@ class WorkflowEntry: try: # variable selector to variable mapping - try: - variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=workflow.graph_dict, config=node_config - ) - except NotImplementedError: - variable_mapping = {} - - cls.mapping_user_inputs_to_variable_pool( - variable_mapping=variable_mapping, - user_inputs=user_inputs, - variable_pool=variable_pool, - tenant_id=workflow.tenant_id, - node_type=node_type, - node_data=node_instance.node_data, + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=workflow.graph_dict, config=node_config ) + except NotImplementedError: + variable_mapping = {} + cls.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id=workflow.tenant_id, + ) + try: # run node generator = node_instance.run() - - return node_instance, generator except Exception as e: raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) + return node_instance, generator @staticmethod def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: @@ -231,12 +222,11 @@ class WorkflowEntry: @classmethod def mapping_user_inputs_to_variable_pool( cls, + *, variable_mapping: Mapping[str, Sequence[str]], user_inputs: dict, variable_pool: VariablePool, tenant_id: str, - node_type: NodeType, - node_data: BaseNodeData, ) -> None: for node_variable, variable_selector in variable_mapping.items(): # fetch node id and variable key from node_variable @@ -254,40 +244,21 @@ class WorkflowEntry: # fetch variable node id from variable selector variable_node_id = variable_selector[0] variable_key_list = variable_selector[1:] - variable_key_list = cast(list[str], variable_key_list) + variable_key_list = list(variable_key_list) # get input value input_value = user_inputs.get(node_variable) if not input_value: input_value = user_inputs.get(node_variable_key) - # FIXME: temp fix for image type - if node_type == NodeType.LLM: - new_value = [] - if isinstance(input_value, list): - node_data = cast(LLMNodeData, node_data) - - detail = node_data.vision.configs.detail if node_data.vision.configs else None - - for item in input_value: - if isinstance(item, dict) and "type" in item and item["type"] == "image": - transfer_method = FileTransferMethod.value_of(item.get("transfer_method")) - mapping = { - "id": item.get("id"), - "transfer_method": transfer_method, - "upload_file_id": item.get("upload_file_id"), - "url": item.get("url"), - } - config = FileUploadConfig(image_config=ImageConfig(detail=detail) if detail else None) - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - config=config, - ) - new_value.append(file) - - if new_value: - input_value = new_value + if isinstance(input_value, dict) and "type" in input_value and "transfer_method" in input_value: + input_value = file_factory.build_from_mapping(mapping=input_value, tenant_id=tenant_id) + if ( + isinstance(input_value, list) + and all(isinstance(item, dict) for item in input_value) + and all("type" in item and "transfer_method" in item for item in input_value) + ): + input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id) # append variable and value to variable pool variable_pool.add([variable_node_id] + variable_key_list, input_value) diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 5af45e1e50..24fa013697 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -33,7 +33,7 @@ def handle(sender, **kwargs): raise NotFound("Document not found") document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) documents.append(document) db.session.add(document) db.session.commit() diff --git a/api/events/event_handlers/update_provider_last_used_at_when_message_created.py b/api/events/event_handlers/update_provider_last_used_at_when_message_created.py index a80572c0de..f225ef8e88 100644 --- a/api/events/event_handlers/update_provider_last_used_at_when_message_created.py +++ b/api/events/event_handlers/update_provider_last_used_at_when_message_created.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity from events.message_event import message_was_created @@ -17,5 +17,5 @@ def handle(sender, **kwargs): db.session.query(Provider).filter( Provider.tenant_id == application_generate_entity.app_config.tenant_id, Provider.provider_name == application_generate_entity.model_conf.provider, - ).update({"last_used": datetime.now(timezone.utc).replace(tzinfo=None)}) + ).update({"last_used": datetime.now(UTC).replace(tzinfo=None)}) db.session.commit() diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index e1f8409f21..36f06c1104 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -1,11 +1,12 @@ import redis +from redis.cluster import ClusterNode, RedisCluster from redis.connection import Connection, SSLConnection from redis.sentinel import Sentinel from configs import dify_config -class RedisClientWrapper(redis.Redis): +class RedisClientWrapper: """ A wrapper class for the Redis client that addresses the issue where the global `redis_client` variable cannot be updated when a new Redis instance is returned @@ -71,6 +72,12 @@ def init_app(app): ) master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) redis_client.initialize(master) + elif dify_config.REDIS_USE_CLUSTERS: + nodes = [ + ClusterNode(host=node.split(":")[0], port=int(node.split.split(":")[1])) + for node in dify_config.REDIS_CLUSTERS.split(",") + ] + redis_client.initialize(RedisCluster(startup_nodes=nodes, password=dify_config.REDIS_CLUSTERS_PASSWORD)) else: redis_params.update( { diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py index 11a7544274..b26caa8671 100644 --- a/api/extensions/storage/azure_blob_storage.py +++ b/api/extensions/storage/azure_blob_storage.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas @@ -67,7 +67,7 @@ class AzureBlobStorage(BaseStorage): account_key=self.account_key, resource_types=ResourceTypes(service=True, container=True, object=True), permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), - expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1), + expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1), ) redis_client.set(cache_key, sas_token, ex=3000) return BlobServiceClient(account_url=self.account_url, credential=sas_token) diff --git a/api/extensions/storage/storage_type.py b/api/extensions/storage/storage_type.py index 415bf251f6..e7fa405afa 100644 --- a/api/extensions/storage/storage_type.py +++ b/api/extensions/storage/storage_type.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class StorageType(str, Enum): +class StorageType(StrEnum): ALIYUN_OSS = "aliyun-oss" AZURE_BLOB = "azure-blob" BAIDU_OBS = "baidu-obs" diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 94bbeebd6d..ad8dba8190 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -1,10 +1,11 @@ import mimetypes from collections.abc import Callable, Mapping, Sequence -from typing import Any +from typing import Any, cast import httpx from sqlalchemy import select +from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig from core.helper import ssrf_proxy from extensions.ext_database import db @@ -71,7 +72,12 @@ def build_from_mapping( transfer_method=transfer_method, ) - if not _is_file_valid_with_config(file=file, config=config): + if not _is_file_valid_with_config( + input_file_type=mapping.get("type", FileType.CUSTOM), + file_extension=file.extension, + file_transfer_method=file.transfer_method, + config=config, + ): raise ValueError(f"File validation failed for file: {file.filename}") return file @@ -80,12 +86,9 @@ def build_from_mapping( def build_from_mappings( *, mappings: Sequence[Mapping[str, Any]], - config: FileUploadConfig | None, + config: FileUploadConfig | None = None, tenant_id: str, ) -> Sequence[File]: - if not config: - return [] - files = [ build_from_mapping( mapping=mapping, @@ -96,13 +99,14 @@ def build_from_mappings( ] if ( + config # If image config is set. - config.image_config + and config.image_config # And the number of image files exceeds the maximum limit and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits ): raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}") - if config.number_limits and len(files) > config.number_limits: + if config and config.number_limits and len(files) > config.number_limits: raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}") return files @@ -114,17 +118,18 @@ def _build_from_local_file( tenant_id: str, transfer_method: FileTransferMethod, ) -> File: - file_type = FileType.value_of(mapping.get("type")) stmt = select(UploadFile).where( UploadFile.id == mapping.get("upload_file_id"), UploadFile.tenant_id == tenant_id, ) row = db.session.scalar(stmt) - if row is None: raise ValueError("Invalid upload file") + file_type = FileType(mapping.get("type")) + file_type = _standardize_file_type(file_type, extension="." + row.extension, mime_type=row.mime_type) + return File( id=mapping.get("id"), filename=row.name, @@ -152,11 +157,14 @@ def _build_from_remote_url( mime_type, filename, file_size = _get_remote_file_info(url) extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin" + file_type = FileType(mapping.get("type")) + file_type = _standardize_file_type(file_type, extension=extension, mime_type=mime_type) + return File( id=mapping.get("id"), filename=filename, tenant_id=tenant_id, - type=FileType.value_of(mapping.get("type")), + type=file_type, transfer_method=transfer_method, remote_url=url, mime_type=mime_type, @@ -166,11 +174,12 @@ def _build_from_remote_url( def _get_remote_file_info(url: str): - mime_type = mimetypes.guess_type(url)[0] or "" file_size = -1 filename = url.split("/")[-1].split("?")[0] or "unknown_file" + mime_type = mimetypes.guess_type(filename)[0] or "" resp = ssrf_proxy.head(url, follow_redirects=True) + resp = cast(httpx.Response, resp) if resp.status_code == httpx.codes.OK: if content_disposition := resp.headers.get("Content-Disposition"): filename = str(content_disposition.split("filename=")[-1].strip('"')) @@ -180,20 +189,6 @@ def _get_remote_file_info(url: str): return mime_type, filename, file_size -def _get_file_type_by_mimetype(mime_type: str) -> FileType: - if "image" in mime_type: - file_type = FileType.IMAGE - elif "video" in mime_type: - file_type = FileType.VIDEO - elif "audio" in mime_type: - file_type = FileType.AUDIO - elif "text" in mime_type or "pdf" in mime_type: - file_type = FileType.DOCUMENT - else: - file_type = FileType.CUSTOM - return file_type - - def _build_from_tool_file( *, mapping: Mapping[str, Any], @@ -213,7 +208,8 @@ def _build_from_tool_file( raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - file_type = mapping.get("type", _get_file_type_by_mimetype(tool_file.mimetype)) + file_type = FileType(mapping.get("type")) + file_type = _standardize_file_type(file_type, extension=extension, mime_type=tool_file.mimetype) return File( id=mapping.get("id"), @@ -229,18 +225,72 @@ def _build_from_tool_file( ) -def _is_file_valid_with_config(*, file: File, config: FileUploadConfig) -> bool: - if config.allowed_file_types and file.type not in config.allowed_file_types and file.type != FileType.CUSTOM: +def _is_file_valid_with_config( + *, + input_file_type: str, + file_extension: str, + file_transfer_method: FileTransferMethod, + config: FileUploadConfig, +) -> bool: + if ( + config.allowed_file_types + and input_file_type not in config.allowed_file_types + and input_file_type != FileType.CUSTOM + ): return False - if config.allowed_file_extensions and file.extension not in config.allowed_file_extensions: + if ( + input_file_type == FileType.CUSTOM + and config.allowed_file_extensions is not None + and file_extension not in config.allowed_file_extensions + ): return False - if config.allowed_file_upload_methods and file.transfer_method not in config.allowed_file_upload_methods: + if config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods: return False - if file.type == FileType.IMAGE and config.image_config: - if config.image_config.transfer_methods and file.transfer_method not in config.image_config.transfer_methods: + if input_file_type == FileType.IMAGE and config.image_config: + if config.image_config.transfer_methods and file_transfer_method not in config.image_config.transfer_methods: return False return True + + +def _standardize_file_type(file_type: FileType, /, *, extension: str = "", mime_type: str = "") -> FileType: + """ + If custom type, try to guess the file type by extension and mime_type. + """ + if file_type != FileType.CUSTOM: + return FileType(file_type) + guessed_type = None + if extension: + guessed_type = _get_file_type_by_extension(extension) + if guessed_type is None and mime_type: + guessed_type = _get_file_type_by_mimetype(mime_type) + return guessed_type or FileType.CUSTOM + + +def _get_file_type_by_extension(extension: str) -> FileType | None: + extension = extension.lstrip(".") + if extension in IMAGE_EXTENSIONS: + return FileType.IMAGE + elif extension in VIDEO_EXTENSIONS: + return FileType.VIDEO + elif extension in AUDIO_EXTENSIONS: + return FileType.AUDIO + elif extension in DOCUMENT_EXTENSIONS: + return FileType.DOCUMENT + + +def _get_file_type_by_mimetype(mime_type: str) -> FileType | None: + if "image" in mime_type: + file_type = FileType.IMAGE + elif "video" in mime_type: + file_type = FileType.VIDEO + elif "audio" in mime_type: + file_type = FileType.AUDIO + elif "text" in mime_type or "pdf" in mime_type: + file_type = FileType.DOCUMENT + else: + file_type = FileType.CUSTOM + return file_type diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index aa353a3cc1..abb27fdad1 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -190,3 +190,12 @@ app_site_fields = { "show_workflow_steps": fields.Boolean, "use_icon_as_answer_icon": fields.Boolean, } + +app_import_fields = { + "id": fields.String, + "status": fields.String, + "app_id": fields.String, + "current_dsl_version": fields.String, + "imported_dsl_version": fields.String, + "error": fields.String, +} diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index b32423f10c..533e3a0837 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -41,6 +41,7 @@ dataset_retrieval_model_fields = { external_retrieval_model_fields = { "top_k": fields.Integer, "score_threshold": fields.Float, + "score_threshold_enabled": fields.Boolean, } tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String} diff --git a/api/libs/helper.py b/api/libs/helper.py index 7638796508..b98a4829e8 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -31,9 +31,12 @@ class AppIconUrlField(fields.Raw): if obj is None: return None - from models.model import IconType + from models.model import App, IconType, Site - if obj.icon_type == IconType.IMAGE.value: + if isinstance(obj, dict) and "app" in obj: + obj = obj["app"] + + if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE.value: return file_helpers.get_signed_file_url(obj.icon) return None diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index e747ea97ad..53aa0f2d45 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -70,7 +70,7 @@ class NotionOAuth(OAuthDataSource): if data_source_binding: data_source_binding.source_info = source_info data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: new_data_source_binding = DataSourceOauthBinding( @@ -106,7 +106,7 @@ class NotionOAuth(OAuthDataSource): if data_source_binding: data_source_binding.source_info = source_info data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: new_data_source_binding = DataSourceOauthBinding( @@ -141,7 +141,7 @@ class NotionOAuth(OAuthDataSource): } data_source_binding.source_info = new_source_info data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: raise ValueError("Data source binding not found") diff --git a/api/models/account.py b/api/models/account.py index 60b4f11aad..951e836dec 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -8,7 +8,7 @@ from extensions.ext_database import db from .types import StringUUID -class AccountStatus(str, enum.Enum): +class AccountStatus(enum.StrEnum): PENDING = "pending" UNINITIALIZED = "uninitialized" ACTIVE = "active" @@ -56,8 +56,8 @@ class Account(UserMixin, db.Model): self._current_tenant = tenant @property - def current_tenant_id(self): - return self._current_tenant.id + def current_tenant_id(self) -> str | None: + return self._current_tenant.id if self._current_tenant else None @current_tenant_id.setter def current_tenant_id(self, value: str): @@ -108,6 +108,10 @@ class Account(UserMixin, db.Model): def is_admin_or_owner(self): return TenantAccountRole.is_privileged_role(self._current_tenant.current_role) + @property + def is_admin(self): + return TenantAccountRole.is_admin_role(self._current_tenant.current_role) + @property def is_editor(self): return TenantAccountRole.is_editing_role(self._current_tenant.current_role) @@ -121,12 +125,12 @@ class Account(UserMixin, db.Model): return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR -class TenantStatus(str, enum.Enum): +class TenantStatus(enum.StrEnum): NORMAL = "normal" ARCHIVE = "archive" -class TenantAccountRole(str, enum.Enum): +class TenantAccountRole(enum.StrEnum): OWNER = "owner" ADMIN = "admin" EDITOR = "editor" @@ -147,6 +151,10 @@ class TenantAccountRole(str, enum.Enum): def is_privileged_role(role: str) -> bool: return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} + @staticmethod + def is_admin_role(role: str) -> bool: + return role and role == TenantAccountRole.ADMIN + @staticmethod def is_non_owner_role(role: str) -> bool: return role and role in { diff --git a/api/models/dataset.py b/api/models/dataset.py index a8b2c419d1..8ab957e875 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -23,7 +23,7 @@ from .model import App, Tag, TagBinding, UploadFile from .types import StringUUID -class DatasetPermissionEnum(str, enum.Enum): +class DatasetPermissionEnum(enum.StrEnum): ONLY_ME = "only_me" ALL_TEAM = "all_team_members" PARTIAL_TEAM = "partial_members" diff --git a/api/models/enums.py b/api/models/enums.py index a83d35e042..7b9500ebe4 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -1,16 +1,16 @@ -from enum import Enum +from enum import StrEnum -class CreatedByRole(str, Enum): +class CreatedByRole(StrEnum): ACCOUNT = "account" END_USER = "end_user" -class UserFrom(str, Enum): +class UserFrom(StrEnum): ACCOUNT = "account" END_USER = "end-user" -class WorkflowRunTriggeredFrom(str, Enum): +class WorkflowRunTriggeredFrom(StrEnum): DEBUGGING = "debugging" APP_RUN = "app-run" diff --git a/api/models/model.py b/api/models/model.py index b7c89ce97c..03b8e0bea5 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -3,7 +3,7 @@ import re import uuid from collections.abc import Mapping from datetime import datetime -from enum import Enum +from enum import Enum, StrEnum from typing import Any, Literal, Optional import sqlalchemy as sa @@ -32,7 +32,7 @@ class DifySetup(db.Model): setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) -class AppMode(str, Enum): +class AppMode(StrEnum): COMPLETION = "completion" WORKFLOW = "workflow" CHAT = "chat" @@ -68,7 +68,7 @@ class App(db.Model): name = db.Column(db.String(255), nullable=False) description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying")) mode = db.Column(db.String(255), nullable=False) - icon_type = db.Column(db.String(255), nullable=True) + icon_type = db.Column(db.String(255), nullable=True) # image, emoji icon = db.Column(db.String(255)) icon_background = db.Column(db.String(255)) app_model_config_id = db.Column(StringUUID, nullable=True) @@ -255,7 +255,7 @@ class AppModelConfig(db.Model): @property def model_dict(self) -> dict: - return json.loads(self.model) if self.model else None + return json.loads(self.model) if self.model else {} @property def suggested_questions_list(self) -> list: @@ -600,8 +600,8 @@ class Conversation(db.Model): app_model_config = ( db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first() ) - - model_config = app_model_config.to_dict() + if app_model_config: + model_config = app_model_config.to_dict() model_config["model_id"] = self.model_id model_config["provider"] = self.model_provider diff --git a/api/models/task.py b/api/models/task.py index 57b147c78d..5d89ff85ac 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from celery import states @@ -16,8 +16,8 @@ class CeleryTask(db.Model): result = db.Column(db.PickleType, nullable=True) date_done = db.Column( db.DateTime, - default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), - onupdate=lambda: datetime.now(timezone.utc).replace(tzinfo=None), + default=lambda: datetime.now(UTC).replace(tzinfo=None), + onupdate=lambda: datetime.now(UTC).replace(tzinfo=None), nullable=True, ) traceback = db.Column(db.Text, nullable=True) @@ -37,4 +37,4 @@ class CeleryTaskSet(db.Model): id = db.Column(db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True) taskset_id = db.Column(db.String(155), unique=True) result = db.Column(db.PickleType, nullable=True) - date_done = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), nullable=True) + date_done = db.Column(db.DateTime, default=lambda: datetime.now(UTC).replace(tzinfo=None), nullable=True) diff --git a/api/models/workflow.py b/api/models/workflow.py index c6b3000083..fd53f137f9 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,7 +1,7 @@ import json from collections.abc import Mapping, Sequence -from datetime import datetime, timezone -from enum import Enum +from datetime import UTC, datetime +from enum import Enum, StrEnum from typing import Any, Optional, Union import sqlalchemy as sa @@ -108,7 +108,7 @@ class Workflow(db.Model): ) updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) updated_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, default=datetime.now(tz=timezone.utc), server_onupdate=func.current_timestamp() + sa.DateTime, nullable=False, default=datetime.now(tz=UTC), server_onupdate=func.current_timestamp() ) _environment_variables: Mapped[str] = mapped_column( "environment_variables", db.Text, nullable=False, server_default="{}" @@ -314,7 +314,7 @@ class Workflow(db.Model): ) -class WorkflowRunStatus(Enum): +class WorkflowRunStatus(StrEnum): """ Workflow Run Status Enum """ @@ -393,13 +393,13 @@ class WorkflowRun(db.Model): version = db.Column(db.String(255), nullable=False) graph = db.Column(db.Text) inputs = db.Column(db.Text) - status = db.Column(db.String(255), nullable=False) - outputs: Mapped[str] = db.Column(db.Text) + status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped + outputs: Mapped[str] = mapped_column(sa.Text, default="{}") error = db.Column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) total_steps = db.Column(db.Integer, server_default=db.text("0")) - created_by_role = db.Column(db.String(255), nullable=False) + created_by_role = db.Column(db.String(255), nullable=False) # account, end_user created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) finished_at = db.Column(db.DateTime) diff --git a/api/poetry.lock b/api/poetry.lock index 6021ae5c74..958673a00b 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -114,7 +114,6 @@ files = [ [package.dependencies] aiohappyeyeballs = ">=2.3.0" aiosignal = ">=1.1.2" -async-timeout = {version = ">=4.0,<5.0", markers = "python_version < \"3.11\""} attrs = ">=17.3.0" frozenlist = ">=1.1.1" multidict = ">=4.5,<7.0" @@ -483,10 +482,8 @@ files = [ ] [package.dependencies] -exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} idna = ">=2.8" sniffio = ">=1.1" -typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} [package.extras] doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] @@ -519,9 +516,6 @@ files = [ {file = "asgiref-3.8.1.tar.gz", hash = "sha256:c343bd80a0bec947a9860adb4c432ffa7db769836c64238fc34bdc3fec84d590"}, ] -[package.dependencies] -typing-extensions = {version = ">=4", markers = "python_version < \"3.11\""} - [package.extras] tests = ["mypy (>=0.800)", "pytest", "pytest-asyncio"] @@ -951,6 +945,10 @@ files = [ {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a37b8f0391212d29b3a91a799c8e4a2855e0576911cdfb2515487e30e322253d"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e84799f09591700a4154154cab9787452925578841a94321d5ee8fb9a9a328f0"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f66b5337fa213f1da0d9000bc8dc0cb5b896b726eefd9c6046f699b169c41b9e"}, + {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5dab0844f2cf82be357a0eb11a9087f70c5430b2c241493fc122bb6f2bb0917c"}, + {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e4fe605b917c70283db7dfe5ada75e04561479075761a0b3866c081d035b01c1"}, + {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1e9a65b5736232e7a7f91ff3d02277f11d339bf34099a56cdab6a8b3410a02b2"}, + {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:58d4b711689366d4a03ac7957ab8c28890415e267f9b6589969e74b6e42225ec"}, {file = "Brotli-1.1.0-cp310-cp310-win32.whl", hash = "sha256:be36e3d172dc816333f33520154d708a2657ea63762ec16b62ece02ab5e4daf2"}, {file = "Brotli-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:0c6244521dda65ea562d5a69b9a26120769b7a9fb3db2fe9545935ed6735b128"}, {file = "Brotli-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a3daabb76a78f829cafc365531c972016e4aa8d5b4bf60660ad8ecee19df7ccc"}, @@ -963,8 +961,14 @@ files = [ {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:19c116e796420b0cee3da1ccec3b764ed2952ccfcc298b55a10e5610ad7885f9"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:510b5b1bfbe20e1a7b3baf5fed9e9451873559a976c1a78eebaa3b86c57b4265"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a1fd8a29719ccce974d523580987b7f8229aeace506952fa9ce1d53a033873c8"}, + {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c247dd99d39e0338a604f8c2b3bc7061d5c2e9e2ac7ba9cc1be5a69cb6cd832f"}, + {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1b2c248cd517c222d89e74669a4adfa5577e06ab68771a529060cf5a156e9757"}, + {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:2a24c50840d89ded6c9a8fdc7b6ed3692ed4e86f1c4a4a938e1e92def92933e0"}, + {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f31859074d57b4639318523d6ffdca586ace54271a73ad23ad021acd807eb14b"}, {file = "Brotli-1.1.0-cp311-cp311-win32.whl", hash = "sha256:39da8adedf6942d76dc3e46653e52df937a3c4d6d18fdc94a7c29d263b1f5b50"}, {file = "Brotli-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:aac0411d20e345dc0920bdec5548e438e999ff68d77564d5e9463a7ca9d3e7b1"}, + {file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:32d95b80260d79926f5fab3c41701dbb818fde1c9da590e77e571eefd14abe28"}, + {file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b760c65308ff1e462f65d69c12e4ae085cff3b332d894637f6273a12a482d09f"}, {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:316cc9b17edf613ac76b1f1f305d2a748f1b976b033b049a6ecdfd5612c70409"}, {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:caf9ee9a5775f3111642d33b86237b05808dafcd6268faa492250e9b78046eb2"}, {file = "Brotli-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70051525001750221daa10907c77830bc889cb6d865cc0b813d9db7fefc21451"}, @@ -975,8 +979,24 @@ files = [ {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4093c631e96fdd49e0377a9c167bfd75b6d0bad2ace734c6eb20b348bc3ea180"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e4c4629ddad63006efa0ef968c8e4751c5868ff0b1c5c40f76524e894c50248"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:861bf317735688269936f755fa136a99d1ed526883859f86e41a5d43c61d8966"}, + {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:87a3044c3a35055527ac75e419dfa9f4f3667a1e887ee80360589eb8c90aabb9"}, + {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c5529b34c1c9d937168297f2c1fde7ebe9ebdd5e121297ff9c043bdb2ae3d6fb"}, + {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ca63e1890ede90b2e4454f9a65135a4d387a4585ff8282bb72964fab893f2111"}, + {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e79e6520141d792237c70bcd7a3b122d00f2613769ae0cb61c52e89fd3443839"}, {file = "Brotli-1.1.0-cp312-cp312-win32.whl", hash = "sha256:5f4d5ea15c9382135076d2fb28dde923352fe02951e66935a9efaac8f10e81b0"}, {file = "Brotli-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:906bc3a79de8c4ae5b86d3d75a8b77e44404b0f4261714306e3ad248d8ab0951"}, + {file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8bf32b98b75c13ec7cf774164172683d6e7891088f6316e54425fde1efc276d5"}, + {file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7bc37c4d6b87fb1017ea28c9508b36bbcb0c3d18b4260fcdf08b200c74a6aee8"}, + {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c0ef38c7a7014ffac184db9e04debe495d317cc9c6fb10071f7fefd93100a4f"}, + {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91d7cc2a76b5567591d12c01f019dd7afce6ba8cba6571187e21e2fc418ae648"}, + {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a93dde851926f4f2678e704fadeb39e16c35d8baebd5252c9fd94ce8ce68c4a0"}, + {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f0db75f47be8b8abc8d9e31bc7aad0547ca26f24a54e6fd10231d623f183d089"}, + {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6967ced6730aed543b8673008b5a391c3b1076d834ca438bbd70635c73775368"}, + {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7eedaa5d036d9336c95915035fb57422054014ebdeb6f3b42eac809928e40d0c"}, + {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:d487f5432bf35b60ed625d7e1b448e2dc855422e87469e3f450aa5552b0eb284"}, + {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:832436e59afb93e1836081a20f324cb185836c617659b07b129141a8426973c7"}, + {file = "Brotli-1.1.0-cp313-cp313-win32.whl", hash = "sha256:43395e90523f9c23a3d5bdf004733246fba087f2948f87ab28015f12359ca6a0"}, + {file = "Brotli-1.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:9011560a466d2eb3f5a6e4929cf4a09be405c64154e12df0dd72713f6500e32b"}, {file = "Brotli-1.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a090ca607cbb6a34b0391776f0cb48062081f5f60ddcce5d11838e67a01928d1"}, {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de9d02f5bda03d27ede52e8cfe7b865b066fa49258cbab568720aa5be80a47d"}, {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2333e30a5e00fe0fe55903c8832e08ee9c3b1382aacf4db26664a16528d51b4b"}, @@ -986,6 +1006,10 @@ files = [ {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:fd5f17ff8f14003595ab414e45fce13d073e0762394f957182e69035c9f3d7c2"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:069a121ac97412d1fe506da790b3e69f52254b9df4eb665cd42460c837193354"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:e93dfc1a1165e385cc8239fab7c036fb2cd8093728cbd85097b284d7b99249a2"}, + {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_aarch64.whl", hash = "sha256:aea440a510e14e818e67bfc4027880e2fb500c2ccb20ab21c7a7c8b5b4703d75"}, + {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_i686.whl", hash = "sha256:6974f52a02321b36847cd19d1b8e381bf39939c21efd6ee2fc13a28b0d99348c"}, + {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_ppc64le.whl", hash = "sha256:a7e53012d2853a07a4a79c00643832161a910674a893d296c9f1259859a289d2"}, + {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:d7702622a8b40c49bffb46e1e3ba2e81268d5c04a34f460978c6b5517a34dd52"}, {file = "Brotli-1.1.0-cp36-cp36m-win32.whl", hash = "sha256:a599669fd7c47233438a56936988a2478685e74854088ef5293802123b5b2460"}, {file = "Brotli-1.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:d143fd47fad1db3d7c27a1b1d66162e855b5d50a89666af46e1679c496e8e579"}, {file = "Brotli-1.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:11d00ed0a83fa22d29bc6b64ef636c4552ebafcef57154b4ddd132f5638fbd1c"}, @@ -997,6 +1021,10 @@ files = [ {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:919e32f147ae93a09fe064d77d5ebf4e35502a8df75c29fb05788528e330fe74"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:23032ae55523cc7bccb4f6a0bf368cd25ad9bcdcc1990b64a647e7bbcce9cb5b"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:224e57f6eac61cc449f498cc5f0e1725ba2071a3d4f48d5d9dffba42db196438"}, + {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:cb1dac1770878ade83f2ccdf7d25e494f05c9165f5246b46a621cc849341dc01"}, + {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:3ee8a80d67a4334482d9712b8e83ca6b1d9bc7e351931252ebef5d8f7335a547"}, + {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:5e55da2c8724191e5b557f8e18943b1b4839b8efc3ef60d65985bcf6f587dd38"}, + {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:d342778ef319e1026af243ed0a07c97acf3bad33b9f29e7ae6a1f68fd083e90c"}, {file = "Brotli-1.1.0-cp37-cp37m-win32.whl", hash = "sha256:587ca6d3cef6e4e868102672d3bd9dc9698c309ba56d41c2b9c85bbb903cdb95"}, {file = "Brotli-1.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2954c1c23f81c2eaf0b0717d9380bd348578a94161a65b3a2afc62c86467dd68"}, {file = "Brotli-1.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:efa8b278894b14d6da122a72fefcebc28445f2d3f880ac59d46c90f4c13be9a3"}, @@ -1009,6 +1037,10 @@ files = [ {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ab4fbee0b2d9098c74f3057b2bc055a8bd92ccf02f65944a241b4349229185a"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:141bd4d93984070e097521ed07e2575b46f817d08f9fa42b16b9b5f27b5ac088"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fce1473f3ccc4187f75b4690cfc922628aed4d3dd013d047f95a9b3919a86596"}, + {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d2b35ca2c7f81d173d2fadc2f4f31e88cc5f7a39ae5b6db5513cf3383b0e0ec7"}, + {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:af6fa6817889314555aede9a919612b23739395ce767fe7fcbea9a80bf140fe5"}, + {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:2feb1d960f760a575dbc5ab3b1c00504b24caaf6986e2dc2b01c09c87866a943"}, + {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:4410f84b33374409552ac9b6903507cdb31cd30d2501fc5ca13d18f73548444a"}, {file = "Brotli-1.1.0-cp38-cp38-win32.whl", hash = "sha256:db85ecf4e609a48f4b29055f1e144231b90edc90af7481aa731ba2d059226b1b"}, {file = "Brotli-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:3d7954194c36e304e1523f55d7042c59dc53ec20dd4e9ea9d151f1b62b4415c0"}, {file = "Brotli-1.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5fb2ce4b8045c78ebbc7b8f3c15062e435d47e7393cc57c25115cfd49883747a"}, @@ -1021,6 +1053,10 @@ files = [ {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:949f3b7c29912693cee0afcf09acd6ebc04c57af949d9bf77d6101ebb61e388c"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:89f4988c7203739d48c6f806f1e87a1d96e0806d44f0fba61dba81392c9e474d"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:de6551e370ef19f8de1807d0a9aa2cdfdce2e85ce88b122fe9f6b2b076837e59"}, + {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0737ddb3068957cf1b054899b0883830bb1fec522ec76b1098f9b6e0f02d9419"}, + {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4f3607b129417e111e30637af1b56f24f7a49e64763253bbc275c75fa887d4b2"}, + {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:6c6e0c425f22c1c719c42670d561ad682f7bfeeef918edea971a79ac5252437f"}, + {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:494994f807ba0b92092a163a0a283961369a65f6cbe01e8891132b7a320e61eb"}, {file = "Brotli-1.1.0-cp39-cp39-win32.whl", hash = "sha256:f0d8a7a6b5983c2496e364b969f0e526647a06b075d034f3297dc66f3b360c64"}, {file = "Brotli-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdad5b9014d83ca68c25d2e9444e28e967ef16e80f6b436918c700c117a85467"}, {file = "Brotli-1.1.0.tar.gz", hash = "sha256:81de08ac11bcb85841e440c13611c00b67d3bf82698314928d0b676362546724"}, @@ -1092,10 +1128,8 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "os_name == \"nt\""} -importlib-metadata = {version = ">=4.6", markers = "python_full_version < \"3.10.2\""} packaging = ">=19.1" pyproject_hooks = "*" -tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} [package.extras] docs = ["furo (>=2023.08.17)", "sphinx (>=7.0,<8.0)", "sphinx-argparse-cli (>=1.5)", "sphinx-autodoc-typehints (>=1.10)", "sphinx-issues (>=3.0.0)"] @@ -1388,36 +1422,40 @@ files = [ [[package]] name = "chroma-hnswlib" -version = "0.7.3" +version = "0.7.6" description = "Chromas fork of hnswlib" optional = false python-versions = "*" files = [ - {file = "chroma-hnswlib-0.7.3.tar.gz", hash = "sha256:b6137bedde49fffda6af93b0297fe00429fc61e5a072b1ed9377f909ed95a932"}, - {file = "chroma_hnswlib-0.7.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:59d6a7c6f863c67aeb23e79a64001d537060b6995c3eca9a06e349ff7b0998ca"}, - {file = "chroma_hnswlib-0.7.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d71a3f4f232f537b6152947006bd32bc1629a8686df22fd97777b70f416c127a"}, - {file = "chroma_hnswlib-0.7.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c92dc1ebe062188e53970ba13f6b07e0ae32e64c9770eb7f7ffa83f149d4210"}, - {file = "chroma_hnswlib-0.7.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49da700a6656fed8753f68d44b8cc8ae46efc99fc8a22a6d970dc1697f49b403"}, - {file = "chroma_hnswlib-0.7.3-cp310-cp310-win_amd64.whl", hash = "sha256:108bc4c293d819b56476d8f7865803cb03afd6ca128a2a04d678fffc139af029"}, - {file = "chroma_hnswlib-0.7.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:11e7ca93fb8192214ac2b9c0943641ac0daf8f9d4591bb7b73be808a83835667"}, - {file = "chroma_hnswlib-0.7.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6f552e4d23edc06cdeb553cdc757d2fe190cdeb10d43093d6a3319f8d4bf1c6b"}, - {file = "chroma_hnswlib-0.7.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f96f4d5699e486eb1fb95849fe35ab79ab0901265805be7e60f4eaa83ce263ec"}, - {file = "chroma_hnswlib-0.7.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:368e57fe9ebae05ee5844840fa588028a023d1182b0cfdb1d13f607c9ea05756"}, - {file = "chroma_hnswlib-0.7.3-cp311-cp311-win_amd64.whl", hash = "sha256:b7dca27b8896b494456db0fd705b689ac6b73af78e186eb6a42fea2de4f71c6f"}, - {file = "chroma_hnswlib-0.7.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:70f897dc6218afa1d99f43a9ad5eb82f392df31f57ff514ccf4eeadecd62f544"}, - {file = "chroma_hnswlib-0.7.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5aef10b4952708f5a1381c124a29aead0c356f8d7d6e0b520b778aaa62a356f4"}, - {file = "chroma_hnswlib-0.7.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ee2d8d1529fca3898d512079144ec3e28a81d9c17e15e0ea4665697a7923253"}, - {file = "chroma_hnswlib-0.7.3-cp37-cp37m-win_amd64.whl", hash = "sha256:a4021a70e898783cd6f26e00008b494c6249a7babe8774e90ce4766dd288c8ba"}, - {file = "chroma_hnswlib-0.7.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a8f61fa1d417fda848e3ba06c07671f14806a2585272b175ba47501b066fe6b1"}, - {file = "chroma_hnswlib-0.7.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d7563be58bc98e8f0866907368e22ae218d6060601b79c42f59af4eccbbd2e0a"}, - {file = "chroma_hnswlib-0.7.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51b8d411486ee70d7b66ec08cc8b9b6620116b650df9c19076d2d8b6ce2ae914"}, - {file = "chroma_hnswlib-0.7.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d706782b628e4f43f1b8a81e9120ac486837fbd9bcb8ced70fe0d9b95c72d77"}, - {file = "chroma_hnswlib-0.7.3-cp38-cp38-win_amd64.whl", hash = "sha256:54f053dedc0e3ba657f05fec6e73dd541bc5db5b09aa8bc146466ffb734bdc86"}, - {file = "chroma_hnswlib-0.7.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e607c5a71c610a73167a517062d302c0827ccdd6e259af6e4869a5c1306ffb5d"}, - {file = "chroma_hnswlib-0.7.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c2358a795870156af6761890f9eb5ca8cade57eb10c5f046fe94dae1faa04b9e"}, - {file = "chroma_hnswlib-0.7.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cea425df2e6b8a5e201fff0d922a1cc1d165b3cfe762b1408075723c8892218"}, - {file = "chroma_hnswlib-0.7.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:454df3dd3e97aa784fba7cf888ad191e0087eef0fd8c70daf28b753b3b591170"}, - {file = "chroma_hnswlib-0.7.3-cp39-cp39-win_amd64.whl", hash = "sha256:df587d15007ca701c6de0ee7d5585dd5e976b7edd2b30ac72bc376b3c3f85882"}, + {file = "chroma_hnswlib-0.7.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f35192fbbeadc8c0633f0a69c3d3e9f1a4eab3a46b65458bbcbcabdd9e895c36"}, + {file = "chroma_hnswlib-0.7.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6f007b608c96362b8f0c8b6b2ac94f67f83fcbabd857c378ae82007ec92f4d82"}, + {file = "chroma_hnswlib-0.7.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:456fd88fa0d14e6b385358515aef69fc89b3c2191706fd9aee62087b62aad09c"}, + {file = "chroma_hnswlib-0.7.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5dfaae825499c2beaa3b75a12d7ec713b64226df72a5c4097203e3ed532680da"}, + {file = "chroma_hnswlib-0.7.6-cp310-cp310-win_amd64.whl", hash = "sha256:2487201982241fb1581be26524145092c95902cb09fc2646ccfbc407de3328ec"}, + {file = "chroma_hnswlib-0.7.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:81181d54a2b1e4727369486a631f977ffc53c5533d26e3d366dda243fb0998ca"}, + {file = "chroma_hnswlib-0.7.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4b4ab4e11f1083dd0a11ee4f0e0b183ca9f0f2ed63ededba1935b13ce2b3606f"}, + {file = "chroma_hnswlib-0.7.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:53db45cd9173d95b4b0bdccb4dbff4c54a42b51420599c32267f3abbeb795170"}, + {file = "chroma_hnswlib-0.7.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c093f07a010b499c00a15bc9376036ee4800d335360570b14f7fe92badcdcf9"}, + {file = "chroma_hnswlib-0.7.6-cp311-cp311-win_amd64.whl", hash = "sha256:0540b0ac96e47d0aa39e88ea4714358ae05d64bbe6bf33c52f316c664190a6a3"}, + {file = "chroma_hnswlib-0.7.6-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e87e9b616c281bfbe748d01705817c71211613c3b063021f7ed5e47173556cb7"}, + {file = "chroma_hnswlib-0.7.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ec5ca25bc7b66d2ecbf14502b5729cde25f70945d22f2aaf523c2d747ea68912"}, + {file = "chroma_hnswlib-0.7.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:305ae491de9d5f3c51e8bd52d84fdf2545a4a2bc7af49765cda286b7bb30b1d4"}, + {file = "chroma_hnswlib-0.7.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:822ede968d25a2c88823ca078a58f92c9b5c4142e38c7c8b4c48178894a0a3c5"}, + {file = "chroma_hnswlib-0.7.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2fe6ea949047beed19a94b33f41fe882a691e58b70c55fdaa90274ae78be046f"}, + {file = "chroma_hnswlib-0.7.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:feceff971e2a2728c9ddd862a9dd6eb9f638377ad98438876c9aeac96c9482f5"}, + {file = "chroma_hnswlib-0.7.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb0633b60e00a2b92314d0bf5bbc0da3d3320be72c7e3f4a9b19f4609dc2b2ab"}, + {file = "chroma_hnswlib-0.7.6-cp37-cp37m-win_amd64.whl", hash = "sha256:a566abe32fab42291f766d667bdbfa234a7f457dcbd2ba19948b7a978c8ca624"}, + {file = "chroma_hnswlib-0.7.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6be47853d9a58dedcfa90fc846af202b071f028bbafe1d8711bf64fe5a7f6111"}, + {file = "chroma_hnswlib-0.7.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3a7af35bdd39a88bffa49f9bb4bf4f9040b684514a024435a1ef5cdff980579d"}, + {file = "chroma_hnswlib-0.7.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a53b1f1551f2b5ad94eb610207bde1bb476245fc5097a2bec2b476c653c58bde"}, + {file = "chroma_hnswlib-0.7.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3085402958dbdc9ff5626ae58d696948e715aef88c86d1e3f9285a88f1afd3bc"}, + {file = "chroma_hnswlib-0.7.6-cp38-cp38-win_amd64.whl", hash = "sha256:77326f658a15adfb806a16543f7db7c45f06fd787d699e643642d6bde8ed49c4"}, + {file = "chroma_hnswlib-0.7.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:93b056ab4e25adab861dfef21e1d2a2756b18be5bc9c292aa252fa12bb44e6ae"}, + {file = "chroma_hnswlib-0.7.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fe91f018b30452c16c811fd6c8ede01f84e5a9f3c23e0758775e57f1c3778871"}, + {file = "chroma_hnswlib-0.7.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e6c0e627476f0f4d9e153420d36042dd9c6c3671cfd1fe511c0253e38c2a1039"}, + {file = "chroma_hnswlib-0.7.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e9796a4536b7de6c6d76a792ba03e08f5aaa53e97e052709568e50b4d20c04f"}, + {file = "chroma_hnswlib-0.7.6-cp39-cp39-win_amd64.whl", hash = "sha256:d30e2db08e7ffdcc415bd072883a322de5995eb6ec28a8f8c054103bbd3ec1e0"}, + {file = "chroma_hnswlib-0.7.6.tar.gz", hash = "sha256:4dce282543039681160259d29fcde6151cc9106c6461e0485f57cdccd83059b7"}, ] [package.dependencies] @@ -1425,26 +1463,26 @@ numpy = "*" [[package]] name = "chromadb" -version = "0.5.1" +version = "0.5.20" description = "Chroma." optional = false python-versions = ">=3.8" files = [ - {file = "chromadb-0.5.1-py3-none-any.whl", hash = "sha256:61f1f75a672b6edce7f1c8875c67e2aaaaf130dc1c1684431fbc42ad7240d01d"}, - {file = "chromadb-0.5.1.tar.gz", hash = "sha256:e2b2b6a34c2a949bedcaa42fa7775f40c7f6667848fc8094dcbf97fc0d30bee7"}, + {file = "chromadb-0.5.20-py3-none-any.whl", hash = "sha256:9550ba1b6dce911e35cac2568b301badf4b42f457b99a432bdeec2b6b9dd3680"}, + {file = "chromadb-0.5.20.tar.gz", hash = "sha256:19513a23b2d20059866216bfd80195d1d4a160ffba234b8899f5e80978160ca7"}, ] [package.dependencies] bcrypt = ">=4.0.1" build = ">=1.0.3" -chroma-hnswlib = "0.7.3" +chroma-hnswlib = "0.7.6" fastapi = ">=0.95.2" grpcio = ">=1.58.0" httpx = ">=0.27.0" importlib-resources = "*" kubernetes = ">=28.1.0" mmh3 = ">=4.0.1" -numpy = ">=1.22.5,<2.0.0" +numpy = ">=1.22.5" onnxruntime = ">=1.14.1" opentelemetry-api = ">=1.2.0" opentelemetry-exporter-otlp-proto-grpc = ">=1.2.0" @@ -1456,7 +1494,7 @@ posthog = ">=2.4.0" pydantic = ">=1.9" pypika = ">=0.48.9" PyYAML = ">=6.0.0" -requests = ">=2.28" +rich = ">=10.11.0" tenacity = ">=8.2.3" tokenizers = ">=0.13.2" tqdm = ">=4.65.0" @@ -2036,9 +2074,6 @@ files = [ {file = "dataclass_wizard-0.28.0-py2.py3-none-any.whl", hash = "sha256:996fa46475b9192a48a057c34f04597bc97be5bc2f163b99cb1de6f778ca1f7f"}, ] -[package.dependencies] -typing-extensions = {version = ">=4", markers = "python_version == \"3.9\" or python_version == \"3.10\""} - [package.extras] dev = ["Sphinx (==7.4.7)", "Sphinx (==8.1.3)", "bump2version (==1.0.1)", "coverage (>=6.2)", "dataclass-factory (==2.16)", "dataclass-wizard[toml]", "dataclasses-json (==0.6.7)", "flake8 (>=3)", "jsons (==1.6.3)", "pip (>=21.3.1)", "pytest (==8.3.3)", "pytest-cov (==6.0.0)", "pytest-mock (>=3.6.1)", "pytimeparse (==1.1.8)", "sphinx-issues (==5.0.0)", "tomli (>=2,<3)", "tomli (>=2,<3)", "tomli-w (>=1,<2)", "tox (==4.23.2)", "twine (==5.1.1)", "watchdog[watchmedo] (==6.0.0)", "wheel (==0.45.0)"] timedelta = ["pytimeparse (>=1.1.7)"] @@ -2410,18 +2445,19 @@ files = [ tests = ["pytest"] [[package]] -name = "exceptiongroup" -version = "1.2.2" -description = "Backport of PEP 654 (exception groups)" +name = "faker" +version = "32.1.0" +description = "Faker is a Python package that generates fake data for you." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, - {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, + {file = "Faker-32.1.0-py3-none-any.whl", hash = "sha256:c77522577863c264bdc9dad3a2a750ad3f7ee43ff8185072e482992288898814"}, + {file = "faker-32.1.0.tar.gz", hash = "sha256:aac536ba04e6b7beb2332c67df78485fc29c1880ff723beac6d1efd45e2f10f5"}, ] -[package.extras] -test = ["pytest (>=6)"] +[package.dependencies] +python-dateutil = ">=2.4" +typing-extensions = "*" [[package]] name = "fal-client" @@ -3030,59 +3066,54 @@ files = [ [[package]] name = "gevent" -version = "23.9.1" +version = "24.11.1" description = "Coroutine-based network library" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "gevent-23.9.1-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:a3c5e9b1f766a7a64833334a18539a362fb563f6c4682f9634dea72cbe24f771"}, - {file = "gevent-23.9.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b101086f109168b23fa3586fccd1133494bdb97f86920a24dc0b23984dc30b69"}, - {file = "gevent-23.9.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:36a549d632c14684bcbbd3014a6ce2666c5f2a500f34d58d32df6c9ea38b6535"}, - {file = "gevent-23.9.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:272cffdf535978d59c38ed837916dfd2b5d193be1e9e5dcc60a5f4d5025dd98a"}, - {file = "gevent-23.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dcb8612787a7f4626aa881ff15ff25439561a429f5b303048f0fca8a1c781c39"}, - {file = "gevent-23.9.1-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:d57737860bfc332b9b5aa438963986afe90f49645f6e053140cfa0fa1bdae1ae"}, - {file = "gevent-23.9.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5f3c781c84794926d853d6fb58554dc0dcc800ba25c41d42f6959c344b4db5a6"}, - {file = "gevent-23.9.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:dbb22a9bbd6a13e925815ce70b940d1578dbe5d4013f20d23e8a11eddf8d14a7"}, - {file = "gevent-23.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:707904027d7130ff3e59ea387dddceedb133cc742b00b3ffe696d567147a9c9e"}, - {file = "gevent-23.9.1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:45792c45d60f6ce3d19651d7fde0bc13e01b56bb4db60d3f32ab7d9ec467374c"}, - {file = "gevent-23.9.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e24c2af9638d6c989caffc691a039d7c7022a31c0363da367c0d32ceb4a0648"}, - {file = "gevent-23.9.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e1ead6863e596a8cc2a03e26a7a0981f84b6b3e956101135ff6d02df4d9a6b07"}, - {file = "gevent-23.9.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:65883ac026731ac112184680d1f0f1e39fa6f4389fd1fc0bf46cc1388e2599f9"}, - {file = "gevent-23.9.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf7af500da05363e66f122896012acb6e101a552682f2352b618e541c941a011"}, - {file = "gevent-23.9.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:c3e5d2fa532e4d3450595244de8ccf51f5721a05088813c1abd93ad274fe15e7"}, - {file = "gevent-23.9.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c84d34256c243b0a53d4335ef0bc76c735873986d478c53073861a92566a8d71"}, - {file = "gevent-23.9.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ada07076b380918829250201df1d016bdafb3acf352f35e5693b59dceee8dd2e"}, - {file = "gevent-23.9.1-cp311-cp311-win_amd64.whl", hash = "sha256:921dda1c0b84e3d3b1778efa362d61ed29e2b215b90f81d498eb4d8eafcd0b7a"}, - {file = "gevent-23.9.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:ed7a048d3e526a5c1d55c44cb3bc06cfdc1947d06d45006cc4cf60dedc628904"}, - {file = "gevent-23.9.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c1abc6f25f475adc33e5fc2dbcc26a732608ac5375d0d306228738a9ae14d3b"}, - {file = "gevent-23.9.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4368f341a5f51611411ec3fc62426f52ac3d6d42eaee9ed0f9eebe715c80184e"}, - {file = "gevent-23.9.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:52b4abf28e837f1865a9bdeef58ff6afd07d1d888b70b6804557e7908032e599"}, - {file = "gevent-23.9.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:52e9f12cd1cda96603ce6b113d934f1aafb873e2c13182cf8e86d2c5c41982ea"}, - {file = "gevent-23.9.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:de350fde10efa87ea60d742901e1053eb2127ebd8b59a7d3b90597eb4e586599"}, - {file = "gevent-23.9.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:fde6402c5432b835fbb7698f1c7f2809c8d6b2bd9d047ac1f5a7c1d5aa569303"}, - {file = "gevent-23.9.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:dd6c32ab977ecf7c7b8c2611ed95fa4aaebd69b74bf08f4b4960ad516861517d"}, - {file = "gevent-23.9.1-cp312-cp312-win_amd64.whl", hash = "sha256:455e5ee8103f722b503fa45dedb04f3ffdec978c1524647f8ba72b4f08490af1"}, - {file = "gevent-23.9.1-cp38-cp38-macosx_11_0_universal2.whl", hash = "sha256:7ccf0fd378257cb77d91c116e15c99e533374a8153632c48a3ecae7f7f4f09fe"}, - {file = "gevent-23.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d163d59f1be5a4c4efcdd13c2177baaf24aadf721fdf2e1af9ee54a998d160f5"}, - {file = "gevent-23.9.1-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:7532c17bc6c1cbac265e751b95000961715adef35a25d2b0b1813aa7263fb397"}, - {file = "gevent-23.9.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:78eebaf5e73ff91d34df48f4e35581ab4c84e22dd5338ef32714264063c57507"}, - {file = "gevent-23.9.1-cp38-cp38-win32.whl", hash = "sha256:f632487c87866094546a74eefbca2c74c1d03638b715b6feb12e80120960185a"}, - {file = "gevent-23.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:62d121344f7465e3739989ad6b91f53a6ca9110518231553fe5846dbe1b4518f"}, - {file = "gevent-23.9.1-cp39-cp39-macosx_11_0_universal2.whl", hash = "sha256:bf456bd6b992eb0e1e869e2fd0caf817f0253e55ca7977fd0e72d0336a8c1c6a"}, - {file = "gevent-23.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:43daf68496c03a35287b8b617f9f91e0e7c0d042aebcc060cadc3f049aadd653"}, - {file = "gevent-23.9.1-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:7c28e38dcde327c217fdafb9d5d17d3e772f636f35df15ffae2d933a5587addd"}, - {file = "gevent-23.9.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:fae8d5b5b8fa2a8f63b39f5447168b02db10c888a3e387ed7af2bd1b8612e543"}, - {file = "gevent-23.9.1-cp39-cp39-win32.whl", hash = "sha256:2c7b5c9912378e5f5ccf180d1fdb1e83f42b71823483066eddbe10ef1a2fcaa2"}, - {file = "gevent-23.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:a2898b7048771917d85a1d548fd378e8a7b2ca963db8e17c6d90c76b495e0e2b"}, - {file = "gevent-23.9.1.tar.gz", hash = "sha256:72c002235390d46f94938a96920d8856d4ffd9ddf62a303a0d7c118894097e34"}, + {file = "gevent-24.11.1-cp310-cp310-macosx_11_0_universal2.whl", hash = "sha256:92fe5dfee4e671c74ffaa431fd7ffd0ebb4b339363d24d0d944de532409b935e"}, + {file = "gevent-24.11.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b7bfcfe08d038e1fa6de458891bca65c1ada6d145474274285822896a858c870"}, + {file = "gevent-24.11.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7398c629d43b1b6fd785db8ebd46c0a353880a6fab03d1cf9b6788e7240ee32e"}, + {file = "gevent-24.11.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d7886b63ebfb865178ab28784accd32f287d5349b3ed71094c86e4d3ca738af5"}, + {file = "gevent-24.11.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d9ca80711e6553880974898d99357fb649e062f9058418a92120ca06c18c3c59"}, + {file = "gevent-24.11.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e24181d172f50097ac8fc272c8c5b030149b630df02d1c639ee9f878a470ba2b"}, + {file = "gevent-24.11.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:1d4fadc319b13ef0a3c44d2792f7918cf1bca27cacd4d41431c22e6b46668026"}, + {file = "gevent-24.11.1-cp310-cp310-win_amd64.whl", hash = "sha256:3d882faa24f347f761f934786dde6c73aa6c9187ee710189f12dcc3a63ed4a50"}, + {file = "gevent-24.11.1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:351d1c0e4ef2b618ace74c91b9b28b3eaa0dd45141878a964e03c7873af09f62"}, + {file = "gevent-24.11.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5efe72e99b7243e222ba0c2c2ce9618d7d36644c166d63373af239da1036bab"}, + {file = "gevent-24.11.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9d3b249e4e1f40c598ab8393fc01ae6a3b4d51fc1adae56d9ba5b315f6b2d758"}, + {file = "gevent-24.11.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81d918e952954675f93fb39001da02113ec4d5f4921bf5a0cc29719af6824e5d"}, + {file = "gevent-24.11.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c9c935b83d40c748b6421625465b7308d87c7b3717275acd587eef2bd1c39546"}, + {file = "gevent-24.11.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff96c5739834c9a594db0e12bf59cb3fa0e5102fc7b893972118a3166733d61c"}, + {file = "gevent-24.11.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d6c0a065e31ef04658f799215dddae8752d636de2bed61365c358f9c91e7af61"}, + {file = "gevent-24.11.1-cp311-cp311-win_amd64.whl", hash = "sha256:97e2f3999a5c0656f42065d02939d64fffaf55861f7d62b0107a08f52c984897"}, + {file = "gevent-24.11.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:a3d75fa387b69c751a3d7c5c3ce7092a171555126e136c1d21ecd8b50c7a6e46"}, + {file = "gevent-24.11.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:beede1d1cff0c6fafae3ab58a0c470d7526196ef4cd6cc18e7769f207f2ea4eb"}, + {file = "gevent-24.11.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:85329d556aaedced90a993226d7d1186a539c843100d393f2349b28c55131c85"}, + {file = "gevent-24.11.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:816b3883fa6842c1cf9d2786722014a0fd31b6312cca1f749890b9803000bad6"}, + {file = "gevent-24.11.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b24d800328c39456534e3bc3e1684a28747729082684634789c2f5a8febe7671"}, + {file = "gevent-24.11.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a5f1701ce0f7832f333dd2faf624484cbac99e60656bfbb72504decd42970f0f"}, + {file = "gevent-24.11.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:d740206e69dfdfdcd34510c20adcb9777ce2cc18973b3441ab9767cd8948ca8a"}, + {file = "gevent-24.11.1-cp312-cp312-win_amd64.whl", hash = "sha256:68bee86b6e1c041a187347ef84cf03a792f0b6c7238378bf6ba4118af11feaae"}, + {file = "gevent-24.11.1-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:d618e118fdb7af1d6c1a96597a5cd6ac84a9f3732b5be8515c6a66e098d498b6"}, + {file = "gevent-24.11.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2142704c2adce9cd92f6600f371afb2860a446bfd0be5bd86cca5b3e12130766"}, + {file = "gevent-24.11.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:92e0d7759de2450a501effd99374256b26359e801b2d8bf3eedd3751973e87f5"}, + {file = "gevent-24.11.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ca845138965c8c56d1550499d6b923eb1a2331acfa9e13b817ad8305dde83d11"}, + {file = "gevent-24.11.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:356b73d52a227d3313f8f828025b665deada57a43d02b1cf54e5d39028dbcf8d"}, + {file = "gevent-24.11.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:58851f23c4bdb70390f10fc020c973ffcf409eb1664086792c8b1e20f25eef43"}, + {file = "gevent-24.11.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:1ea50009ecb7f1327347c37e9eb6561bdbc7de290769ee1404107b9a9cba7cf1"}, + {file = "gevent-24.11.1-cp313-cp313-win_amd64.whl", hash = "sha256:ec68e270543ecd532c4c1d70fca020f90aa5486ad49c4f3b8b2e64a66f5c9274"}, + {file = "gevent-24.11.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d9347690f4e53de2c4af74e62d6fabc940b6d4a6cad555b5a379f61e7d3f2a8e"}, + {file = "gevent-24.11.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8619d5c888cb7aebf9aec6703e410620ef5ad48cdc2d813dd606f8aa7ace675f"}, + {file = "gevent-24.11.1-cp39-cp39-win32.whl", hash = "sha256:c6b775381f805ff5faf250e3a07c0819529571d19bb2a9d474bee8c3f90d66af"}, + {file = "gevent-24.11.1-cp39-cp39-win_amd64.whl", hash = "sha256:1c3443b0ed23dcb7c36a748d42587168672953d368f2956b17fad36d43b58836"}, + {file = "gevent-24.11.1-pp310-pypy310_pp73-macosx_11_0_universal2.whl", hash = "sha256:f43f47e702d0c8e1b8b997c00f1601486f9f976f84ab704f8f11536e3fa144c9"}, + {file = "gevent-24.11.1.tar.gz", hash = "sha256:8bd1419114e9e4a3ed33a5bad766afff9a3cf765cb440a582a1b3a9bc80c1aca"}, ] [package.dependencies] -cffi = {version = ">=1.12.2", markers = "platform_python_implementation == \"CPython\" and sys_platform == \"win32\""} -greenlet = [ - {version = ">=2.0.0", markers = "platform_python_implementation == \"CPython\" and python_version < \"3.11\""}, - {version = ">=3.0rc3", markers = "platform_python_implementation == \"CPython\" and python_version >= \"3.11\""}, -] +cffi = {version = ">=1.17.1", markers = "platform_python_implementation == \"CPython\" and sys_platform == \"win32\""} +greenlet = {version = ">=3.1.1", markers = "platform_python_implementation == \"CPython\""} "zope.event" = "*" "zope.interface" = "*" @@ -3090,8 +3121,8 @@ greenlet = [ dnspython = ["dnspython (>=1.16.0,<2.0)", "idna"] docs = ["furo", "repoze.sphinx.autointerface", "sphinx", "sphinxcontrib-programoutput", "zope.schema"] monitor = ["psutil (>=5.7.0)"] -recommended = ["cffi (>=1.12.2)", "dnspython (>=1.16.0,<2.0)", "idna", "psutil (>=5.7.0)"] -test = ["cffi (>=1.12.2)", "coverage (>=5.0)", "dnspython (>=1.16.0,<2.0)", "idna", "objgraph", "psutil (>=5.7.0)", "requests", "setuptools"] +recommended = ["cffi (>=1.17.1)", "dnspython (>=1.16.0,<2.0)", "idna", "psutil (>=5.7.0)"] +test = ["cffi (>=1.17.1)", "coverage (>=5.0)", "dnspython (>=1.16.0,<2.0)", "idna", "objgraph", "psutil (>=5.7.0)", "requests"] [[package]] name = "gmpy2" @@ -3200,14 +3231,8 @@ files = [ [package.dependencies] google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" -grpcio = [ - {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, - {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, -] -grpcio-status = [ - {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, - {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, -] +grpcio = {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""} +grpcio-status = {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""} proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" requests = ">=2.18.0,<3.0.0.dev0" @@ -5540,9 +5565,6 @@ files = [ {file = "multidict-6.1.0.tar.gz", hash = "sha256:22ae2ebf9b0c69d206c003e2f6a914ea33f0a932d4aa16f236afc049d9958f4a"}, ] -[package.dependencies] -typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""} - [[package]] name = "multiprocess" version = "0.70.17" @@ -6400,7 +6422,6 @@ bottleneck = {version = ">=1.3.6", optional = true, markers = "extra == \"perfor numba = {version = ">=0.56.4", optional = true, markers = "extra == \"performance\""} numexpr = {version = ">=2.8.4", optional = true, markers = "extra == \"performance\""} numpy = [ - {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] @@ -6684,7 +6705,6 @@ files = [ deprecation = ">=2.1.0,<3.0.0" httpx = {version = ">=0.26,<0.28", extras = ["http2"]} pydantic = ">=1.9,<3.0" -strenum = {version = ">=0.4.9,<0.5.0", markers = "python_version < \"3.11\""} [[package]] name = "posthog" @@ -7416,9 +7436,6 @@ files = [ {file = "pypdf-5.1.0.tar.gz", hash = "sha256:425a129abb1614183fd1aca6982f650b47f8026867c0ce7c4b9f281c443d2740"}, ] -[package.dependencies] -typing_extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} - [package.extras] crypto = ["cryptography"] cryptodome = ["PyCryptodome"] @@ -7507,11 +7524,9 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "sys_platform == \"win32\""} -exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" packaging = "*" pluggy = ">=1.5,<2" -tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] @@ -7549,7 +7564,6 @@ files = [ [package.dependencies] pytest = ">=8.3.3" -tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""} [package.extras] testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "pytest-mock (>=3.14)"] @@ -8367,7 +8381,6 @@ files = [ [package.dependencies] markdown-it-py = ">=2.2.0" pygments = ">=2.13.0,<3.0.0" -typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.11\""} [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] @@ -9204,22 +9217,6 @@ httpx = {version = ">=0.26,<0.28", extras = ["http2"]} python-dateutil = ">=2.8.2,<3.0.0" typing-extensions = ">=4.2.0,<5.0.0" -[[package]] -name = "strenum" -version = "0.4.15" -description = "An Enum that inherits from str." -optional = false -python-versions = "*" -files = [ - {file = "StrEnum-0.4.15-py3-none-any.whl", hash = "sha256:a30cda4af7cc6b5bf52c8055bc4bf4b2b6b14a93b574626da33df53cf7740659"}, - {file = "StrEnum-0.4.15.tar.gz", hash = "sha256:878fb5ab705442070e4dd1929bb5e2249511c0bcf2b0eeacf3bcd80875c82eff"}, -] - -[package.extras] -docs = ["myst-parser[linkify]", "sphinx", "sphinx-rtd-theme"] -release = ["twine"] -test = ["pylint", "pytest", "pytest-black", "pytest-cov", "pytest-pylint"] - [[package]] name = "strictyaml" version = "1.7.3" @@ -9626,17 +9623,6 @@ files = [ {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, ] -[[package]] -name = "tomli" -version = "2.1.0" -description = "A lil' TOML parser" -optional = false -python-versions = ">=3.8" -files = [ - {file = "tomli-2.1.0-py3-none-any.whl", hash = "sha256:a5c57c3d1c56f5ccdf89f6523458f60ef716e210fc47c4cfb188c5ba473e0391"}, - {file = "tomli-2.1.0.tar.gz", hash = "sha256:3f646cae2aec94e17d04973e4249548320197cfabdf130015d023de4b74d8ab8"}, -] - [[package]] name = "tos" version = "2.7.2" @@ -10057,7 +10043,6 @@ h11 = ">=0.8" httptools = {version = ">=0.5.0", optional = true, markers = "extra == \"standard\""} python-dotenv = {version = ">=0.13", optional = true, markers = "extra == \"standard\""} pyyaml = {version = ">=5.1", optional = true, markers = "extra == \"standard\""} -typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} uvloop = {version = ">=0.14.0,<0.15.0 || >0.15.0,<0.15.1 || >0.15.1", optional = true, markers = "(sys_platform != \"win32\" and sys_platform != \"cygwin\") and platform_python_implementation != \"PyPy\" and extra == \"standard\""} watchfiles = {version = ">=0.13", optional = true, markers = "extra == \"standard\""} websockets = {version = ">=10.4", optional = true, markers = "extra == \"standard\""} @@ -11040,5 +11025,5 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" -python-versions = ">=3.10,<3.13" -content-hash = "69a3f471f85dce9e5fb889f739e148a4a6d95aaf94081414503867c7157dba69" +python-versions = ">=3.11,<3.13" +content-hash = "983ba4f2cb89f0c867fc50cb48677cad9343f7f0828c7082cb0b5cf171d716fb" diff --git a/api/pyproject.toml b/api/pyproject.toml index 0d87c1b1c8..79857f8163 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,5 +1,5 @@ [project] -requires-python = ">=3.10,<3.13" +requires-python = ">=3.11,<3.13" [build-system] requires = ["poetry-core"] @@ -131,7 +131,7 @@ flask-login = "~0.6.3" flask-migrate = "~4.0.5" flask-restful = "~0.3.10" flask-sqlalchemy = "~3.1.1" -gevent = "~23.9.1" +gevent = "~24.11.1" gmpy2 = "~2.2.1" google-ai-generativelanguage = "0.6.9" google-api-core = "2.18.0" @@ -163,7 +163,7 @@ pydantic-settings = "~2.6.0" pydantic_extra_types = "~2.9.0" pyjwt = "~2.8.0" pypdfium2 = "~4.17.0" -python = ">=3.10,<3.13" +python = ">=3.11,<3.13" python-docx = "~1.1.0" python-dotenv = "1.0.0" pyyaml = "~6.0.1" @@ -242,7 +242,7 @@ tos = "~2.7.1" [tool.poetry.group.vdb.dependencies] alibabacloud_gpdb20160503 = "~3.8.0" alibabacloud_tea_openapi = "~0.3.9" -chromadb = "0.5.1" +chromadb = "0.5.20" clickhouse-connect = "~0.7.16" couchbase = "~4.3.0" elasticsearch = "8.14.0" @@ -268,6 +268,7 @@ weaviate-client = "~3.21.0" optional = true [tool.poetry.group.dev.dependencies] coverage = "~7.2.4" +faker = "~32.1.0" pytest = "~8.3.2" pytest-benchmark = "~4.0.0" pytest-env = "~1.1.3" diff --git a/api/pytest.ini b/api/pytest.ini index a23a4b3f3d..993da4c9a7 100644 --- a/api/pytest.ini +++ b/api/pytest.ini @@ -20,6 +20,7 @@ env = OPENAI_API_KEY = sk-IamNotARealKeyJustForMockTestKawaiiiiiiiiii TEI_EMBEDDING_SERVER_URL = http://a.abc.com:11451 TEI_RERANK_SERVER_URL = http://a.abc.com:11451 + TEI_API_KEY = ttttttttttttttt UPSTAGE_API_KEY = up-aaaaaaaaaaaaaaaaaaaa VOYAGE_API_KEY = va-aaaaaaaaaaaaaaaaaaaa XINFERENCE_CHAT_MODEL_UID = chat diff --git a/api/services/account_service.py b/api/services/account_service.py index 3d7f9e7dfb..aeb373bb26 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -4,7 +4,7 @@ import logging import random import secrets import uuid -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from hashlib import sha256 from typing import Any, Optional @@ -115,15 +115,15 @@ class AccountService: available_ta.current = True db.session.commit() - if datetime.now(timezone.utc).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10): - account.last_active_at = datetime.now(timezone.utc).replace(tzinfo=None) + if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10): + account.last_active_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return account @staticmethod def get_account_jwt_token(account: Account) -> str: - exp_dt = datetime.now(timezone.utc) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES) + exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES) exp = int(exp_dt.timestamp()) payload = { "user_id": account.id, @@ -160,7 +160,7 @@ class AccountService: if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value - account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) + account.initialized_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() @@ -253,7 +253,7 @@ class AccountService: # If it exists, update the record account_integrate.open_id = open_id account_integrate.encrypted_token = "" # todo - account_integrate.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + account_integrate.updated_at = datetime.now(UTC).replace(tzinfo=None) else: # If it does not exist, create a new record account_integrate = AccountIntegrate( @@ -288,7 +288,7 @@ class AccountService: @staticmethod def update_login_info(account: Account, *, ip_address: str) -> None: """Update last login time and ip""" - account.last_login_at = datetime.now(timezone.utc).replace(tzinfo=None) + account.last_login_at = datetime.now(UTC).replace(tzinfo=None) account.last_login_ip = ip_address db.session.add(account) db.session.commit() @@ -765,7 +765,7 @@ class RegisterService: ) account.last_login_ip = ip_address - account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) + account.initialized_at = datetime.now(UTC).replace(tzinfo=None) TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True) @@ -805,7 +805,7 @@ class RegisterService: is_setup=is_setup, ) account.status = AccountStatus.ACTIVE.value if not status else status.value - account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) + account.initialized_at = datetime.now(UTC).replace(tzinfo=None) if open_id is not None or provider is not None: AccountService.link_account_integrate(provider, open_id, account) diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 915d37ec03..f45c21cb18 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -429,7 +429,7 @@ class AppAnnotationService: raise NotFound("App annotation not found") annotation_setting.score_threshold = args["score_threshold"] annotation_setting.updated_user_id = current_user.id - annotation_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + annotation_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(annotation_setting) db.session.commit() diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py new file mode 100644 index 0000000000..b3c919dbd9 --- /dev/null +++ b/api/services/app_dsl_service.py @@ -0,0 +1,485 @@ +import logging +import uuid +from enum import StrEnum +from typing import Optional +from uuid import uuid4 + +import yaml +from packaging import version +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.helper import ssrf_proxy +from events.app_event import app_model_config_was_updated, app_was_created +from extensions.ext_redis import redis_client +from factories import variable_factory +from models import Account, App, AppMode +from models.model import AppModelConfig +from services.workflow_service import WorkflowService + +logger = logging.getLogger(__name__) + +IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" +IMPORT_INFO_REDIS_EXPIRY = 180 # 3 minutes +CURRENT_DSL_VERSION = "0.1.3" + + +class ImportMode(StrEnum): + YAML_CONTENT = "yaml-content" + YAML_URL = "yaml-url" + + +class ImportStatus(StrEnum): + COMPLETED = "completed" + COMPLETED_WITH_WARNINGS = "completed-with-warnings" + PENDING = "pending" + FAILED = "failed" + + +class Import(BaseModel): + id: str + status: ImportStatus + app_id: Optional[str] = None + current_dsl_version: str = CURRENT_DSL_VERSION + imported_dsl_version: str = "" + error: str = "" + + +def _check_version_compatibility(imported_version: str) -> ImportStatus: + """Determine import status based on version comparison""" + try: + current_ver = version.parse(CURRENT_DSL_VERSION) + imported_ver = version.parse(imported_version) + except version.InvalidVersion: + return ImportStatus.FAILED + + # Compare major version and minor version + if current_ver.major != imported_ver.major or current_ver.minor != imported_ver.minor: + return ImportStatus.PENDING + + if current_ver.micro != imported_ver.micro: + return ImportStatus.COMPLETED_WITH_WARNINGS + + return ImportStatus.COMPLETED + + +class PendingData(BaseModel): + import_mode: str + yaml_content: str + name: str | None + description: str | None + icon_type: str | None + icon: str | None + icon_background: str | None + app_id: str | None + + +class AppDslService: + def __init__(self, session: Session): + self._session = session + + def import_app( + self, + *, + account: Account, + import_mode: str, + yaml_content: Optional[str] = None, + yaml_url: Optional[str] = None, + name: Optional[str] = None, + description: Optional[str] = None, + icon_type: Optional[str] = None, + icon: Optional[str] = None, + icon_background: Optional[str] = None, + app_id: Optional[str] = None, + ) -> Import: + """Import an app from YAML content or URL.""" + import_id = str(uuid.uuid4()) + + # Validate import mode + try: + mode = ImportMode(import_mode) + except ValueError: + raise ValueError(f"Invalid import_mode: {import_mode}") + + # Get YAML content + content = "" + if mode == ImportMode.YAML_URL: + if not yaml_url: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="yaml_url is required when import_mode is yaml-url", + ) + try: + max_size = 10 * 1024 * 1024 # 10MB + response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10)) + response.raise_for_status() + content = response.content + + if len(content) > max_size: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="File size exceeds the limit of 10MB", + ) + + if not content: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="Empty content from url", + ) + + try: + content = content.decode("utf-8") + except UnicodeDecodeError as e: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error=f"Error decoding content: {e}", + ) + except Exception as e: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error=f"Error fetching YAML from URL: {str(e)}", + ) + elif mode == ImportMode.YAML_CONTENT: + if not yaml_content: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="yaml_content is required when import_mode is yaml-content", + ) + content = yaml_content + + # Process YAML content + try: + # Parse YAML to validate format + data = yaml.safe_load(content) + if not isinstance(data, dict): + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="Invalid YAML format: content must be a mapping", + ) + + # Validate and fix DSL version + if not data.get("version"): + data["version"] = "0.1.0" + if not data.get("kind") or data.get("kind") != "app": + data["kind"] = "app" + + imported_version = data.get("version", "0.1.0") + status = _check_version_compatibility(imported_version) + + # Extract app data + app_data = data.get("app") + if not app_data: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="Missing app data in YAML content", + ) + + # If app_id is provided, check if it exists + app = None + if app_id: + stmt = select(App).where(App.id == app_id, App.tenant_id == account.current_tenant_id) + app = self._session.scalar(stmt) + + if not app: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="App not found", + ) + + if app.mode not in [AppMode.WORKFLOW.value, AppMode.ADVANCED_CHAT.value]: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="Only workflow or advanced chat apps can be overwritten", + ) + + # If major version mismatch, store import info in Redis + if status == ImportStatus.PENDING: + panding_data = PendingData( + import_mode=import_mode, + yaml_content=content, + name=name, + description=description, + icon_type=icon_type, + icon=icon, + icon_background=icon_background, + app_id=app_id, + ) + redis_client.setex( + f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}", + IMPORT_INFO_REDIS_EXPIRY, + panding_data.model_dump_json(), + ) + + return Import( + id=import_id, + status=status, + app_id=app_id, + imported_dsl_version=imported_version, + ) + + # Create or update app + app = self._create_or_update_app( + app=app, + data=data, + account=account, + name=name, + description=description, + icon_type=icon_type, + icon=icon, + icon_background=icon_background, + ) + + return Import( + id=import_id, + status=status, + app_id=app.id, + imported_dsl_version=imported_version, + ) + + except yaml.YAMLError as e: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error=f"Invalid YAML format: {str(e)}", + ) + + except Exception as e: + logger.exception("Failed to import app") + return Import( + id=import_id, + status=ImportStatus.FAILED, + error=str(e), + ) + + def confirm_import(self, *, import_id: str, account: Account) -> Import: + """ + Confirm an import that requires confirmation + """ + redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}" + pending_data = redis_client.get(redis_key) + + if not pending_data: + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="Import information expired or does not exist", + ) + + try: + if not isinstance(pending_data, str | bytes): + return Import( + id=import_id, + status=ImportStatus.FAILED, + error="Invalid import information", + ) + pending_data = PendingData.model_validate_json(pending_data) + data = yaml.safe_load(pending_data.yaml_content) + + app = None + if pending_data.app_id: + stmt = select(App).where(App.id == pending_data.app_id, App.tenant_id == account.current_tenant_id) + app = self._session.scalar(stmt) + + # Create or update app + app = self._create_or_update_app( + app=app, + data=data, + account=account, + name=pending_data.name, + description=pending_data.description, + icon_type=pending_data.icon_type, + icon=pending_data.icon, + icon_background=pending_data.icon_background, + ) + + # Delete import info from Redis + redis_client.delete(redis_key) + + return Import( + id=import_id, + status=ImportStatus.COMPLETED, + app_id=app.id, + current_dsl_version=CURRENT_DSL_VERSION, + imported_dsl_version=data.get("version", "0.1.0"), + ) + + except Exception as e: + logger.exception("Error confirming import") + return Import( + id=import_id, + status=ImportStatus.FAILED, + error=str(e), + ) + + def _create_or_update_app( + self, + *, + app: Optional[App], + data: dict, + account: Account, + name: Optional[str] = None, + description: Optional[str] = None, + icon_type: Optional[str] = None, + icon: Optional[str] = None, + icon_background: Optional[str] = None, + ) -> App: + """Create a new app or update an existing one.""" + app_data = data.get("app", {}) + app_mode = AppMode(app_data["mode"]) + + # Set icon type + icon_type_value = icon_type or app_data.get("icon_type") + if icon_type_value in ["emoji", "link"]: + icon_type = icon_type_value + else: + icon_type = "emoji" + icon = icon or str(app_data.get("icon", "")) + + if app: + # Update existing app + app.name = name or app_data.get("name", app.name) + app.description = description or app_data.get("description", app.description) + app.icon_type = icon_type + app.icon = icon + app.icon_background = icon_background or app_data.get("icon_background", app.icon_background) + app.updated_by = account.id + else: + # Create new app + app = App() + app.id = str(uuid4()) + app.tenant_id = account.current_tenant_id + app.mode = app_mode.value + app.name = name or app_data.get("name", "") + app.description = description or app_data.get("description", "") + app.icon_type = icon_type + app.icon = icon + app.icon_background = icon_background or app_data.get("icon_background", "#FFFFFF") + app.enable_site = True + app.enable_api = True + app.use_icon_as_answer_icon = app_data.get("use_icon_as_answer_icon", False) + app.created_by = account.id + app.updated_by = account.id + + self._session.add(app) + self._session.commit() + app_was_created.send(app, account=account) + + # Initialize app based on mode + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + workflow_data = data.get("workflow") + if not workflow_data or not isinstance(workflow_data, dict): + raise ValueError("Missing workflow data for workflow/advanced chat app") + + environment_variables_list = workflow_data.get("environment_variables", []) + environment_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list + ] + conversation_variables_list = workflow_data.get("conversation_variables", []) + conversation_variables = [ + variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list + ] + + workflow_service = WorkflowService() + current_draft_workflow = workflow_service.get_draft_workflow(app_model=app) + if current_draft_workflow: + unique_hash = current_draft_workflow.unique_hash + else: + unique_hash = None + workflow_service.sync_draft_workflow( + app_model=app, + graph=workflow_data.get("graph", {}), + features=workflow_data.get("features", {}), + unique_hash=unique_hash, + account=account, + environment_variables=environment_variables, + conversation_variables=conversation_variables, + ) + elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}: + # Initialize model config + model_config = data.get("model_config") + if not model_config or not isinstance(model_config, dict): + raise ValueError("Missing model_config for chat/agent-chat/completion app") + # Initialize or update model config + if not app.app_model_config: + app_model_config = AppModelConfig().from_model_config_dict(model_config) + app_model_config.id = str(uuid4()) + app_model_config.app_id = app.id + app_model_config.created_by = account.id + app_model_config.updated_by = account.id + + app.app_model_config_id = app_model_config.id + + self._session.add(app_model_config) + app_model_config_was_updated.send(app, app_model_config=app_model_config) + else: + raise ValueError("Invalid app mode") + return app + + @classmethod + def export_dsl(cls, app_model: App, include_secret: bool = False) -> str: + """ + Export app + :param app_model: App instance + :return: + """ + app_mode = AppMode.value_of(app_model.mode) + + export_data = { + "version": CURRENT_DSL_VERSION, + "kind": "app", + "app": { + "name": app_model.name, + "mode": app_model.mode, + "icon": "🤖" if app_model.icon_type == "image" else app_model.icon, + "icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background, + "description": app_model.description, + "use_icon_as_answer_icon": app_model.use_icon_as_answer_icon, + }, + } + + if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + cls._append_workflow_export_data( + export_data=export_data, app_model=app_model, include_secret=include_secret + ) + else: + cls._append_model_config_export_data(export_data, app_model) + + return yaml.dump(export_data, allow_unicode=True) + + @classmethod + def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None: + """ + Append workflow export data + :param export_data: export data + :param app_model: App instance + """ + workflow_service = WorkflowService() + workflow = workflow_service.get_draft_workflow(app_model) + if not workflow: + raise ValueError("Missing draft workflow configuration, please check.") + + export_data["workflow"] = workflow.to_dict(include_secret=include_secret) + + @classmethod + def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None: + """ + Append model config export data + :param export_data: export data + :param app_model: App instance + """ + app_model_config = app_model.app_model_config + if not app_model_config: + raise ValueError("Missing app configuration, please check.") + + export_data["model_config"] = app_model_config.to_dict() diff --git a/api/services/app_dsl_service/__init__.py b/api/services/app_dsl_service/__init__.py deleted file mode 100644 index 9fc988ffb3..0000000000 --- a/api/services/app_dsl_service/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .service import AppDslService - -__all__ = ["AppDslService"] diff --git a/api/services/app_dsl_service/exc.py b/api/services/app_dsl_service/exc.py deleted file mode 100644 index 6da4b1938f..0000000000 --- a/api/services/app_dsl_service/exc.py +++ /dev/null @@ -1,34 +0,0 @@ -class DSLVersionNotSupportedError(ValueError): - """Raised when the imported DSL version is not supported by the current Dify version.""" - - -class InvalidYAMLFormatError(ValueError): - """Raised when the provided YAML format is invalid.""" - - -class MissingAppDataError(ValueError): - """Raised when the app data is missing in the provided DSL.""" - - -class InvalidAppModeError(ValueError): - """Raised when the app mode is invalid.""" - - -class MissingWorkflowDataError(ValueError): - """Raised when the workflow data is missing in the provided DSL.""" - - -class MissingModelConfigError(ValueError): - """Raised when the model config data is missing in the provided DSL.""" - - -class FileSizeLimitExceededError(ValueError): - """Raised when the file size exceeds the allowed limit.""" - - -class EmptyContentError(ValueError): - """Raised when the content fetched from the URL is empty.""" - - -class ContentDecodingError(ValueError): - """Raised when there is an error decoding the content.""" diff --git a/api/services/app_dsl_service/service.py b/api/services/app_dsl_service/service.py deleted file mode 100644 index e6b0d9a272..0000000000 --- a/api/services/app_dsl_service/service.py +++ /dev/null @@ -1,484 +0,0 @@ -import logging -from collections.abc import Mapping -from typing import Any - -import yaml -from packaging import version - -from core.helper import ssrf_proxy -from events.app_event import app_model_config_was_updated, app_was_created -from extensions.ext_database import db -from factories import variable_factory -from models.account import Account -from models.model import App, AppMode, AppModelConfig -from models.workflow import Workflow -from services.workflow_service import WorkflowService - -from .exc import ( - ContentDecodingError, - EmptyContentError, - FileSizeLimitExceededError, - InvalidAppModeError, - InvalidYAMLFormatError, - MissingAppDataError, - MissingModelConfigError, - MissingWorkflowDataError, -) - -logger = logging.getLogger(__name__) - -current_dsl_version = "0.1.3" - - -class AppDslService: - @classmethod - def import_and_create_new_app_from_url(cls, tenant_id: str, url: str, args: dict, account: Account) -> App: - """ - Import app dsl from url and create new app - :param tenant_id: tenant id - :param url: import url - :param args: request args - :param account: Account instance - """ - max_size = 10 * 1024 * 1024 # 10MB - response = ssrf_proxy.get(url.strip(), follow_redirects=True, timeout=(10, 10)) - response.raise_for_status() - content = response.content - - if len(content) > max_size: - raise FileSizeLimitExceededError("File size exceeds the limit of 10MB") - - if not content: - raise EmptyContentError("Empty content from url") - - try: - data = content.decode("utf-8") - except UnicodeDecodeError as e: - raise ContentDecodingError(f"Error decoding content: {e}") - - return cls.import_and_create_new_app(tenant_id, data, args, account) - - @classmethod - def import_and_create_new_app(cls, tenant_id: str, data: str, args: dict, account: Account) -> App: - """ - Import app dsl and create new app - :param tenant_id: tenant id - :param data: import data - :param args: request args - :param account: Account instance - """ - try: - import_data = yaml.safe_load(data) - except yaml.YAMLError: - raise InvalidYAMLFormatError("Invalid YAML format in data argument.") - - # check or repair dsl version - import_data = _check_or_fix_dsl(import_data) - - app_data = import_data.get("app") - if not app_data: - raise MissingAppDataError("Missing app in data argument") - - # get app basic info - name = args.get("name") or app_data.get("name") - description = args.get("description") or app_data.get("description", "") - icon_type = args.get("icon_type") or app_data.get("icon_type") - icon = args.get("icon") or app_data.get("icon") - icon_background = args.get("icon_background") or app_data.get("icon_background") - use_icon_as_answer_icon = app_data.get("use_icon_as_answer_icon", False) - - # import dsl and create app - app_mode = AppMode.value_of(app_data.get("mode")) - - if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: - workflow_data = import_data.get("workflow") - if not workflow_data or not isinstance(workflow_data, dict): - raise MissingWorkflowDataError( - "Missing workflow in data argument when app mode is advanced-chat or workflow" - ) - - app = cls._import_and_create_new_workflow_based_app( - tenant_id=tenant_id, - app_mode=app_mode, - workflow_data=workflow_data, - account=account, - name=name, - description=description, - icon_type=icon_type, - icon=icon, - icon_background=icon_background, - use_icon_as_answer_icon=use_icon_as_answer_icon, - ) - elif app_mode in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION}: - model_config = import_data.get("model_config") - if not model_config or not isinstance(model_config, dict): - raise MissingModelConfigError( - "Missing model_config in data argument when app mode is chat, agent-chat or completion" - ) - - app = cls._import_and_create_new_model_config_based_app( - tenant_id=tenant_id, - app_mode=app_mode, - model_config_data=model_config, - account=account, - name=name, - description=description, - icon_type=icon_type, - icon=icon, - icon_background=icon_background, - use_icon_as_answer_icon=use_icon_as_answer_icon, - ) - else: - raise InvalidAppModeError("Invalid app mode") - - return app - - @classmethod - def import_and_overwrite_workflow(cls, app_model: App, data: str, account: Account) -> Workflow: - """ - Import app dsl and overwrite workflow - :param app_model: App instance - :param data: import data - :param account: Account instance - """ - try: - import_data = yaml.safe_load(data) - except yaml.YAMLError: - raise InvalidYAMLFormatError("Invalid YAML format in data argument.") - - # check or repair dsl version - import_data = _check_or_fix_dsl(import_data) - - app_data = import_data.get("app") - if not app_data: - raise MissingAppDataError("Missing app in data argument") - - # import dsl and overwrite app - app_mode = AppMode.value_of(app_data.get("mode")) - if app_mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: - raise InvalidAppModeError("Only support import workflow in advanced-chat or workflow app.") - - if app_data.get("mode") != app_model.mode: - raise ValueError(f"App mode {app_data.get('mode')} is not matched with current app mode {app_mode.value}") - - workflow_data = import_data.get("workflow") - if not workflow_data or not isinstance(workflow_data, dict): - raise MissingWorkflowDataError( - "Missing workflow in data argument when app mode is advanced-chat or workflow" - ) - - return cls._import_and_overwrite_workflow_based_app( - app_model=app_model, - workflow_data=workflow_data, - account=account, - ) - - @classmethod - def export_dsl(cls, app_model: App, include_secret: bool = False) -> str: - """ - Export app - :param app_model: App instance - :return: - """ - app_mode = AppMode.value_of(app_model.mode) - - export_data = { - "version": current_dsl_version, - "kind": "app", - "app": { - "name": app_model.name, - "mode": app_model.mode, - "icon": "🤖" if app_model.icon_type == "image" else app_model.icon, - "icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background, - "description": app_model.description, - "use_icon_as_answer_icon": app_model.use_icon_as_answer_icon, - }, - } - - if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: - cls._append_workflow_export_data( - export_data=export_data, app_model=app_model, include_secret=include_secret - ) - else: - cls._append_model_config_export_data(export_data, app_model) - - return yaml.dump(export_data, allow_unicode=True) - - @classmethod - def _import_and_create_new_workflow_based_app( - cls, - tenant_id: str, - app_mode: AppMode, - workflow_data: Mapping[str, Any], - account: Account, - name: str, - description: str, - icon_type: str, - icon: str, - icon_background: str, - use_icon_as_answer_icon: bool, - ) -> App: - """ - Import app dsl and create new workflow based app - - :param tenant_id: tenant id - :param app_mode: app mode - :param workflow_data: workflow data - :param account: Account instance - :param name: app name - :param description: app description - :param icon_type: app icon type, "emoji" or "image" - :param icon: app icon - :param icon_background: app icon background - :param use_icon_as_answer_icon: use app icon as answer icon - """ - if not workflow_data: - raise MissingWorkflowDataError( - "Missing workflow in data argument when app mode is advanced-chat or workflow" - ) - - app = cls._create_app( - tenant_id=tenant_id, - app_mode=app_mode, - account=account, - name=name, - description=description, - icon_type=icon_type, - icon=icon, - icon_background=icon_background, - use_icon_as_answer_icon=use_icon_as_answer_icon, - ) - - # init draft workflow - environment_variables_list = workflow_data.get("environment_variables") or [] - environment_variables = [ - variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list - ] - conversation_variables_list = workflow_data.get("conversation_variables") or [] - conversation_variables = [ - variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list - ] - workflow_service = WorkflowService() - draft_workflow = workflow_service.sync_draft_workflow( - app_model=app, - graph=workflow_data.get("graph", {}), - features=workflow_data.get("features", {}), - unique_hash=None, - account=account, - environment_variables=environment_variables, - conversation_variables=conversation_variables, - ) - workflow_service.publish_workflow(app_model=app, account=account, draft_workflow=draft_workflow) - - return app - - @classmethod - def _import_and_overwrite_workflow_based_app( - cls, app_model: App, workflow_data: Mapping[str, Any], account: Account - ) -> Workflow: - """ - Import app dsl and overwrite workflow based app - - :param app_model: App instance - :param workflow_data: workflow data - :param account: Account instance - """ - if not workflow_data: - raise MissingWorkflowDataError( - "Missing workflow in data argument when app mode is advanced-chat or workflow" - ) - - # fetch draft workflow by app_model - workflow_service = WorkflowService() - current_draft_workflow = workflow_service.get_draft_workflow(app_model=app_model) - if current_draft_workflow: - unique_hash = current_draft_workflow.unique_hash - else: - unique_hash = None - - # sync draft workflow - environment_variables_list = workflow_data.get("environment_variables") or [] - environment_variables = [ - variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list - ] - conversation_variables_list = workflow_data.get("conversation_variables") or [] - conversation_variables = [ - variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list - ] - draft_workflow = workflow_service.sync_draft_workflow( - app_model=app_model, - graph=workflow_data.get("graph", {}), - features=workflow_data.get("features", {}), - unique_hash=unique_hash, - account=account, - environment_variables=environment_variables, - conversation_variables=conversation_variables, - ) - - return draft_workflow - - @classmethod - def _import_and_create_new_model_config_based_app( - cls, - tenant_id: str, - app_mode: AppMode, - model_config_data: Mapping[str, Any], - account: Account, - name: str, - description: str, - icon_type: str, - icon: str, - icon_background: str, - use_icon_as_answer_icon: bool, - ) -> App: - """ - Import app dsl and create new model config based app - - :param tenant_id: tenant id - :param app_mode: app mode - :param model_config_data: model config data - :param account: Account instance - :param name: app name - :param description: app description - :param icon: app icon - :param icon_background: app icon background - """ - if not model_config_data: - raise MissingModelConfigError( - "Missing model_config in data argument when app mode is chat, agent-chat or completion" - ) - - app = cls._create_app( - tenant_id=tenant_id, - app_mode=app_mode, - account=account, - name=name, - description=description, - icon_type=icon_type, - icon=icon, - icon_background=icon_background, - use_icon_as_answer_icon=use_icon_as_answer_icon, - ) - - app_model_config = AppModelConfig() - app_model_config = app_model_config.from_model_config_dict(model_config_data) - app_model_config.app_id = app.id - app_model_config.created_by = account.id - app_model_config.updated_by = account.id - - db.session.add(app_model_config) - db.session.commit() - - app.app_model_config_id = app_model_config.id - - app_model_config_was_updated.send(app, app_model_config=app_model_config) - - return app - - @classmethod - def _create_app( - cls, - tenant_id: str, - app_mode: AppMode, - account: Account, - name: str, - description: str, - icon_type: str, - icon: str, - icon_background: str, - use_icon_as_answer_icon: bool, - ) -> App: - """ - Create new app - - :param tenant_id: tenant id - :param app_mode: app mode - :param account: Account instance - :param name: app name - :param description: app description - :param icon_type: app icon type, "emoji" or "image" - :param icon: app icon - :param icon_background: app icon background - :param use_icon_as_answer_icon: use app icon as answer icon - """ - app = App( - tenant_id=tenant_id, - mode=app_mode.value, - name=name, - description=description, - icon_type=icon_type, - icon=icon, - icon_background=icon_background, - enable_site=True, - enable_api=True, - use_icon_as_answer_icon=use_icon_as_answer_icon, - created_by=account.id, - updated_by=account.id, - ) - - db.session.add(app) - db.session.commit() - - app_was_created.send(app, account=account) - - return app - - @classmethod - def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None: - """ - Append workflow export data - :param export_data: export data - :param app_model: App instance - """ - workflow_service = WorkflowService() - workflow = workflow_service.get_draft_workflow(app_model) - if not workflow: - raise ValueError("Missing draft workflow configuration, please check.") - - export_data["workflow"] = workflow.to_dict(include_secret=include_secret) - - @classmethod - def _append_model_config_export_data(cls, export_data: dict, app_model: App) -> None: - """ - Append model config export data - :param export_data: export data - :param app_model: App instance - """ - app_model_config = app_model.app_model_config - if not app_model_config: - raise ValueError("Missing app configuration, please check.") - - export_data["model_config"] = app_model_config.to_dict() - - -def _check_or_fix_dsl(import_data: dict[str, Any]) -> Mapping[str, Any]: - """ - Check or fix dsl - - :param import_data: import data - :raises DSLVersionNotSupportedError: if the imported DSL version is newer than the current version - """ - if not import_data.get("version"): - import_data["version"] = "0.1.0" - - if not import_data.get("kind") or import_data.get("kind") != "app": - import_data["kind"] = "app" - - imported_version = import_data.get("version") - if imported_version != current_dsl_version: - if imported_version and version.parse(imported_version) > version.parse(current_dsl_version): - errmsg = ( - f"The imported DSL version {imported_version} is newer than " - f"the current supported version {current_dsl_version}. " - f"Please upgrade your Dify instance to import this configuration." - ) - logger.warning(errmsg) - # raise DSLVersionNotSupportedError(errmsg) - else: - logger.warning( - f"DSL version {imported_version} is older than " - f"the current version {current_dsl_version}. " - f"This may cause compatibility issues." - ) - - return import_data diff --git a/api/services/app_service.py b/api/services/app_service.py index 620d0ac270..8d8ba735ec 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,6 +1,6 @@ import json import logging -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import cast from flask_login import current_user @@ -155,7 +155,7 @@ class AppService: """ # get original app model config if app.mode == AppMode.AGENT_CHAT.value or app.is_agent: - model_config: AppModelConfig = app.app_model_config + model_config = app.app_model_config agent_mode = model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input for tool in agent_mode.get("tools") or []: @@ -223,7 +223,7 @@ class AppService: app.icon_background = args.get("icon_background") app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) app.updated_by = current_user.id - app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() if app.max_active_requests is not None: @@ -240,7 +240,7 @@ class AppService: """ app.name = name app.updated_by = current_user.id - app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return app @@ -256,7 +256,7 @@ class AppService: app.icon = icon app.icon_background = icon_background app.updated_by = current_user.id - app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return app @@ -273,7 +273,7 @@ class AppService: app.enable_site = enable_site app.updated_by = current_user.id - app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return app @@ -290,7 +290,7 @@ class AppService: app.enable_api = enable_api app.updated_by = current_user.id - app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return app @@ -341,7 +341,7 @@ class AppService: if not app_model_config: return meta - agent_config = app_model_config.agent_mode_dict or {} + agent_config = app_model_config.agent_mode_dict # get all tools tools = agent_config.get("tools", []) diff --git a/api/services/auth/auth_type.py b/api/services/auth/auth_type.py index 2d6e901447..2e1946841f 100644 --- a/api/services/auth/auth_type.py +++ b/api/services/auth/auth_type.py @@ -1,6 +1,6 @@ -from enum import Enum +from enum import StrEnum -class AuthType(str, Enum): +class AuthType(StrEnum): FIRECRAWL = "firecrawl" JINA = "jinareader" diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index f9e41988c0..8642972710 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,4 +1,5 @@ -from datetime import datetime, timezone +from collections.abc import Callable +from datetime import UTC, datetime from typing import Optional, Union from sqlalchemy import asc, desc, or_ @@ -74,14 +75,14 @@ class ConversationService: return InfiniteScrollPagination(data=conversations, limit=limit, has_more=has_more) @classmethod - def _get_sort_params(cls, sort_by: str) -> tuple[str, callable]: + def _get_sort_params(cls, sort_by: str): if sort_by.startswith("-"): return sort_by[1:], desc return sort_by, asc @classmethod def _build_filter_condition( - cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation, is_next_page: bool = False + cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation, is_next_page: bool = False ): field_value = getattr(reference_conversation, sort_field) if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page): @@ -104,7 +105,7 @@ class ConversationService: return cls.auto_generate_name(app_model, conversation) else: conversation.name = name - conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return conversation @@ -160,5 +161,5 @@ class ConversationService: conversation = cls.get_conversation(app_model, conversation_id, user) conversation.is_deleted = True - conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 806dbdf8c5..d38729f31e 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -600,7 +600,7 @@ class DocumentService: # update document to be paused document.is_paused = True document.paused_by = current_user.id - document.paused_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.paused_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(document) db.session.commit() @@ -1072,7 +1072,7 @@ class DocumentService: document.parsing_completed_at = None document.cleaning_completed_at = None document.splitting_completed_at = None - document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.created_from = created_from document.doc_form = document_data["doc_form"] db.session.add(document) @@ -1409,8 +1409,8 @@ class SegmentService: word_count=len(content), tokens=tokens, status="completed", - indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), - completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), created_by=current_user.id, ) if document.doc_form == "qa_model": @@ -1429,7 +1429,7 @@ class SegmentService: except Exception as e: logging.exception("create segment index failed") segment_document.enabled = False - segment_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment_document.status = "error" segment_document.error = str(e) db.session.commit() @@ -1481,8 +1481,8 @@ class SegmentService: word_count=len(content), tokens=tokens, status="completed", - indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), - completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), created_by=current_user.id, ) if document.doc_form == "qa_model": @@ -1508,7 +1508,7 @@ class SegmentService: logging.exception("create segment index failed") for segment_document in segment_data_list: segment_document.enabled = False - segment_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment_document.status = "error" segment_document.error = str(e) db.session.commit() @@ -1526,7 +1526,7 @@ class SegmentService: if segment.enabled != action: if not action: segment.enabled = action - segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment.disabled_by = current_user.id db.session.add(segment) db.session.commit() @@ -1585,10 +1585,10 @@ class SegmentService: segment.word_count = len(content) segment.tokens = tokens segment.status = "completed" - segment.indexing_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment.completed_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.indexing_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment.updated_by = current_user.id - segment.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment.enabled = True segment.disabled_at = None segment.disabled_by = None @@ -1608,7 +1608,7 @@ class SegmentService: except Exception as e: logging.exception("update segment index failed") segment.enabled = False - segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment.status = "error" segment.error = str(e) db.session.commit() diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 98e5d9face..7e3cd87f1e 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -1,6 +1,6 @@ import json from copy import deepcopy -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any, Optional, Union import httpx @@ -99,7 +99,7 @@ class ExternalDatasetService: external_knowledge_api.description = args.get("description", "") external_knowledge_api.settings = json.dumps(args.get("settings"), ensure_ascii=False) external_knowledge_api.updated_by = user_id - external_knowledge_api.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + external_knowledge_api.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return external_knowledge_api diff --git a/api/services/feature_service.py b/api/services/feature_service.py index d0b04628cf..c2203b167d 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import StrEnum from pydantic import BaseModel, ConfigDict @@ -22,7 +22,7 @@ class LimitationModel(BaseModel): limit: int = 0 -class LicenseStatus(str, Enum): +class LicenseStatus(StrEnum): NONE = "none" INACTIVE = "inactive" ACTIVE = "active" diff --git a/api/services/file_service.py b/api/services/file_service.py index 976111502c..b12b95ca13 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -77,7 +77,7 @@ class FileService: mime_type=mimetype, created_by_role=(CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER), created_by=user.id, - created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), used=False, hash=hashlib.sha3_256(content).hexdigest(), source_url=source_url, @@ -123,10 +123,10 @@ class FileService: mime_type="text/plain", created_by=current_user.id, created_by_role=CreatedByRole.ACCOUNT, - created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), used=True, used_by=current_user.id, - used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + used_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), ) db.session.add(upload_file) diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index e7b9422cfe..b20bda8755 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -371,7 +371,7 @@ class ModelLoadBalancingService: load_balancing_config.name = name load_balancing_config.enabled = enabled - load_balancing_config.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + load_balancing_config.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() self._clear_credentials_cache(tenant_id, config_id) diff --git a/api/services/recommend_app/recommend_app_type.py b/api/services/recommend_app/recommend_app_type.py index 7ea93b3f64..e60e435b3a 100644 --- a/api/services/recommend_app/recommend_app_type.py +++ b/api/services/recommend_app/recommend_app_type.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class RecommendAppType(str, Enum): +class RecommendAppType(StrEnum): REMOTE = "remote" BUILDIN = "builtin" DATABASE = "db" diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 1befa11531..a4aa870dc8 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -242,7 +242,7 @@ class ToolTransformService: # get tool parameters parameters = tool.parameters or [] # get tool runtime parameters - runtime_parameters = tool.get_runtime_parameters() or [] + runtime_parameters = tool.get_runtime_parameters() # override parameters current_parameters = parameters.copy() for runtime_parameter in runtime_parameters: diff --git a/api/services/website_service.py b/api/services/website_service.py index 13cc9c679a..230f5d7815 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -51,8 +51,8 @@ class WebsiteService: excludes = options.get("excludes").split(",") if options.get("excludes") else [] params = { "crawlerOptions": { - "includes": includes or [], - "excludes": excludes or [], + "includes": includes, + "excludes": excludes, "generateImgAltText": True, "limit": options.get("limit", 1), "returnOnlyUrls": False, diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 7187d40517..aa2babd7f7 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,7 +1,7 @@ import json import time from collections.abc import Sequence -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Optional from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager @@ -115,7 +115,7 @@ class WorkflowService: workflow.graph = json.dumps(graph) workflow.features = json.dumps(features) workflow.updated_by = account.id - workflow.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) workflow.environment_variables = environment_variables workflow.conversation_variables = conversation_variables @@ -148,7 +148,7 @@ class WorkflowService: tenant_id=app_model.tenant_id, app_id=app_model.id, type=draft_workflow.type, - version=str(datetime.now(timezone.utc).replace(tzinfo=None)), + version=str(datetime.now(UTC).replace(tzinfo=None)), graph=draft_workflow.graph, features=draft_workflow.features, created_by=account.id, @@ -257,18 +257,22 @@ class WorkflowService: workflow_node_execution.elapsed_time = time.perf_counter() - start_at workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value workflow_node_execution.created_by = account.id - workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) - workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) + workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) if run_succeeded and node_run_result: # create workflow node execution - workflow_node_execution.inputs = json.dumps(node_run_result.inputs) if node_run_result.inputs else None - workflow_node_execution.process_data = ( - json.dumps(node_run_result.process_data) if node_run_result.process_data else None - ) - workflow_node_execution.outputs = ( - json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None + inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None + process_data = ( + WorkflowEntry.handle_special_values(node_run_result.process_data) + if node_run_result.process_data + else None ) + outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None + + workflow_node_execution.inputs = json.dumps(inputs) + workflow_node_execution.process_data = json.dumps(process_data) + workflow_node_execution.outputs = json.dumps(outputs) workflow_node_execution.execution_metadata = ( json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None ) @@ -303,10 +307,10 @@ class WorkflowService: new_app = workflow_converter.convert_to_workflow( app_model=app_model, account=account, - name=args.get("name"), - icon_type=args.get("icon_type"), - icon=args.get("icon"), - icon_background=args.get("icon_background"), + name=args.get("name", "Default Name"), + icon_type=args.get("icon_type", "emoji"), + icon=args.get("icon", "🤖"), + icon_background=args.get("icon_background", "#FFEAD5"), ) return new_app diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index b50876cc79..09be661216 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -18,9 +18,9 @@ from models.dataset import DocumentSegment def add_document_to_index_task(dataset_document_id: str): """ Async Add document to index - :param document_id: + :param dataset_document_id: - Usage: add_document_to_index.delay(document_id) + Usage: add_document_to_index.delay(dataset_document_id) """ logging.info(click.style("Start add document to index: {}".format(dataset_document_id), fg="green")) start_at = time.perf_counter() @@ -74,7 +74,7 @@ def add_document_to_index_task(dataset_document_id: str): except Exception as e: logging.exception("add document to index failed") dataset_document.enabled = False - dataset_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + dataset_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) dataset_document.status = "error" dataset_document.error = str(e) db.session.commit() diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index e819bf3635..0bdcd0eccd 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -52,7 +52,7 @@ def enable_annotation_reply_task( annotation_setting.score_threshold = score_threshold annotation_setting.collection_binding_id = dataset_collection_binding.id annotation_setting.updated_user_id = user_id - annotation_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + annotation_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(annotation_setting) else: new_app_annotation_setting = AppAnnotationSetting( diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 5ee72c27fc..dcb7009e44 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -80,9 +80,9 @@ def batch_create_segment_to_index_task( word_count=len(content), tokens=tokens, created_by=user_id, - indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), status="completed", - completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), ) if dataset_document.doc_form == "qa_model": segment_document.answer = segment["answer"] diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py index 4d45df4d2a..a555fb2874 100644 --- a/api/tasks/clean_dataset_task.py +++ b/api/tasks/clean_dataset_task.py @@ -78,6 +78,7 @@ def clean_dataset_task( "Delete image_files failed when storage deleted, \ image_upload_file_is: {}".format(upload_file_id) ) + db.session.delete(image_file) db.session.delete(segment) db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete() diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 54c89450c9..4d328643bf 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -51,6 +51,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i "Delete image_files failed when storage deleted, \ image_upload_file_is: {}".format(upload_file_id) ) + db.session.delete(image_file) db.session.delete(segment) db.session.commit() diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 26375743b6..315b01f157 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -38,7 +38,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] # update segment status to indexing update_params = { DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), } DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.commit() @@ -75,7 +75,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] # update segment to completed update_params = { DocumentSegment.status: "completed", - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), } DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.commit() @@ -87,7 +87,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] except Exception as e: logging.exception("create segment to index failed") segment.enabled = False - segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment.status = "error" segment.error = str(e) db.session.commit() diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 6dd755ab03..1831691393 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -67,7 +67,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): # check the page is updated if last_edited_time != page_edited_time: document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() # delete all document segment and index diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index df1177d578..734dd2478a 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -50,7 +50,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(document) db.session.commit() return @@ -64,7 +64,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): if document: document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) documents.append(document) db.session.add(document) db.session.commit() diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index cb38bc668d..1a52a6636b 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -30,7 +30,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): raise NotFound("Document not found") document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() # delete all document segment and index diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 1412ad9ec7..12639db939 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -71,7 +71,7 @@ def enable_segment_to_index_task(segment_id: str): except Exception as e: logging.exception("enable segment to index failed") segment.enabled = False - segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment.status = "error" segment.error = str(e) db.session.commit() diff --git a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py index 85a4f7734d..b995077984 100644 --- a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py +++ b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_llm.py @@ -11,7 +11,6 @@ from core.model_runtime.entities.message_entities import ( ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.azure_ai_studio.llm.llm import AzureAIStudioLargeLanguageModel -from tests.integration_tests.model_runtime.__mock.azure_ai_studio import setup_azure_ai_studio_mock @pytest.mark.parametrize("setup_azure_ai_studio_mock", [["chat"]], indirect=True) diff --git a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py index 466facc5ff..4d72327c0e 100644 --- a/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/azure_ai_studio/test_rerank.py @@ -4,29 +4,21 @@ import pytest from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureAIStudioRerankModel +from core.model_runtime.model_providers.azure_ai_studio.rerank.rerank import AzureRerankModel def test_validate_credentials(): - model = AzureAIStudioRerankModel() + model = AzureRerankModel() with pytest.raises(CredentialsValidateFailedError): model.validate_credentials( model="azure-ai-studio-rerank-v1", credentials={"api_key": "invalid_key", "api_base": os.getenv("AZURE_AI_STUDIO_API_BASE")}, - query="What is the capital of the United States?", - docs=[ - "Carson City is the capital city of the American state of Nevada. At the 2010 United States " - "Census, Carson City had a population of 55,274.", - "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " - "are a political division controlled by the United States. Its capital is Saipan.", - ], - score_threshold=0.8, ) def test_invoke_model(): - model = AzureAIStudioRerankModel() + model = AzureRerankModel() result = model.invoke( model="azure-ai-studio-rerank-v1", diff --git a/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py b/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py index b1fa9d5ca5..33160062e5 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py +++ b/api/tests/integration_tests/model_runtime/huggingface_tei/test_embeddings.py @@ -40,6 +40,7 @@ def test_validate_credentials(setup_tei_mock): model="reranker", credentials={ "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), + "api_key": os.environ.get("TEI_API_KEY", ""), }, ) @@ -47,6 +48,7 @@ def test_validate_credentials(setup_tei_mock): model=model_name, credentials={ "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), + "api_key": os.environ.get("TEI_API_KEY", ""), }, ) @@ -60,6 +62,7 @@ def test_invoke_model(setup_tei_mock): model=model_name, credentials={ "server_url": os.environ.get("TEI_EMBEDDING_SERVER_URL", ""), + "api_key": os.environ.get("TEI_API_KEY", ""), }, texts=["hello", "world"], user="abc-123", diff --git a/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py b/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py index cd1c20dd02..9777367063 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/huggingface_tei/test_rerank.py @@ -40,6 +40,7 @@ def test_validate_credentials(setup_tei_mock): model="embedding", credentials={ "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), + "api_key": os.environ.get("TEI_API_KEY", ""), }, ) @@ -47,6 +48,7 @@ def test_validate_credentials(setup_tei_mock): model=model_name, credentials={ "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), + "api_key": os.environ.get("TEI_API_KEY", ""), }, ) @@ -61,6 +63,7 @@ def test_invoke_model(setup_tei_mock): model=model_name, credentials={ "server_url": os.environ.get("TEI_RERANK_SERVER_URL"), + "api_key": os.environ.get("TEI_API_KEY", ""), }, query="Who is Kasumi?", docs=[ diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py index 2808b5b0fa..d98e9f6bad 100644 --- a/api/tests/unit_tests/core/test_model_manager.py +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -1,10 +1,12 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest +import redis from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.model_manager import LBModelManager from core.model_runtime.entities.model_entities import ModelType +from extensions.ext_redis import redis_client @pytest.fixture @@ -38,6 +40,9 @@ def lb_model_manager(): def test_lb_model_manager_fetch_next(mocker, lb_model_manager): + # initialize redis client + redis_client.initialize(redis.Redis()) + assert len(lb_model_manager._load_balancing_configs) == 3 config1 = lb_model_manager._load_balancing_configs[0] @@ -55,12 +60,13 @@ def test_lb_model_manager_fetch_next(mocker, lb_model_manager): start_index += 1 return start_index - mocker.patch("redis.Redis.incr", side_effect=incr) - mocker.patch("redis.Redis.set", return_value=None) - mocker.patch("redis.Redis.expire", return_value=None) + with ( + patch.object(redis_client, "incr", side_effect=incr), + patch.object(redis_client, "set", return_value=None), + patch.object(redis_client, "expire", return_value=None), + ): + config = lb_model_manager.fetch_next() + assert config == config2 - config = lb_model_manager.fetch_next() - assert config == config2 - - config = lb_model_manager.fetch_next() - assert config == config3 + config = lb_model_manager.fetch_next() + assert config == config3 diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py index f6b3be8250..f6555cfdde 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Generator -from datetime import datetime, timezone +from datetime import UTC, datetime, timezone from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey @@ -29,7 +29,7 @@ def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngine def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]: - route_node_state = RouteNodeState(node_id=next_node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None)) + route_node_state = RouteNodeState(node_id=next_node_id, start_at=datetime.now(UTC).replace(tzinfo=None)) parallel_id = graph.node_parallel_mapping.get(next_node_id) parallel_start_node_id = None @@ -68,7 +68,7 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve ) route_node_state.status = RouteNodeState.Status.SUCCESS - route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + route_node_state.finished_at = datetime.now(UTC).replace(tzinfo=None) yield NodeRunSucceededEvent( id=node_execution_id, node_id=next_node_id, diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py new file mode 100644 index 0000000000..0f6b7e4ab6 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py @@ -0,0 +1,140 @@ +from unittest.mock import Mock, PropertyMock, patch + +import httpx +import pytest + +from core.workflow.nodes.http_request.entities import Response + + +@pytest.fixture +def mock_response(): + response = Mock(spec=httpx.Response) + response.headers = {} + return response + + +def test_is_file_with_attachment_disposition(mock_response): + """Test is_file when content-disposition header contains 'attachment'""" + mock_response.headers = {"content-disposition": "attachment; filename=test.pdf", "content-type": "application/pdf"} + response = Response(mock_response) + assert response.is_file + + +def test_is_file_with_filename_disposition(mock_response): + """Test is_file when content-disposition header contains filename parameter""" + mock_response.headers = {"content-disposition": "inline; filename=test.pdf", "content-type": "application/pdf"} + response = Response(mock_response) + assert response.is_file + + +@pytest.mark.parametrize("content_type", ["application/pdf", "image/jpeg", "audio/mp3", "video/mp4"]) +def test_is_file_with_file_content_types(mock_response, content_type): + """Test is_file with various file content types""" + mock_response.headers = {"content-type": content_type} + # Mock binary content + type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) + response = Response(mock_response) + assert response.is_file, f"Content type {content_type} should be identified as a file" + + +@pytest.mark.parametrize( + "content_type", + [ + "application/json", + "application/xml", + "application/javascript", + "application/x-www-form-urlencoded", + "application/yaml", + "application/graphql", + ], +) +def test_text_based_application_types(mock_response, content_type): + """Test common text-based application types are not identified as files""" + mock_response.headers = {"content-type": content_type} + response = Response(mock_response) + assert not response.is_file, f"Content type {content_type} should not be identified as a file" + + +@pytest.mark.parametrize( + ("content", "content_type"), + [ + (b'{"key": "value"}', "application/octet-stream"), + (b"[1, 2, 3]", "application/unknown"), + (b"function test() {}", "application/x-unknown"), + (b"test", "application/binary"), + (b"var x = 1;", "application/data"), + ], +) +def test_content_based_detection(mock_response, content, content_type): + """Test content-based detection for text-like content""" + mock_response.headers = {"content-type": content_type} + type(mock_response).content = PropertyMock(return_value=content) + response = Response(mock_response) + assert not response.is_file, f"Content {content} with type {content_type} should not be identified as a file" + + +@pytest.mark.parametrize( + ("content", "content_type"), + [ + (bytes([0x00, 0xFF] * 512), "application/octet-stream"), + (bytes([0x89, 0x50, 0x4E, 0x47]), "application/unknown"), # PNG magic numbers + (bytes([0xFF, 0xD8, 0xFF]), "application/binary"), # JPEG magic numbers + ], +) +def test_binary_content_detection(mock_response, content, content_type): + """Test content-based detection for binary content""" + mock_response.headers = {"content-type": content_type} + type(mock_response).content = PropertyMock(return_value=content) + response = Response(mock_response) + assert response.is_file, f"Binary content with type {content_type} should be identified as a file" + + +@pytest.mark.parametrize( + ("content_type", "expected_main_type"), + [ + ("x-world/x-vrml", "model"), # VRML 3D model + ("font/ttf", "application"), # TrueType font + ("text/csv", "text"), # CSV text file + ("unknown/xyz", None), # Unknown type + ], +) +def test_mimetype_based_detection(mock_response, content_type, expected_main_type): + """Test detection using mimetypes.guess_type for non-application content types""" + mock_response.headers = {"content-type": content_type} + type(mock_response).content = PropertyMock(return_value=bytes([0x00])) # Dummy content + + with patch("core.workflow.nodes.http_request.entities.mimetypes.guess_type") as mock_guess_type: + # Mock the return value based on expected_main_type + if expected_main_type: + mock_guess_type.return_value = (f"{expected_main_type}/subtype", None) + else: + mock_guess_type.return_value = (None, None) + + response = Response(mock_response) + + # Check if the result matches our expectation + if expected_main_type in ("application", "image", "audio", "video"): + assert response.is_file, f"Content type {content_type} should be identified as a file" + else: + assert not response.is_file, f"Content type {content_type} should not be identified as a file" + + # Verify that guess_type was called + mock_guess_type.assert_called_once() + + +def test_is_file_with_inline_disposition(mock_response): + """Test is_file when content-disposition is 'inline'""" + mock_response.headers = {"content-disposition": "inline", "content-type": "application/pdf"} + # Mock binary content + type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) + response = Response(mock_response) + assert response.is_file + + +def test_is_file_with_no_content_disposition(mock_response): + """Test is_file when no content-disposition header is present""" + mock_response.headers = {"content-type": "application/pdf"} + # Mock binary content + type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) + response = Response(mock_response) + assert response.is_file diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index def6c2a232..9a24d35a1f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -1,125 +1,431 @@ +from collections.abc import Sequence +from typing import Optional + import pytest -from core.app.entities.app_invoke_entities import InvokeFrom +from configs import dify_config +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import CustomConfiguration, SystemConfiguration from core.file import File, FileTransferMethod, FileType -from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + PromptMessageRole, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelFeature, ModelType, ProviderModel +from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState from core.workflow.nodes.answer import AnswerStreamGenerateRoute from core.workflow.nodes.end import EndStreamParam -from core.workflow.nodes.llm.entities import ContextConfig, LLMNodeData, ModelConfig, VisionConfig, VisionConfigOptions +from core.workflow.nodes.llm.entities import ( + ContextConfig, + LLMNodeChatModelMessage, + LLMNodeData, + ModelConfig, + VisionConfig, + VisionConfigOptions, +) from core.workflow.nodes.llm.node import LLMNode from models.enums import UserFrom +from models.provider import ProviderType from models.workflow import WorkflowType +from tests.unit_tests.core.workflow.nodes.llm.test_scenarios import LLMNodeTestScenario -class TestLLMNode: - @pytest.fixture - def llm_node(self): - data = LLMNodeData( - title="Test LLM", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), - prompt_template=[], - memory=None, - context=ContextConfig(enabled=False), - vision=VisionConfig( - enabled=True, - configs=VisionConfigOptions( - variable_selector=["sys", "files"], - detail=ImagePromptMessageContent.DETAIL.HIGH, - ), - ), - ) - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - node = LLMNode( - id="1", - config={ - "id": "1", - "data": data.model_dump(), - }, - graph_init_params=GraphInitParams( - tenant_id="1", - app_id="1", - workflow_type=WorkflowType.WORKFLOW, - workflow_id="1", - graph_config={}, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, - ), - graph=Graph( - root_node_id="1", - answer_stream_generate_routes=AnswerStreamGenerateRoute( - answer_dependencies={}, - answer_generate_route={}, - ), - end_stream_param=EndStreamParam( - end_dependencies={}, - end_stream_variable_selector_mapping={}, - ), - ), - graph_runtime_state=GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ), - ) - return node +class MockTokenBufferMemory: + def __init__(self, history_messages=None): + self.history_messages = history_messages or [] - def test_fetch_files_with_file_segment(self, llm_node): - file = File( + def get_history_prompt_messages( + self, max_token_limit: int = 2000, message_limit: Optional[int] = None + ) -> Sequence[PromptMessage]: + if message_limit is not None: + return self.history_messages[-message_limit * 2 :] + return self.history_messages + + +@pytest.fixture +def llm_node(): + data = LLMNodeData( + title="Test LLM", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), + prompt_template=[], + memory=None, + context=ContextConfig(enabled=False), + vision=VisionConfig( + enabled=True, + configs=VisionConfigOptions( + variable_selector=["sys", "files"], + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ), + ) + variable_pool = VariablePool( + system_variables={}, + user_inputs={}, + ) + node = LLMNode( + id="1", + config={ + "id": "1", + "data": data.model_dump(), + }, + graph_init_params=GraphInitParams( + tenant_id="1", + app_id="1", + workflow_type=WorkflowType.WORKFLOW, + workflow_id="1", + graph_config={}, + user_id="1", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.SERVICE_API, + call_depth=0, + ), + graph=Graph( + root_node_id="1", + answer_stream_generate_routes=AnswerStreamGenerateRoute( + answer_dependencies={}, + answer_generate_route={}, + ), + end_stream_param=EndStreamParam( + end_dependencies={}, + end_stream_variable_selector_mapping={}, + ), + ), + graph_runtime_state=GraphRuntimeState( + variable_pool=variable_pool, + start_at=0, + ), + ) + return node + + +@pytest.fixture +def model_config(): + # Create actual provider and model type instances + model_provider_factory = ModelProviderFactory() + provider_instance = model_provider_factory.get_provider_instance("openai") + model_type_instance = provider_instance.get_model_instance(ModelType.LLM) + + # Create a ProviderModelBundle + provider_model_bundle = ProviderModelBundle( + configuration=ProviderConfiguration( + tenant_id="1", + provider=provider_instance.get_provider_schema(), + preferred_provider_type=ProviderType.CUSTOM, + using_provider_type=ProviderType.CUSTOM, + system_configuration=SystemConfiguration(enabled=False), + custom_configuration=CustomConfiguration(provider=None), + model_settings=[], + ), + provider_instance=provider_instance, + model_type_instance=model_type_instance, + ) + + # Create and return a ModelConfigWithCredentialsEntity + return ModelConfigWithCredentialsEntity( + provider="openai", + model="gpt-3.5-turbo", + model_schema=AIModelEntity( + model="gpt-3.5-turbo", + label=I18nObject(en_US="GPT-3.5 Turbo"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={}, + ), + mode="chat", + credentials={}, + parameters={}, + provider_model_bundle=provider_model_bundle, + ) + + +def test_fetch_files_with_file_segment(llm_node): + file = File( + id="1", + tenant_id="test", + type=FileType.IMAGE, + filename="test.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="1", + ) + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [file] + + +def test_fetch_files_with_array_file_segment(llm_node): + files = [ + File( id="1", tenant_id="test", type=FileType.IMAGE, - filename="test.jpg", + filename="test1.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", + ), + File( + id="2", + tenant_id="test", + type=FileType.IMAGE, + filename="test2.jpg", + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="2", + ), + ] + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == files + + +def test_fetch_files_with_none_segment(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + +def test_fetch_files_with_array_any_segment(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) + + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + +def test_fetch_files_with_non_existent_variable(llm_node): + result = llm_node._fetch_files(selector=["sys", "files"]) + assert result == [] + + +def test_fetch_prompt_messages__vison_disabled(faker, llm_node, model_config): + prompt_template = [] + llm_node.node_data.prompt_template = prompt_template + + fake_vision_detail = faker.random_element( + [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW] + ) + fake_remote_url = faker.url() + files = [ + File( + id="1", + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, ) - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) + ] - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == [file] + fake_query = faker.sentence() - def test_fetch_files_with_array_file_segment(self, llm_node): - files = [ - File( - id="1", - tenant_id="test", - type=FileType.IMAGE, - filename="test1.jpg", - transfer_method=FileTransferMethod.LOCAL_FILE, - related_id="1", - ), - File( - id="2", - tenant_id="test", - type=FileType.IMAGE, - filename="test2.jpg", - transfer_method=FileTransferMethod.LOCAL_FILE, - related_id="2", - ), - ] - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) + prompt_messages, _ = llm_node._fetch_prompt_messages( + user_query=fake_query, + user_files=files, + context=None, + memory=None, + model_config=model_config, + prompt_template=prompt_template, + memory_config=None, + vision_enabled=False, + vision_detail=fake_vision_detail, + variable_pool=llm_node.graph_runtime_state.variable_pool, + jinja2_variables=[], + ) - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == files + assert prompt_messages == [UserPromptMessage(content=fake_query)] - def test_fetch_files_with_none_segment(self, llm_node): - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == [] +def test_fetch_prompt_messages__basic(faker, llm_node, model_config): + # Setup dify config + dify_config.MULTIMODAL_SEND_IMAGE_FORMAT = "url" + dify_config.MULTIMODAL_SEND_VIDEO_FORMAT = "url" - def test_fetch_files_with_array_any_segment(self, llm_node): - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) + # Generate fake values for prompt template + fake_assistant_prompt = faker.sentence() + fake_query = faker.sentence() + fake_context = faker.sentence() + fake_window_size = faker.random_int(min=1, max=3) + fake_vision_detail = faker.random_element( + [ImagePromptMessageContent.DETAIL.HIGH, ImagePromptMessageContent.DETAIL.LOW] + ) + fake_remote_url = faker.url() - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == [] + # Setup mock memory with history messages + mock_history = [ + UserPromptMessage(content=faker.sentence()), + AssistantPromptMessage(content=faker.sentence()), + UserPromptMessage(content=faker.sentence()), + AssistantPromptMessage(content=faker.sentence()), + UserPromptMessage(content=faker.sentence()), + AssistantPromptMessage(content=faker.sentence()), + ] - def test_fetch_files_with_non_existent_variable(self, llm_node): - result = llm_node._fetch_files(selector=["sys", "files"]) - assert result == [] + # Setup memory configuration + memory_config = MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=fake_window_size), + query_prompt_template=None, + ) + + memory = MockTokenBufferMemory(history_messages=mock_history) + + # Test scenarios covering different file input combinations + test_scenarios = [ + LLMNodeTestScenario( + description="No files", + user_query=fake_query, + user_files=[], + features=[], + vision_enabled=False, + vision_detail=None, + window_size=fake_window_size, + prompt_template=[ + LLMNodeChatModelMessage( + text=fake_context, + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="{#context#}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text=fake_assistant_prompt, + role=PromptMessageRole.ASSISTANT, + edition_type="basic", + ), + ], + expected_messages=[ + SystemPromptMessage(content=fake_context), + UserPromptMessage(content=fake_context), + AssistantPromptMessage(content=fake_assistant_prompt), + ] + + mock_history[fake_window_size * -2 :] + + [ + UserPromptMessage(content=fake_query), + ], + ), + LLMNodeTestScenario( + description="User files", + user_query=fake_query, + user_files=[ + File( + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + ) + ], + vision_enabled=True, + vision_detail=fake_vision_detail, + features=[ModelFeature.VISION], + window_size=fake_window_size, + prompt_template=[ + LLMNodeChatModelMessage( + text=fake_context, + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="{#context#}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text=fake_assistant_prompt, + role=PromptMessageRole.ASSISTANT, + edition_type="basic", + ), + ], + expected_messages=[ + SystemPromptMessage(content=fake_context), + UserPromptMessage(content=fake_context), + AssistantPromptMessage(content=fake_assistant_prompt), + ] + + mock_history[fake_window_size * -2 :] + + [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data=fake_query), + ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + ] + ), + ], + ), + LLMNodeTestScenario( + description="Prompt template with variable selector of File", + user_query=fake_query, + user_files=[], + vision_enabled=False, + vision_detail=fake_vision_detail, + features=[ModelFeature.VISION], + window_size=fake_window_size, + prompt_template=[ + LLMNodeChatModelMessage( + text="{{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + ], + expected_messages=[ + UserPromptMessage( + content=[ + ImagePromptMessageContent(data=fake_remote_url, detail=fake_vision_detail), + ] + ), + ] + + mock_history[fake_window_size * -2 :] + + [UserPromptMessage(content=fake_query)], + file_variables={ + "input.image": File( + tenant_id="test", + type=FileType.IMAGE, + filename="test1.jpg", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=fake_remote_url, + ) + }, + ), + ] + + for scenario in test_scenarios: + model_config.model_schema.features = scenario.features + + for k, v in scenario.file_variables.items(): + selector = k.split(".") + llm_node.graph_runtime_state.variable_pool.add(selector, v) + + # Call the method under test + prompt_messages, _ = llm_node._fetch_prompt_messages( + user_query=scenario.user_query, + user_files=scenario.user_files, + context=fake_context, + memory=memory, + model_config=model_config, + prompt_template=scenario.prompt_template, + memory_config=memory_config, + vision_enabled=scenario.vision_enabled, + vision_detail=scenario.vision_detail, + variable_pool=llm_node.graph_runtime_state.variable_pool, + jinja2_variables=[], + ) + + # Verify the result + assert len(prompt_messages) == len(scenario.expected_messages), f"Scenario failed: {scenario.description}" + assert ( + prompt_messages == scenario.expected_messages + ), f"Message content mismatch in scenario: {scenario.description}" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py new file mode 100644 index 0000000000..8e39445baf --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py @@ -0,0 +1,25 @@ +from collections.abc import Mapping, Sequence + +from pydantic import BaseModel, Field + +from core.file import File +from core.model_runtime.entities.message_entities import PromptMessage +from core.model_runtime.entities.model_entities import ModelFeature +from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage + + +class LLMNodeTestScenario(BaseModel): + """Test scenario for LLM node testing.""" + + description: str = Field(..., description="Description of the test scenario") + user_query: str = Field(..., description="User query input") + user_files: Sequence[File] = Field(default_factory=list, description="List of user files") + vision_enabled: bool = Field(default=False, description="Whether vision is enabled") + vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled") + features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features") + window_size: int = Field(..., description="Window size for memory") + prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages") + file_variables: Mapping[str, File | Sequence[File]] = Field( + default_factory=dict, description="List of file variables" + ) + expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing") diff --git a/api/tests/unit_tests/services/app_dsl_service/test_app_dsl_service.py b/api/tests/unit_tests/services/app_dsl_service/test_app_dsl_service.py deleted file mode 100644 index 842e8268d1..0000000000 --- a/api/tests/unit_tests/services/app_dsl_service/test_app_dsl_service.py +++ /dev/null @@ -1,47 +0,0 @@ -import pytest -from packaging import version - -from services.app_dsl_service import AppDslService -from services.app_dsl_service.exc import DSLVersionNotSupportedError -from services.app_dsl_service.service import _check_or_fix_dsl, current_dsl_version - - -class TestAppDSLService: - @pytest.mark.skip(reason="Test skipped") - def test_check_or_fix_dsl_missing_version(self): - import_data = {} - result = _check_or_fix_dsl(import_data) - assert result["version"] == "0.1.0" - assert result["kind"] == "app" - - @pytest.mark.skip(reason="Test skipped") - def test_check_or_fix_dsl_missing_kind(self): - import_data = {"version": "0.1.0"} - result = _check_or_fix_dsl(import_data) - assert result["kind"] == "app" - - @pytest.mark.skip(reason="Test skipped") - def test_check_or_fix_dsl_older_version(self): - import_data = {"version": "0.0.9", "kind": "app"} - result = _check_or_fix_dsl(import_data) - assert result["version"] == "0.0.9" - - @pytest.mark.skip(reason="Test skipped") - def test_check_or_fix_dsl_current_version(self): - import_data = {"version": current_dsl_version, "kind": "app"} - result = _check_or_fix_dsl(import_data) - assert result["version"] == current_dsl_version - - @pytest.mark.skip(reason="Test skipped") - def test_check_or_fix_dsl_newer_version(self): - current_version = version.parse(current_dsl_version) - newer_version = f"{current_version.major}.{current_version.minor + 1}.0" - import_data = {"version": newer_version, "kind": "app"} - with pytest.raises(DSLVersionNotSupportedError): - _check_or_fix_dsl(import_data) - - @pytest.mark.skip(reason="Test skipped") - def test_check_or_fix_dsl_invalid_kind(self): - import_data = {"version": current_dsl_version, "kind": "invalid"} - result = _check_or_fix_dsl(import_data) - assert result["kind"] == "app" diff --git a/docker-legacy/docker-compose.chroma.yaml b/docker-legacy/docker-compose.chroma.yaml index a943d620c0..63354305de 100644 --- a/docker-legacy/docker-compose.chroma.yaml +++ b/docker-legacy/docker-compose.chroma.yaml @@ -1,7 +1,7 @@ services: # Chroma vector store. chroma: - image: ghcr.io/chroma-core/chroma:0.5.1 + image: ghcr.io/chroma-core/chroma:0.5.20 restart: always volumes: - ./volumes/chroma:/chroma/chroma diff --git a/docker-legacy/docker-compose.yaml b/docker-legacy/docker-compose.yaml index 7bf2cd4708..7ddb98e272 100644 --- a/docker-legacy/docker-compose.yaml +++ b/docker-legacy/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: # API service api: - image: langgenius/dify-api:0.11.2 + image: langgenius/dify-api:0.12.0 restart: always environment: # Startup mode, 'api' starts the API server. @@ -227,7 +227,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.11.2 + image: langgenius/dify-api:0.12.0 restart: always environment: CONSOLE_WEB_URL: '' @@ -397,7 +397,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.11.2 + image: langgenius/dify-web:0.12.0 restart: always environment: # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is diff --git a/docker/.env.example b/docker/.env.example index be8d72339f..50dc56a5c9 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -240,6 +240,12 @@ REDIS_SENTINEL_USERNAME= REDIS_SENTINEL_PASSWORD= REDIS_SENTINEL_SOCKET_TIMEOUT=0.1 +# List of Redis Cluster nodes. If Cluster mode is enabled, provide at least one Cluster IP and port. +# Format: `:,:,:` +REDIS_USE_CLUSTERS=false +REDIS_CLUSTERS= +REDIS_CLUSTERS_PASSWORD= + # ------------------------------ # Celery Configuration # ------------------------------ @@ -568,9 +574,11 @@ UPLOAD_FILE_BATCH_LIMIT=5 # `Unstructured` Unstructured.io file extraction scheme ETL_TYPE=dify -# Unstructured API path, needs to be configured when ETL_TYPE is Unstructured. +# Unstructured API path and API key, needs to be configured when ETL_TYPE is Unstructured +# Or using Unstructured for document extractor node for pptx. # For example: http://unstructured:8000/general/v0/general UNSTRUCTURED_API_URL= +UNSTRUCTURED_API_KEY= # ------------------------------ # Model Configuration diff --git a/docker/README.md b/docker/README.md index 7ce3f9bd75..c3cd1f9e3c 100644 --- a/docker/README.md +++ b/docker/README.md @@ -36,7 +36,7 @@ Welcome to the new `docker` directory for deploying Dify using Docker Compose. T - Navigate to the `docker` directory. - Ensure the `middleware.env` file is created by running `cp middleware.env.example middleware.env` (refer to the `middleware.env.example` file). 2. **Running Middleware Services**: - - Execute `docker-compose -f docker-compose.middleware.yaml up -d` to start the middleware services. + - Execute `docker-compose -f docker-compose.middleware.yaml up --env-file middleware.env -d` to start the middleware services. ### Migration for Existing Users diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 2eea273e72..11f5302197 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -29,11 +29,13 @@ services: redis: image: redis:6-alpine restart: always + environment: + REDISCLI_AUTH: ${REDIS_PASSWORD:-difyai123456} volumes: # Mount the redis data directory to the container. - ${REDIS_HOST_VOLUME:-./volumes/redis/data}:/data # Set the redis password when startup redis server. - command: redis-server --requirepass difyai123456 + command: redis-server --requirepass ${REDIS_PASSWORD:-difyai123456} ports: - "${EXPOSE_REDIS_PORT:-6379}:6379" healthcheck: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index b6caff90d9..9a135e7b54 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -55,6 +55,9 @@ x-shared-env: &shared-api-worker-env REDIS_SENTINEL_USERNAME: ${REDIS_SENTINEL_USERNAME:-} REDIS_SENTINEL_PASSWORD: ${REDIS_SENTINEL_PASSWORD:-} REDIS_SENTINEL_SOCKET_TIMEOUT: ${REDIS_SENTINEL_SOCKET_TIMEOUT:-0.1} + REDIS_CLUSTERS: ${REDIS_CLUSTERS:-} + REDIS_USE_CLUSTERS: ${REDIS_USE_CLUSTERS:-false} + REDIS_CLUSTERS_PASSWORD: ${REDIS_CLUSTERS_PASSWORD:-} ACCESS_TOKEN_EXPIRE_MINUTES: ${ACCESS_TOKEN_EXPIRE_MINUTES:-60} CELERY_BROKER_URL: ${CELERY_BROKER_URL:-redis://:difyai123456@redis:6379/1} BROKER_USE_SSL: ${BROKER_USE_SSL:-false} @@ -174,7 +177,7 @@ x-shared-env: &shared-api-worker-env ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} LINDORM_URL: ${LINDORM_URL:-http://lindorm:30070} LINDORM_USERNAME: ${LINDORM_USERNAME:-lindorm} - LINDORM_PASSWORD: ${LINDORM_USERNAME:-lindorm } + LINDORM_PASSWORD: ${LINDORM_PASSWORD:-lindorm } KIBANA_PORT: ${KIBANA_PORT:-5601} # AnalyticDB configuration ANALYTICDB_KEY_ID: ${ANALYTICDB_KEY_ID:-} @@ -219,6 +222,7 @@ x-shared-env: &shared-api-worker-env UPLOAD_FILE_BATCH_LIMIT: ${UPLOAD_FILE_BATCH_LIMIT:-5} ETL_TYPE: ${ETL_TYPE:-dify} UNSTRUCTURED_API_URL: ${UNSTRUCTURED_API_URL:-} + UNSTRUCTURED_API_KEY: ${UNSTRUCTURED_API_KEY:-} PROMPT_GENERATION_MAX_TOKENS: ${PROMPT_GENERATION_MAX_TOKENS:-512} CODE_GENERATION_MAX_TOKENS: ${CODE_GENERATION_MAX_TOKENS:-1024} MULTIMODAL_SEND_IMAGE_FORMAT: ${MULTIMODAL_SEND_IMAGE_FORMAT:-base64} @@ -287,7 +291,7 @@ x-shared-env: &shared-api-worker-env services: # API service api: - image: langgenius/dify-api:0.11.2 + image: langgenius/dify-api:0.12.0 restart: always environment: # Use the shared environment variables. @@ -307,7 +311,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.11.2 + image: langgenius/dify-api:0.12.0 restart: always environment: # Use the shared environment variables. @@ -326,7 +330,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.11.2 + image: langgenius/dify-web:0.12.0 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -363,6 +367,8 @@ services: redis: image: redis:6-alpine restart: always + environment: + REDISCLI_AUTH: ${REDIS_PASSWORD:-difyai123456} volumes: # Mount the redis data directory to the container. - ./volumes/redis/data:/data @@ -599,7 +605,7 @@ services: # Chroma vector database chroma: - image: ghcr.io/chroma-core/chroma:0.5.1 + image: ghcr.io/chroma-core/chroma:0.5.20 profiles: - chroma restart: always diff --git a/docker/middleware.env.example b/docker/middleware.env.example index 17ac819527..c4ce9f0114 100644 --- a/docker/middleware.env.example +++ b/docker/middleware.env.example @@ -42,11 +42,13 @@ POSTGRES_EFFECTIVE_CACHE_SIZE=4096MB # ----------------------------- # Environment Variables for redis Service -REDIS_HOST_VOLUME=./volumes/redis/data # ----------------------------- +REDIS_HOST_VOLUME=./volumes/redis/data +REDIS_PASSWORD=difyai123456 # ------------------------------ # Environment Variables for sandbox Service +# ------------------------------ SANDBOX_API_KEY=dify-sandbox SANDBOX_GIN_MODE=release SANDBOX_WORKER_TIMEOUT=15 @@ -54,7 +56,6 @@ SANDBOX_ENABLE_NETWORK=true SANDBOX_HTTP_PROXY=http://ssrf_proxy:3128 SANDBOX_HTTPS_PROXY=http://ssrf_proxy:3128 SANDBOX_PORT=8194 -# ------------------------------ # ------------------------------ # Environment Variables for ssrf_proxy Service diff --git a/web/app/components/app/configuration/config-vision/index.tsx b/web/app/components/app/configuration/config-vision/index.tsx index 23f00d46d8..f30d3e4a0a 100644 --- a/web/app/components/app/configuration/config-vision/index.tsx +++ b/web/app/components/app/configuration/config-vision/index.tsx @@ -12,34 +12,46 @@ import ConfigContext from '@/context/debug-configuration' // import { Resolution } from '@/types/app' import { useFeatures, useFeaturesStore } from '@/app/components/base/features/hooks' import Switch from '@/app/components/base/switch' -import type { FileUpload } from '@/app/components/base/features/types' +import { SupportUploadFileTypes } from '@/app/components/workflow/types' const ConfigVision: FC = () => { const { t } = useTranslation() - const { isShowVisionConfig } = useContext(ConfigContext) + const { isShowVisionConfig, isAllowVideoUpload } = useContext(ConfigContext) const file = useFeatures(s => s.features.file) const featuresStore = useFeaturesStore() - const handleChange = useCallback((data: FileUpload) => { + const isImageEnabled = file?.allowed_file_types?.includes(SupportUploadFileTypes.image) ?? false + + const handleChange = useCallback((value: boolean) => { const { features, setFeatures, } = featuresStore!.getState() const newFeatures = produce(features, (draft) => { - draft.file = { - ...draft.file, - enabled: data.enabled, - image: { - enabled: data.enabled, - detail: data.image?.detail, - transfer_methods: data.image?.transfer_methods, - number_limits: data.image?.number_limits, - }, + if (value) { + draft.file!.allowed_file_types = Array.from(new Set([ + ...(draft.file?.allowed_file_types || []), + SupportUploadFileTypes.image, + ...(isAllowVideoUpload ? [SupportUploadFileTypes.video] : []), + ])) + } + else { + draft.file!.allowed_file_types = draft.file!.allowed_file_types?.filter( + type => type !== SupportUploadFileTypes.image && (isAllowVideoUpload ? type !== SupportUploadFileTypes.video : true), + ) + } + + if (draft.file) { + draft.file.enabled = (draft.file.allowed_file_types?.length ?? 0) > 0 + draft.file.image = { + ...(draft.file.image || {}), + enabled: value, + } } }) setFeatures(newFeatures) - }, [featuresStore]) + }, [featuresStore, isAllowVideoUpload]) if (!isShowVisionConfig) return null @@ -89,11 +101,8 @@ const ConfigVision: FC = () => {
handleChange({ - ...(file || {}), - enabled: value, - })} + defaultValue={isImageEnabled} + onChange={handleChange} size='md' /> diff --git a/web/app/components/app/configuration/config/config-document.tsx b/web/app/components/app/configuration/config/config-document.tsx new file mode 100644 index 0000000000..1ac6da0dd8 --- /dev/null +++ b/web/app/components/app/configuration/config/config-document.tsx @@ -0,0 +1,78 @@ +'use client' +import type { FC } from 'react' +import React, { useCallback } from 'react' +import { useTranslation } from 'react-i18next' +import produce from 'immer' +import { useContext } from 'use-context-selector' + +import { Document } from '@/app/components/base/icons/src/vender/features' +import Tooltip from '@/app/components/base/tooltip' +import ConfigContext from '@/context/debug-configuration' +import { SupportUploadFileTypes } from '@/app/components/workflow/types' +import { useFeatures, useFeaturesStore } from '@/app/components/base/features/hooks' +import Switch from '@/app/components/base/switch' + +const ConfigDocument: FC = () => { + const { t } = useTranslation() + const file = useFeatures(s => s.features.file) + const featuresStore = useFeaturesStore() + const { isShowDocumentConfig } = useContext(ConfigContext) + + const isDocumentEnabled = file?.allowed_file_types?.includes(SupportUploadFileTypes.document) ?? false + + const handleChange = useCallback((value: boolean) => { + const { + features, + setFeatures, + } = featuresStore!.getState() + + const newFeatures = produce(features, (draft) => { + if (value) { + draft.file!.allowed_file_types = Array.from(new Set([ + ...(draft.file?.allowed_file_types || []), + SupportUploadFileTypes.document, + ])) + } + else { + draft.file!.allowed_file_types = draft.file!.allowed_file_types?.filter( + type => type !== SupportUploadFileTypes.document, + ) + } + if (draft.file) + draft.file.enabled = (draft.file.allowed_file_types?.length ?? 0) > 0 + }) + setFeatures(newFeatures) + }, [featuresStore]) + + if (!isShowDocumentConfig) + return null + + return ( +
+
+
+ +
+
+
+
{t('appDebug.feature.documentUpload.title')}
+ + {t('appDebug.feature.documentUpload.description')} +
+ } + /> +
+
+
+ +
+ + ) +} +export default React.memo(ConfigDocument) diff --git a/web/app/components/app/configuration/config/index.tsx b/web/app/components/app/configuration/config/index.tsx index 8687079931..39fdd502ef 100644 --- a/web/app/components/app/configuration/config/index.tsx +++ b/web/app/components/app/configuration/config/index.tsx @@ -7,6 +7,7 @@ import { useFormattingChangedDispatcher } from '../debug/hooks' import DatasetConfig from '../dataset-config' import HistoryPanel from '../config-prompt/conversation-history/history-panel' import ConfigVision from '../config-vision' +import ConfigDocument from './config-document' import AgentTools from './agent/agent-tools' import ConfigContext from '@/context/debug-configuration' import ConfigPrompt from '@/app/components/app/configuration/config-prompt' @@ -82,6 +83,8 @@ const Config: FC = () => { + + {/* Chat History */} {isAdvancedMode && isChatApp && modelModeType === ModelModeType.completion && ( { } const isShowVisionConfig = !!currModel?.features?.includes(ModelFeatureEnum.vision) - + const isShowDocumentConfig = !!currModel?.features?.includes(ModelFeatureEnum.document) + const isAllowVideoUpload = !!currModel?.features?.includes(ModelFeatureEnum.video) // *** web app features *** const featuresData: FeaturesData = useMemo(() => { return { @@ -472,7 +473,7 @@ const Configuration: FC = () => { transfer_methods: modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], }, enabled: !!(modelConfig.file_upload?.enabled || modelConfig.file_upload?.image?.enabled), - allowed_file_types: modelConfig.file_upload?.allowed_file_types || [SupportUploadFileTypes.image, SupportUploadFileTypes.video], + allowed_file_types: modelConfig.file_upload?.allowed_file_types || [], allowed_file_extensions: modelConfig.file_upload?.allowed_file_extensions || [...FILE_EXTS[SupportUploadFileTypes.image], ...FILE_EXTS[SupportUploadFileTypes.video]].map(ext => `.${ext}`), allowed_file_upload_methods: modelConfig.file_upload?.allowed_file_upload_methods || modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'], number_limits: modelConfig.file_upload?.number_limits || modelConfig.file_upload?.image?.number_limits || 3, @@ -861,6 +862,8 @@ const Configuration: FC = () => { isShowVisionConfig, visionConfig, setVisionConfig: handleSetVisionConfig, + isAllowVideoUpload, + isShowDocumentConfig, rerankSettingModalOpen, setRerankSettingModalOpen, }} diff --git a/web/app/components/app/create-from-dsl-modal/index.tsx b/web/app/components/app/create-from-dsl-modal/index.tsx index e238ce0e91..ce06b113bc 100644 --- a/web/app/components/app/create-from-dsl-modal/index.tsx +++ b/web/app/components/app/create-from-dsl-modal/index.tsx @@ -12,9 +12,13 @@ import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' import { ToastContext } from '@/app/components/base/toast' import { - importApp, - importAppFromUrl, + importDSL, + importDSLConfirm, } from '@/service/apps' +import { + DSLImportMode, + DSLImportStatus, +} from '@/models/app' import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' import AppsFull from '@/app/components/billing/apps-full-in-dialog' @@ -43,6 +47,9 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS const [fileContent, setFileContent] = useState() const [currentTab, setCurrentTab] = useState(activeTab) const [dslUrlValue, setDslUrlValue] = useState(dslUrl) + const [showErrorModal, setShowErrorModal] = useState(false) + const [versions, setVersions] = useState<{ importedVersion: string; systemVersion: string }>() + const [importId, setImportId] = useState() const readFile = (file: File) => { const reader = new FileReader() @@ -66,6 +73,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS const isAppsFull = (enableBilling && plan.usage.buildApps >= plan.total.buildApps) const isCreatingRef = useRef(false) + const onCreate: MouseEventHandler = async () => { if (currentTab === CreateFromDSLModalTab.FROM_FILE && !currentFile) return @@ -75,25 +83,54 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS return isCreatingRef.current = true try { - let app + let response if (currentTab === CreateFromDSLModalTab.FROM_FILE) { - app = await importApp({ - data: fileContent || '', + response = await importDSL({ + mode: DSLImportMode.YAML_CONTENT, + yaml_content: fileContent || '', }) } if (currentTab === CreateFromDSLModalTab.FROM_URL) { - app = await importAppFromUrl({ - url: dslUrlValue || '', + response = await importDSL({ + mode: DSLImportMode.YAML_URL, + yaml_url: dslUrlValue || '', }) } - if (onSuccess) - onSuccess() - if (onClose) - onClose() - notify({ type: 'success', message: t('app.newApp.appCreated') }) - localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') - getRedirection(isCurrentWorkspaceEditor, app, push) + + if (!response) + return + + const { id, status, app_id, imported_dsl_version, current_dsl_version } = response + if (status === DSLImportStatus.COMPLETED || status === DSLImportStatus.COMPLETED_WITH_WARNINGS) { + if (onSuccess) + onSuccess() + if (onClose) + onClose() + + notify({ + type: status === DSLImportStatus.COMPLETED ? 'success' : 'warning', + message: t(status === DSLImportStatus.COMPLETED ? 'app.newApp.appCreated' : 'app.newApp.caution'), + children: status === DSLImportStatus.COMPLETED_WITH_WARNINGS && t('app.newApp.appCreateDSLWarning'), + }) + localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') + getRedirection(isCurrentWorkspaceEditor, { id: app_id }, push) + } + else if (status === DSLImportStatus.PENDING) { + setVersions({ + importedVersion: imported_dsl_version ?? '', + systemVersion: current_dsl_version ?? '', + }) + if (onClose) + onClose() + setTimeout(() => { + setShowErrorModal(true) + }, 300) + setImportId(id) + } + else { + notify({ type: 'error', message: t('app.newApp.appCreateFailed') }) + } } catch (e) { notify({ type: 'error', message: t('app.newApp.appCreateFailed') }) @@ -101,6 +138,38 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS isCreatingRef.current = false } + const onDSLConfirm: MouseEventHandler = async () => { + try { + if (!importId) + return + const response = await importDSLConfirm({ + import_id: importId, + }) + + const { status, app_id } = response + + if (status === DSLImportStatus.COMPLETED) { + if (onSuccess) + onSuccess() + if (onClose) + onClose() + + notify({ + type: 'success', + message: t('app.newApp.appCreated'), + }) + localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') + getRedirection(isCurrentWorkspaceEditor, { id: app_id }, push) + } + else if (status === DSLImportStatus.FAILED) { + notify({ type: 'error', message: t('app.newApp.appCreateFailed') }) + } + } + catch (e) { + notify({ type: 'error', message: t('app.newApp.appCreateFailed') }) + } + } + const tabs = [ { key: CreateFromDSLModalTab.FROM_FILE, @@ -123,74 +192,96 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS }, [isAppsFull, currentTab, currentFile, dslUrlValue]) return ( - { }} - > -
- {t('app.importFromDSL')} -
onClose()} - > - + <> + { }} + > +
+ {t('app.importFromDSL')} +
onClose()} + > + +
-
-
- { - tabs.map(tab => ( -
setCurrentTab(tab.key)} - > - {tab.label} - { - currentTab === tab.key && ( -
- ) - } -
- )) - } -
-
- { - currentTab === CreateFromDSLModalTab.FROM_FILE && ( - - ) - } - { - currentTab === CreateFromDSLModalTab.FROM_URL && ( -
-
DSL URL
- setDslUrlValue(e.target.value)} +
+ { + tabs.map(tab => ( +
setCurrentTab(tab.key)} + > + {tab.label} + { + currentTab === tab.key && ( +
+ ) + } +
+ )) + } +
+
+ { + currentTab === CreateFromDSLModalTab.FROM_FILE && ( + -
- ) - } -
- {isAppsFull && ( -
- + ) + } + { + currentTab === CreateFromDSLModalTab.FROM_URL && ( +
+
DSL URL
+ setDslUrlValue(e.target.value)} + /> +
+ ) + }
- )} -
- - -
- + {isAppsFull && ( +
+ +
+ )} +
+ + +
+ + setShowErrorModal(false)} + className='w-[480px]' + > +
+
{t('app.newApp.appCreateDSLErrorTitle')}
+
+
{t('app.newApp.appCreateDSLErrorPart1')}
+
{t('app.newApp.appCreateDSLErrorPart2')}
+
+
{t('app.newApp.appCreateDSLErrorPart3')}{versions?.importedVersion}
+
{t('app.newApp.appCreateDSLErrorPart4')}{versions?.systemVersion}
+
+
+
+ + +
+
+ ) } diff --git a/web/app/components/app/create-from-dsl-modal/uploader.tsx b/web/app/components/app/create-from-dsl-modal/uploader.tsx index fa5554f9cf..beb2b4b1a0 100644 --- a/web/app/components/app/create-from-dsl-modal/uploader.tsx +++ b/web/app/components/app/create-from-dsl-modal/uploader.tsx @@ -6,6 +6,7 @@ import { } from '@remixicon/react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' +import { formatFileSize } from '@/utils/format' import cn from '@/utils/classnames' import { Yaml as YamlIcon } from '@/app/components/base/icons/src/public/files' import { ToastContext } from '@/app/components/base/toast' @@ -58,8 +59,13 @@ const Uploader: FC = ({ updateFile(files[0]) } const selectHandle = () => { - if (fileUploader.current) + const originalFile = file + if (fileUploader.current) { + fileUploader.current.value = '' fileUploader.current.click() + // If no file is selected, restore the original file + fileUploader.current.oncancel = () => updateFile(originalFile) + } } const removeFile = () => { if (fileUploader.current) @@ -96,7 +102,7 @@ const Uploader: FC = ({ />
{!file && ( -
+
@@ -108,17 +114,23 @@ const Uploader: FC = ({
)} {file && ( -
- -
- {file.name.replace(/(.yaml|.yml)$/, '')} - .yml +
+
+ +
+
+ {file.name} +
+ YAML + · + {formatFileSize(file.size)} +
- +
diff --git a/web/app/components/base/chat/__tests__/__snapshots__/utils.spec.ts.snap b/web/app/components/base/chat/__tests__/__snapshots__/utils.spec.ts.snap index 7da09c4529..4ffcfa31e9 100644 --- a/web/app/components/base/chat/__tests__/__snapshots__/utils.spec.ts.snap +++ b/web/app/components/base/chat/__tests__/__snapshots__/utils.spec.ts.snap @@ -1804,8 +1804,85 @@ exports[`build chat item tree and get thread messages should get thread messages ] `; -exports[`build chat item tree and get thread messages should work with partial messages 1`] = ` +exports[`build chat item tree and get thread messages should work with partial messages 1 1`] = ` [ + { + "children": [ + { + "agent_thoughts": [ + { + "chain_id": null, + "created_at": 1726105799, + "files": [], + "id": "9730d587-9268-4683-9dd9-91a1cab9510b", + "message_id": "4c5d0841-1206-463e-95d8-71f812877658", + "observation": "", + "position": 1, + "thought": "I'll go with 112. Your turn!", + "tool": "", + "tool_input": "", + "tool_labels": {}, + }, + ], + "children": [], + "content": "I'll go with 112. Your turn!", + "conversationId": "dd6c9cfd-2656-48ec-bd51-2139c1790d80", + "feedbackDisabled": false, + "id": "4c5d0841-1206-463e-95d8-71f812877658", + "input": { + "inputs": {}, + "query": "99", + }, + "isAnswer": true, + "log": [ + { + "files": [], + "role": "user", + "text": "Let's play a game, I say a number , and you response me with another bigger, yet randomly number. I'll start first, 38", + }, + { + "files": [], + "role": "assistant", + "text": "Sure, I'll play! My number is 57. Your turn!", + }, + { + "files": [], + "role": "user", + "text": "58", + }, + { + "files": [], + "role": "assistant", + "text": "I choose 83. What's your next number?", + }, + { + "files": [], + "role": "user", + "text": "99", + }, + { + "files": [], + "role": "assistant", + "text": "I'll go with 112. Your turn!", + }, + ], + "message_files": [], + "more": { + "latency": "1.49", + "time": "09/11/2024 09:50 PM", + "tokens": 86, + }, + "parentMessageId": "question-4c5d0841-1206-463e-95d8-71f812877658", + "siblingIndex": 0, + "workflow_run_id": null, + }, + ], + "content": "99", + "id": "question-4c5d0841-1206-463e-95d8-71f812877658", + "isAnswer": false, + "message_files": [], + "parentMessageId": "73bbad14-d915-499d-87bf-0df14d40779d", + }, { "children": [ { @@ -2078,6 +2155,178 @@ exports[`build chat item tree and get thread messages should work with partial m ] `; +exports[`build chat item tree and get thread messages should work with partial messages 2 1`] = ` +[ + { + "children": [ + { + "children": [], + "content": "237.", + "id": "ebb73fe2-15de-46dd-aab5-75416d8448eb", + "isAnswer": true, + "parentMessageId": "question-ebb73fe2-15de-46dd-aab5-75416d8448eb", + "siblingIndex": 0, + }, + ], + "content": "123", + "id": "question-ebb73fe2-15de-46dd-aab5-75416d8448eb", + "isAnswer": false, + "parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418", + }, + { + "children": [ + { + "children": [], + "content": "My number is 256.", + "id": "3553d508-3850-462e-8594-078539f940f9", + "isAnswer": true, + "parentMessageId": "question-3553d508-3850-462e-8594-078539f940f9", + "siblingIndex": 1, + }, + ], + "content": "123", + "id": "question-3553d508-3850-462e-8594-078539f940f9", + "isAnswer": false, + "parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418", + }, + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [ + { + "children": [], + "content": "My number is 3e (approximately 8.15).", + "id": "9eac3bcc-8d3b-4e56-a12b-44c34cebc719", + "isAnswer": true, + "parentMessageId": "question-9eac3bcc-8d3b-4e56-a12b-44c34cebc719", + "siblingIndex": 0, + }, + ], + "content": "e", + "id": "question-9eac3bcc-8d3b-4e56-a12b-44c34cebc719", + "isAnswer": false, + "parentMessageId": "5c56a2b3-f057-42a0-9b2c-52a35713cd8c", + }, + ], + "content": "My number is 2π (approximately 6.28).", + "id": "5c56a2b3-f057-42a0-9b2c-52a35713cd8c", + "isAnswer": true, + "parentMessageId": "question-5c56a2b3-f057-42a0-9b2c-52a35713cd8c", + "siblingIndex": 0, + }, + ], + "content": "π", + "id": "question-5c56a2b3-f057-42a0-9b2c-52a35713cd8c", + "isAnswer": false, + "parentMessageId": "46a49bb9-0881-459e-8c6a-24d20ae48d2f", + }, + ], + "content": "My number is 145.", + "id": "46a49bb9-0881-459e-8c6a-24d20ae48d2f", + "isAnswer": true, + "parentMessageId": "question-46a49bb9-0881-459e-8c6a-24d20ae48d2f", + "siblingIndex": 0, + }, + ], + "content": "78", + "id": "question-46a49bb9-0881-459e-8c6a-24d20ae48d2f", + "isAnswer": false, + "parentMessageId": "3cded945-855a-4a24-aab7-43c7dd54664c", + }, + ], + "content": "My number is 7.89.", + "id": "3cded945-855a-4a24-aab7-43c7dd54664c", + "isAnswer": true, + "parentMessageId": "question-3cded945-855a-4a24-aab7-43c7dd54664c", + "siblingIndex": 0, + }, + ], + "content": "3.11", + "id": "question-3cded945-855a-4a24-aab7-43c7dd54664c", + "isAnswer": false, + "parentMessageId": "a956de3d-ef95-4d90-84fe-f7a26ef28cd7", + }, + ], + "content": "My number is 22.", + "id": "a956de3d-ef95-4d90-84fe-f7a26ef28cd7", + "isAnswer": true, + "parentMessageId": "question-a956de3d-ef95-4d90-84fe-f7a26ef28cd7", + "siblingIndex": 0, + }, + ], + "content": "-5", + "id": "question-a956de3d-ef95-4d90-84fe-f7a26ef28cd7", + "isAnswer": false, + "parentMessageId": "93bac05d-1470-4ac9-b090-fe21cd7c3d55", + }, + ], + "content": "My number is 4782.", + "id": "93bac05d-1470-4ac9-b090-fe21cd7c3d55", + "isAnswer": true, + "parentMessageId": "question-93bac05d-1470-4ac9-b090-fe21cd7c3d55", + "siblingIndex": 0, + }, + ], + "content": "3306", + "id": "question-93bac05d-1470-4ac9-b090-fe21cd7c3d55", + "isAnswer": false, + "parentMessageId": "9e51a13b-7780-4565-98dc-f2d8c3b1758f", + }, + ], + "content": "My number is 2048.", + "id": "9e51a13b-7780-4565-98dc-f2d8c3b1758f", + "isAnswer": true, + "parentMessageId": "question-9e51a13b-7780-4565-98dc-f2d8c3b1758f", + "siblingIndex": 0, + }, + ], + "content": "1024", + "id": "question-9e51a13b-7780-4565-98dc-f2d8c3b1758f", + "isAnswer": false, + "parentMessageId": "507f9df9-1f06-4a57-bb38-f00228c42c22", + }, + ], + "content": "My number is 259.", + "id": "507f9df9-1f06-4a57-bb38-f00228c42c22", + "isAnswer": true, + "parentMessageId": "question-507f9df9-1f06-4a57-bb38-f00228c42c22", + "siblingIndex": 2, + }, + ], + "content": "123", + "id": "question-507f9df9-1f06-4a57-bb38-f00228c42c22", + "isAnswer": false, + "parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418", + }, +] +`; + exports[`build chat item tree and get thread messages should work with real world messages 1`] = ` [ { diff --git a/web/app/components/base/chat/__tests__/partialMessages.json b/web/app/components/base/chat/__tests__/partialMessages.json new file mode 100644 index 0000000000..916c6ad254 --- /dev/null +++ b/web/app/components/base/chat/__tests__/partialMessages.json @@ -0,0 +1,122 @@ +[ + { + "id": "question-ebb73fe2-15de-46dd-aab5-75416d8448eb", + "content": "123", + "isAnswer": false, + "parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418" + }, + { + "id": "ebb73fe2-15de-46dd-aab5-75416d8448eb", + "content": "237.", + "isAnswer": true, + "parentMessageId": "question-ebb73fe2-15de-46dd-aab5-75416d8448eb" + }, + { + "id": "question-3553d508-3850-462e-8594-078539f940f9", + "content": "123", + "isAnswer": false, + "parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418" + }, + { + "id": "3553d508-3850-462e-8594-078539f940f9", + "content": "My number is 256.", + "isAnswer": true, + "parentMessageId": "question-3553d508-3850-462e-8594-078539f940f9" + }, + { + "id": "question-507f9df9-1f06-4a57-bb38-f00228c42c22", + "content": "123", + "isAnswer": false, + "parentMessageId": "57c989f9-3fa4-4dec-9ee5-c3568dd27418" + }, + { + "id": "507f9df9-1f06-4a57-bb38-f00228c42c22", + "content": "My number is 259.", + "isAnswer": true, + "parentMessageId": "question-507f9df9-1f06-4a57-bb38-f00228c42c22" + }, + { + "id": "question-9e51a13b-7780-4565-98dc-f2d8c3b1758f", + "content": "1024", + "isAnswer": false, + "parentMessageId": "507f9df9-1f06-4a57-bb38-f00228c42c22" + }, + { + "id": "9e51a13b-7780-4565-98dc-f2d8c3b1758f", + "content": "My number is 2048.", + "isAnswer": true, + "parentMessageId": "question-9e51a13b-7780-4565-98dc-f2d8c3b1758f" + }, + { + "id": "question-93bac05d-1470-4ac9-b090-fe21cd7c3d55", + "content": "3306", + "isAnswer": false, + "parentMessageId": "9e51a13b-7780-4565-98dc-f2d8c3b1758f" + }, + { + "id": "93bac05d-1470-4ac9-b090-fe21cd7c3d55", + "content": "My number is 4782.", + "isAnswer": true, + "parentMessageId": "question-93bac05d-1470-4ac9-b090-fe21cd7c3d55" + }, + { + "id": "question-a956de3d-ef95-4d90-84fe-f7a26ef28cd7", + "content": "-5", + "isAnswer": false, + "parentMessageId": "93bac05d-1470-4ac9-b090-fe21cd7c3d55" + }, + { + "id": "a956de3d-ef95-4d90-84fe-f7a26ef28cd7", + "content": "My number is 22.", + "isAnswer": true, + "parentMessageId": "question-a956de3d-ef95-4d90-84fe-f7a26ef28cd7" + }, + { + "id": "question-3cded945-855a-4a24-aab7-43c7dd54664c", + "content": "3.11", + "isAnswer": false, + "parentMessageId": "a956de3d-ef95-4d90-84fe-f7a26ef28cd7" + }, + { + "id": "3cded945-855a-4a24-aab7-43c7dd54664c", + "content": "My number is 7.89.", + "isAnswer": true, + "parentMessageId": "question-3cded945-855a-4a24-aab7-43c7dd54664c" + }, + { + "id": "question-46a49bb9-0881-459e-8c6a-24d20ae48d2f", + "content": "78", + "isAnswer": false, + "parentMessageId": "3cded945-855a-4a24-aab7-43c7dd54664c" + }, + { + "id": "46a49bb9-0881-459e-8c6a-24d20ae48d2f", + "content": "My number is 145.", + "isAnswer": true, + "parentMessageId": "question-46a49bb9-0881-459e-8c6a-24d20ae48d2f" + }, + { + "id": "question-5c56a2b3-f057-42a0-9b2c-52a35713cd8c", + "content": "π", + "isAnswer": false, + "parentMessageId": "46a49bb9-0881-459e-8c6a-24d20ae48d2f" + }, + { + "id": "5c56a2b3-f057-42a0-9b2c-52a35713cd8c", + "content": "My number is 2π (approximately 6.28).", + "isAnswer": true, + "parentMessageId": "question-5c56a2b3-f057-42a0-9b2c-52a35713cd8c" + }, + { + "id": "question-9eac3bcc-8d3b-4e56-a12b-44c34cebc719", + "content": "e", + "isAnswer": false, + "parentMessageId": "5c56a2b3-f057-42a0-9b2c-52a35713cd8c" + }, + { + "id": "9eac3bcc-8d3b-4e56-a12b-44c34cebc719", + "content": "My number is 3e (approximately 8.15).", + "isAnswer": true, + "parentMessageId": "question-9eac3bcc-8d3b-4e56-a12b-44c34cebc719" + } +] diff --git a/web/app/components/base/chat/__tests__/utils.spec.ts b/web/app/components/base/chat/__tests__/utils.spec.ts index 1dead1c949..0bff8a77a1 100644 --- a/web/app/components/base/chat/__tests__/utils.spec.ts +++ b/web/app/components/base/chat/__tests__/utils.spec.ts @@ -7,6 +7,7 @@ import mixedTestMessages from './mixedTestMessages.json' import multiRootNodesMessages from './multiRootNodesMessages.json' import multiRootNodesWithLegacyTestMessages from './multiRootNodesWithLegacyTestMessages.json' import realWorldMessages from './realWorldMessages.json' +import partialMessages from './partialMessages.json' function visitNode(tree: ChatItemInTree | ChatItemInTree[], path: string): ChatItemInTree { return get(tree, path) @@ -256,9 +257,15 @@ describe('build chat item tree and get thread messages', () => { expect(threadMessages6_2).toMatchSnapshot() }) - const partialMessages = (realWorldMessages as ChatItemInTree[]).slice(-10) - const tree7 = buildChatItemTree(partialMessages) - it('should work with partial messages', () => { + const partialMessages1 = (realWorldMessages as ChatItemInTree[]).slice(-10) + const tree7 = buildChatItemTree(partialMessages1) + it('should work with partial messages 1', () => { expect(tree7).toMatchSnapshot() }) + + const partialMessages2 = (partialMessages as ChatItemInTree[]) + const tree8 = buildChatItemTree(partialMessages2) + it('should work with partial messages 2', () => { + expect(tree8).toMatchSnapshot() + }) }) diff --git a/web/app/components/base/chat/chat/chat-input-area/index.tsx b/web/app/components/base/chat/chat/chat-input-area/index.tsx index 5169e65a59..eec636b478 100644 --- a/web/app/components/base/chat/chat/chat-input-area/index.tsx +++ b/web/app/components/base/chat/chat/chat-input-area/index.tsx @@ -102,21 +102,21 @@ const ChatInputArea = ({ setCurrentIndex(historyRef.current.length) handleSend() } - else if (e.key === 'ArrowUp' && !e.shiftKey && !e.nativeEvent.isComposing) { - // When the up key is pressed, output the previous element + else if (e.key === 'ArrowUp' && !e.shiftKey && !e.nativeEvent.isComposing && e.metaKey) { + // When the cmd + up key is pressed, output the previous element if (currentIndex > 0) { setCurrentIndex(currentIndex - 1) setQuery(historyRef.current[currentIndex - 1]) } } - else if (e.key === 'ArrowDown' && !e.shiftKey && !e.nativeEvent.isComposing) { - // When the down key is pressed, output the next element + else if (e.key === 'ArrowDown' && !e.shiftKey && !e.nativeEvent.isComposing && e.metaKey) { + // When the cmd + down key is pressed, output the next element if (currentIndex < historyRef.current.length - 1) { setCurrentIndex(currentIndex + 1) setQuery(historyRef.current[currentIndex + 1]) } else if (currentIndex === historyRef.current.length - 1) { - // If it is the last element, clear the input box + // If it is the last element, clear the input box setCurrentIndex(historyRef.current.length) setQuery('') } diff --git a/web/app/components/base/chat/utils.ts b/web/app/components/base/chat/utils.ts index 61dfaecffc..326805c930 100644 --- a/web/app/components/base/chat/utils.ts +++ b/web/app/components/base/chat/utils.ts @@ -127,19 +127,16 @@ function buildChatItemTree(allMessages: IChatItem[]): ChatItemInTree[] { lastAppendedLegacyAnswer = answerNode } else { - if (!parentMessageId) + if ( + !parentMessageId + || !allMessages.some(item => item.id === parentMessageId) // parent message might not be fetched yet, in this case we will append the question to the root nodes + ) rootNodes.push(questionNode) else map[parentMessageId]?.children!.push(questionNode) } } - // If no messages have parentMessageId=null (indicating a root node), - // then we likely have a partial chat history. In this case, - // use the first available message as the root node. - if (rootNodes.length === 0 && allMessages.length > 0) - rootNodes.push(map[allMessages[0]!.id]!) - return rootNodes } diff --git a/web/app/components/base/divider/index.tsx b/web/app/components/base/divider/index.tsx index 85ce886199..4b351dea99 100644 --- a/web/app/components/base/divider/index.tsx +++ b/web/app/components/base/divider/index.tsx @@ -1,17 +1,31 @@ import type { CSSProperties, FC } from 'react' import React from 'react' -import s from './style.module.css' +import { type VariantProps, cva } from 'class-variance-authority' +import classNames from '@/utils/classnames' -type Props = { - type?: 'horizontal' | 'vertical' - // orientation?: 'left' | 'right' | 'center' +const dividerVariants = cva( + 'bg-divider-regular', + { + variants: { + type: { + horizontal: 'w-full h-[0.5px] my-2', + vertical: 'w-[1px] h-full mx-2', + }, + }, + defaultVariants: { + type: 'horizontal', + }, + }, +) + +type DividerProps = { className?: string style?: CSSProperties -} +} & VariantProps -const Divider: FC = ({ type = 'horizontal', className = '', style }) => { +const Divider: FC = ({ type, className = '', style }) => { return ( -
+
) } diff --git a/web/app/components/base/divider/style.module.css b/web/app/components/base/divider/style.module.css deleted file mode 100644 index 9cb2601b1f..0000000000 --- a/web/app/components/base/divider/style.module.css +++ /dev/null @@ -1,9 +0,0 @@ -.divider { - @apply bg-gray-200; -} -.horizontal { - @apply w-full h-[0.5px] my-2; -} -.vertical { - @apply w-[1px] h-full mx-2; -} diff --git a/web/app/components/base/ga/index.tsx b/web/app/components/base/ga/index.tsx index 219724113f..0015edbfca 100644 --- a/web/app/components/base/ga/index.tsx +++ b/web/app/components/base/ga/index.tsx @@ -47,6 +47,12 @@ gtag('config', '${gaIdMaps[gaType]}'); nonce={nonce!} > + {/* Cookie banner */} + ) diff --git a/web/app/components/base/icons/assets/vender/features/document.svg b/web/app/components/base/icons/assets/vender/features/document.svg new file mode 100644 index 0000000000..dca0e91a52 --- /dev/null +++ b/web/app/components/base/icons/assets/vender/features/document.svg @@ -0,0 +1,3 @@ + + + \ No newline at end of file diff --git a/web/app/components/base/icons/src/vender/features/Document.json b/web/app/components/base/icons/src/vender/features/Document.json new file mode 100644 index 0000000000..fdd08d5254 --- /dev/null +++ b/web/app/components/base/icons/src/vender/features/Document.json @@ -0,0 +1,23 @@ +{ + "icon": { + "type": "element", + "isRootNode": true, + "name": "svg", + "attributes": { + "xmlns": "http://www.w3.org/2000/svg", + "viewBox": "0 0 24 24", + "fill": "currentColor" + }, + "children": [ + { + "type": "element", + "name": "path", + "attributes": { + "d": "M20 22H4C3.44772 22 3 21.5523 3 21V3C3 2.44772 3.44772 2 4 2H20C20.5523 2 21 2.44772 21 3V21C21 21.5523 20.5523 22 20 22ZM7 6V10H11V6H7ZM7 12V14H17V12H7ZM7 16V18H17V16H7ZM13 7V9H17V7H13Z" + }, + "children": [] + } + ] + }, + "name": "Document" +} \ No newline at end of file diff --git a/web/app/components/base/icons/src/vender/features/Document.tsx b/web/app/components/base/icons/src/vender/features/Document.tsx new file mode 100644 index 0000000000..84bf3a2f10 --- /dev/null +++ b/web/app/components/base/icons/src/vender/features/Document.tsx @@ -0,0 +1,16 @@ +// GENERATE BY script +// DON NOT EDIT IT MANUALLY + +import * as React from 'react' +import data from './Document.json' +import IconBase from '@/app/components/base/icons/IconBase' +import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase' + +const Icon = React.forwardRef, Omit>(( + props, + ref, +) => ) + +Icon.displayName = 'Document' + +export default Icon diff --git a/web/app/components/base/icons/src/vender/features/index.ts b/web/app/components/base/icons/src/vender/features/index.ts index 2b8cb17c94..f246732226 100644 --- a/web/app/components/base/icons/src/vender/features/index.ts +++ b/web/app/components/base/icons/src/vender/features/index.ts @@ -7,3 +7,4 @@ export { default as Microphone01 } from './Microphone01' export { default as TextToAudio } from './TextToAudio' export { default as VirtualAssistant } from './VirtualAssistant' export { default as Vision } from './Vision' +export { default as Document } from './Document' diff --git a/web/app/components/base/toast/index.tsx b/web/app/components/base/toast/index.tsx index 3e13db5d7f..b1b9ffe8c4 100644 --- a/web/app/components/base/toast/index.tsx +++ b/web/app/components/base/toast/index.tsx @@ -3,16 +3,19 @@ import type { ReactNode } from 'react' import React, { useEffect, useState } from 'react' import { createRoot } from 'react-dom/client' import { - CheckCircleIcon, - ExclamationTriangleIcon, - InformationCircleIcon, - XCircleIcon, -} from '@heroicons/react/20/solid' + RiAlertFill, + RiCheckboxCircleFill, + RiCloseLine, + RiErrorWarningFill, + RiInformation2Fill, +} from '@remixicon/react' import { createContext, useContext } from 'use-context-selector' +import ActionButton from '@/app/components/base/action-button' import classNames from '@/utils/classnames' export type IToastProps = { type?: 'success' | 'error' | 'warning' | 'info' + size?: 'md' | 'sm' duration?: number message: string children?: ReactNode @@ -21,60 +24,55 @@ export type IToastProps = { } type IToastContext = { notify: (props: IToastProps) => void + close: () => void } export const ToastContext = createContext({} as IToastContext) export const useToastContext = () => useContext(ToastContext) const Toast = ({ type = 'info', + size = 'md', message, children, className, }: IToastProps) => { + const { close } = useToastContext() // sometimes message is react node array. Not handle it. if (typeof message !== 'string') return null return
-
-
- {type === 'success' &&